Beispiel #1
0
def prgq(volft, kb, nx, delta, ref_a, sym, MPI=False):
    """
	  Generate set of projections based on even angles
	  The command returns list of ffts of projections
	"""
    from projection import prep_vol, prgs
    from applications import MPI_start_end
    from utilities import even_angles, model_blank
    from fundamentals import fft
    # generate list of Eulerian angles for reference projections
    #  phi, theta, psi
    mode = "F"
    ref_angles = even_angles(delta,
                             symmetry=sym,
                             method=ref_a,
                             phiEqpsi="Minus")
    cnx = nx // 2 + 1
    cny = nx // 2 + 1
    num_ref = len(ref_angles)

    if MPI:
        from mpi import mpi_comm_rank, mpi_comm_size, MPI_COMM_WORLD
        myid = mpi_comm_rank(MPI_COMM_WORLD)
        ncpu = mpi_comm_size(MPI_COMM_WORLD)
    else:
        ncpu = 1
        myid = 0
    from applications import MPI_start_end
    ref_start, ref_end = MPI_start_end(num_ref, ncpu, myid)

    prjref = [
    ]  # list of (image objects) reference projections in Fourier representation

    for i in xrange(num_ref):
        prjref.append(model_blank(
            nx,
            nx))  # I am not sure why is that necessary, why not put None's??

    for i in xrange(ref_start, ref_end):
        prjref[i] = prgs(
            volft, kb,
            [ref_angles[i][0], ref_angles[i][1], ref_angles[i][2], 0.0, 0.0])

    if MPI:
        from utilities import bcast_EMData_to_all
        for i in xrange(num_ref):
            for j in xrange(ncpu):
                ref_start, ref_end = MPI_start_end(num_ref, ncpu, j)
                if i >= ref_start and i < ref_end: rootid = j
            bcast_EMData_to_all(prjref[i], myid, rootid)

    for i in xrange(len(ref_angles)):
        prjref[i].set_attr_dict({
            "phi": ref_angles[i][0],
            "theta": ref_angles[i][1],
            "psi": ref_angles[i][2]
        })

    return prjref
Beispiel #2
0
def calculate_list_of_independent_viper_run_indices_used_for_outlier_elimination(no_of_viper_runs_analyzed_together, 
	no_of_viper_runs_analyzed_together_from_user_options, masterdir, rviper_iter, criterion_name):

	from utilities import combinations_of_n_taken_by_k

	# generate all possible combinations of (no_of_viper_runs_analyzed_together - 1) taken (3 - 1) at a time
	import itertools

	number_of_additional_combinations_for_this_viper_iteration = combinations_of_n_taken_by_k(no_of_viper_runs_analyzed_together - 1,
																		  no_of_viper_runs_analyzed_together_from_user_options - 1)

	criterion_measure = [0.0] * number_of_additional_combinations_for_this_viper_iteration
	all_n_minus_1_combinations_taken_k_minus_1_at_a_time = list(itertools.combinations(range(no_of_viper_runs_analyzed_together - 1),
																  no_of_viper_runs_analyzed_together_from_user_options - 1))

	no_of_processors = mpi_comm_size(MPI_COMM_WORLD)
	my_rank = mpi_comm_rank(MPI_COMM_WORLD)

	for idx, tuple_of_projection_indices in enumerate(all_n_minus_1_combinations_taken_k_minus_1_at_a_time):
		if (my_rank == idx % no_of_processors):
			list_of_viper_run_indices = list(tuple_of_projection_indices) + [no_of_viper_runs_analyzed_together - 1]
			criterion_measure[idx] = measure_for_outlier_criterion(criterion_name, masterdir, rviper_iter, list_of_viper_run_indices)
			plot_errors_between_any_number_of_projections(masterdir, rviper_iter, list_of_viper_run_indices, criterion_measure[idx])

	criterion_measure = mpi_reduce(criterion_measure, number_of_additional_combinations_for_this_viper_iteration, MPI_FLOAT, MPI_SUM, 0, MPI_COMM_WORLD)

	if (my_rank == 0):
		index_of_sorted_criterion_measure_list = [i[0] for i in sorted(enumerate(criterion_measure), reverse=False, key=lambda x: x[1])]

		list_of_viper_run_indices_for_the_current_rrr_viper_iteration = list(all_n_minus_1_combinations_taken_k_minus_1_at_a_time[index_of_sorted_criterion_measure_list[0]]) + \
																		[no_of_viper_runs_analyzed_together - 1]

		mainoutputdir = masterdir + DIR_DELIM + NAME_OF_MAIN_DIR + ("%03d" + DIR_DELIM) % (rviper_iter)

		if criterion_measure[index_of_sorted_criterion_measure_list[0]] == TRIPLET_WITH_ANGLE_ERROR_LESS_THAN_THRESHOLD_HAS_BEEN_FOUND:
			list_of_viper_run_indices_for_the_current_rrr_viper_iteration.insert(0,MUST_END_PROGRAM_THIS_ITERATION)
		else:
			list_of_viper_run_indices_for_the_current_rrr_viper_iteration.insert(0,DUMMY_INDEX_USED_AS_BUFFER)
			if criterion_name == "80th percentile":
				pass_criterion = criterion_measure[index_of_sorted_criterion_measure_list[0]] < PERCENT_THRESHOLD_Y
			elif criterion_name == "fastest increase in the last quartile":
				pass_criterion = criterion_measure[index_of_sorted_criterion_measure_list[-1]] > PERCENT_THRESHOLD_Y
			else:
				pass_criterion = False
	
			if not pass_criterion:
				list_of_viper_run_indices_for_the_current_rrr_viper_iteration = [EMPTY_VIPER_RUN_INDICES_LIST]

		import json; f = open(mainoutputdir + "list_of_viper_runs_included_in_outlier_elimination.json", 'w')
		json.dump(list_of_viper_run_indices_for_the_current_rrr_viper_iteration[1:],f); f.close()

		mpi_barrier(MPI_COMM_WORLD)
		return list_of_viper_run_indices_for_the_current_rrr_viper_iteration

	mpi_barrier(MPI_COMM_WORLD)

	return [EMPTY_VIPER_RUN_INDICES_LIST]
Beispiel #3
0
def main():

	import sys

	arglist = []
	for arg in sys.argv:
		arglist.append( arg )

	progname = os.path.basename(arglist[0])
	usage = progname + " prjstack outdir bufprefix --delta --d --nvol --nbufvol --seedbase --snr --npad --CTF --MPI --verbose"
	parser = OptionParser(usage,version=SPARXVERSION)
	parser.add_option("--nvol",     type="int",                         help="number of resample volumes to be generated")
	parser.add_option("--nbufvol",  type="int",          default=1,     help="number of fftvols in the memory")
	parser.add_option("--delta",    type="float",        default=10.0,  help="angular step for cones")
	parser.add_option("--d",        type="float",        default=0.1,   help="fraction of projections to leave out")
	parser.add_option("--CTF",      action="store_true", default=False, help="use CTF")
	parser.add_option("--snr",      type="float",        default=1.0,   help="Signal-to-Noise Ratio")
	parser.add_option("--npad",     type="int",          default=2,     help="times of padding")
	parser.add_option("--seedbase", type="int",          default=-1,    help="random seed base")
	parser.add_option("--MPI",      action="store_true", default=False, help="use MPI")
	parser.add_option("--verbose",  type="int",          default=0,     help="verbose level: 0 no, 1 yes")

	(options, args) = parser.parse_args( arglist[1:] )

	if( len(args) !=1 and len(args) != 3):
		print("usage: " + usage)
		return None

	prjfile = args[0]

	if options.MPI:
		from mpi import mpi_barrier, mpi_comm_rank, mpi_comm_size, mpi_comm_split, MPI_COMM_WORLD
		from mpi import mpi_init
		sys.argv = mpi_init( len(sys.argv), sys.argv )
		myid = mpi_comm_rank( MPI_COMM_WORLD )
		ncpu = mpi_comm_size( MPI_COMM_WORLD )
	else:
		myid = 0
		ncpu = 1

	if global_def.CACHE_DISABLE:
		from utilities import disable_bdb_cache
		disable_bdb_cache()

	outdir = args[1]
	bufprefix = args[2]
	resample( prjfile, outdir, bufprefix, options.nbufvol, options.nvol, options.seedbase,\
	           options.delta, options.d, options.snr, options.CTF, options.npad,\
		   options.MPI, myid, ncpu, options.verbose )
	if options.MPI:
		from mpi import mpi_finalize
		mpi_finalize()
Beispiel #4
0
def main():

	import sys

        arglist = []
        for arg in sys.argv:
	    arglist.append( arg )

	progname = os.path.basename(arglist[0])
	usage = progname + " prjstack outdir bufprefix --delta --d --nvol --nbufvol --seedbase --snr --npad --CTF --MPI --verbose"
	parser = OptionParser(usage,version=SPARXVERSION)
	parser.add_option("--nvol",     type="int",                         help="number of resample volumes to be generated")
	parser.add_option("--nbufvol",  type="int",          default=1,     help="number of fftvols in the memory")
	parser.add_option("--delta",    type="float",        default=10.0,  help="angular step for cones")
	parser.add_option("--d",        type="float",        default=0.1,   help="fraction of projections to leave out")
	parser.add_option("--CTF",      action="store_true", default=False, help="use CTF")
	parser.add_option("--snr",      type="float",        default=1.0,   help="Signal-to-Noise Ratio")
	parser.add_option("--npad",     type="int",          default=2,     help="times of padding")
	parser.add_option("--seedbase", type="int",          default=-1,    help="random seed base")
	parser.add_option("--MPI",      action="store_true", default=False, help="use MPI")
	parser.add_option("--verbose",  type="int",          default=0,     help="verbose level: 0 no, 1 yes")

	(options, args) = parser.parse_args( arglist[1:] )

	if( len(args) !=1 and len(args) != 3):
		print "usage: " + usage
		return None

	prjfile = args[0]

	if options.MPI:
		from mpi import mpi_barrier, mpi_comm_rank, mpi_comm_size, mpi_comm_split, MPI_COMM_WORLD
                from mpi import mpi_init
                sys.argv = mpi_init( len(sys.argv), sys.argv )
		myid = mpi_comm_rank( MPI_COMM_WORLD )
		ncpu = mpi_comm_size( MPI_COMM_WORLD )
	else:
		myid = 0
		ncpu = 1

	if global_def.CACHE_DISABLE:
		from utilities import disable_bdb_cache
		disable_bdb_cache()

	outdir = args[1]
	bufprefix = args[2]
	resample( prjfile, outdir, bufprefix, options.nbufvol, options.nvol, options.seedbase,\
	           options.delta, options.d, options.snr, options.CTF, options.npad,\
		   options.MPI, myid, ncpu, options.verbose )
	if options.MPI:
		from mpi import mpi_finalize
		mpi_finalize()
Beispiel #5
0
def prgq( volft, kb, nx, delta, ref_a, sym, MPI=False):
	"""
	  Generate set of projections based on even angles
	  The command returns list of ffts of projections
	"""
	from projection   import prep_vol, prgs
	from applications import MPI_start_end
	from utilities    import even_angles, model_blank
	from fundamentals import fft
	# generate list of Eulerian angles for reference projections
	#  phi, theta, psi
	mode = "F"
	ref_angles = even_angles(delta, symmetry=sym, method = ref_a, phiEqpsi = "Minus")
	cnx = nx//2 + 1
	cny = nx//2 + 1
        num_ref = len(ref_angles)

	if MPI:
		from mpi import mpi_comm_rank, mpi_comm_size, MPI_COMM_WORLD
		myid = mpi_comm_rank( MPI_COMM_WORLD )
		ncpu = mpi_comm_size( MPI_COMM_WORLD )
	else:
		ncpu = 1
		myid = 0
	from applications import MPI_start_end
	ref_start,ref_end = MPI_start_end( num_ref, ncpu, myid )

	prjref = []     # list of (image objects) reference projections in Fourier representation

        for i in xrange(num_ref):
		prjref.append(model_blank(nx, nx))  # I am not sure why is that necessary, why not put None's??

        for i in xrange(ref_start, ref_end):
		prjref[i] = prgs(volft, kb, [ref_angles[i][0], ref_angles[i][1], ref_angles[i][2], 0.0, 0.0])

	if MPI:
		from utilities import bcast_EMData_to_all
		for i in xrange(num_ref):
			for j in xrange(ncpu):
				ref_start,ref_end = MPI_start_end(num_ref,ncpu,j)
				if i >= ref_start and i < ref_end: rootid = j
			bcast_EMData_to_all( prjref[i], myid, rootid )

	for i in xrange(len(ref_angles)):
		prjref[i].set_attr_dict({"phi": ref_angles[i][0], "theta": ref_angles[i][1],"psi": ref_angles[i][2]})

	return prjref
Beispiel #6
0
def wrap_mpi_split(mpi_comm, number_of_subcomm):
    from mpi import mpi_comm_rank, mpi_comm_size, mpi_comm_split
    from air import mpi_env_type

    main_size = mpi_comm_size(mpi_comm)
    if number_of_subcomm > main_size:
        raise RuntimeError("number_of_subcomm > main_size")

    me = mpi_env_type()
    me.main_comm = mpi_comm
    me.main_rank = mpi_comm_rank(mpi_comm)
    me.subcomm_id = me.main_rank % number_of_subcomm
    me.sub_rank = me.main_rank / number_of_subcomm
    me.sub_comm = mpi_comm_split(mpi_comm, me.subcomm_id, me.sub_rank)
    me.subcomms_count = number_of_subcomm
    me.subcomms_roots = range(number_of_subcomm)

    return me
Beispiel #7
0
def main():
	import os
	import sys
	from optparse import OptionParser
        arglist = []
        for arg in sys.argv:
        	arglist.append( arg )
	progname = os.path.basename(arglist[0])
	usage = progname + """ firstvolume  secondvolume maskfile outputfile --wn --step --cutoff  --radius  --fsc --MPI

	Compute local resolution in real space within area outlined by the maskfile and within regions wn x wn x wn
	"""
	parser = OptionParser(usage,version=SPARXVERSION)
	
	parser.add_option("--wn",		type="int",		default=7, 			help="Size of window within which local real-space FSC is computed (default 7")
	parser.add_option("--step",     type="float",	default= 1.0,       help="Shell step in Fourier size in pixels (default 1.0)")   
	parser.add_option("--cutoff",   type="float",	default= 0.5,       help="resolution cut-off for FSC (default 0.5)")
	parser.add_option("--radius",	type="int",		default=-1, 		help="if there is no maskfile, sphere with r=radius will be used, by default the radius is nx/2-wn")
	parser.add_option("--fsc",      type="string",	default= None,      help="overall FSC curve (might be truncated) (default no curve)")
	parser.add_option("--MPI",      action="store_true",   	default=False,  help="use MPI version")

	(options, args) = parser.parse_args(arglist[1:])
	
	if len(args) <3 or len(args) > 4:
		print "See usage " + usage
		sys.exit()

	if global_def.CACHE_DISABLE:
		from utilities import disable_bdb_cache
		disable_bdb_cache()


	if options.MPI:
		from mpi 	  	  import mpi_init, mpi_comm_size, mpi_comm_rank, MPI_COMM_WORLD
		from mpi 	  	  import mpi_reduce, mpi_bcast, mpi_barrier, mpi_gatherv, mpi_send, mpi_recv
		from mpi 	  	  import MPI_SUM, MPI_FLOAT, MPI_INT, MPI_TAG_UB
		sys.argv = mpi_init(len(sys.argv),sys.argv)		
	
		number_of_proc = mpi_comm_size(MPI_COMM_WORLD)
		myid = mpi_comm_rank(MPI_COMM_WORLD)
		main_node = 0
		cutoff = options.cutoff

		nk = int(options.wn)

		if(myid == main_node):
			#print sys.argv
			vi = get_im(sys.argv[1])
			ui = get_im(sys.argv[2])

			nx = vi.get_xsize()
			ny = vi.get_ysize()
			nz = vi.get_zsize()
			dis = [nx, ny, nz]
		else:
			dis = [0,0,0,0]


		dis = bcast_list_to_all(dis, myid, source_node = main_node)

		if(myid != main_node):
			nx = int(dis[0])
			ny = int(dis[1])
			nz = int(dis[2])

			vi = model_blank(nx,ny,nz)
			ui = model_blank(nx,ny,nz)


		if len(args) == 3:
			m = model_circle((min(nx,ny,nz)-nk)//2,nx,ny,nz)
			outvol = args[2]
		
		elif len(args) == 4:
			if(myid == main_node):
				m = binarize(get_im(args[2]), 0.5)
			else:
				m = model_blank(nx, ny, nz)
			outvol = args[3]
		bcast_EMData_to_all(m, myid, main_node)

		from statistics import locres
		freqvol, resolut = locres(vi, ui, m, nk, cutoff, options.step, myid, main_node, number_of_proc)
		if(myid == 0):
			freqvol.write_image(outvol)
			if(options.fsc != None): write_text_row(resolut, options.fsc)

		from mpi import mpi_finalize
		mpi_finalize()

	else:
		cutoff = options.cutoff
		vi = get_im(args[0])
		ui = get_im(args[1])

		nn = vi.get_xsize()
		nk = int(options.wn)
	
		if len(args) == 3:
			m = model_circle((nn-nk)//2,nn,nn,nn)
			outvol = args[2]
		
		elif len(args) == 4:
			m = binarize(get_im(args[2]), 0.5)
			outvol = args[3]

		mc = model_blank(nn,nn,nn,1.0)-m

		vf = fft(vi)
		uf = fft(ui)

		lp = int(nn/2/options.step+0.5)
		step = 0.5/lp

		freqvol = model_blank(nn,nn,nn)
		resolut = []
		for i in xrange(1,lp):
			fl = step*i
			fh = fl+step
			print lp,i,step,fl,fh
			v = fft(filt_tophatb( vf, fl, fh))
			u = fft(filt_tophatb( uf, fl, fh))
			tmp1 = Util.muln_img(v,v)
			tmp2 = Util.muln_img(u,u)

			do = Util.infomask(square_root(Util.muln_img(tmp1,tmp2)),m,True)[0]


			tmp3 = Util.muln_img(u,v)
			dp = Util.infomask(tmp3,m,True)[0]
			resolut.append([i,(fl+fh)/2.0, dp/do])

			tmp1 = Util.box_convolution(tmp1, nk)
			tmp2 = Util.box_convolution(tmp2, nk)
			tmp3 = Util.box_convolution(tmp3, nk)

			Util.mul_img(tmp1,tmp2)

			tmp1 = square_root(tmp1)

			Util.mul_img(tmp1,m)
			Util.add_img(tmp1,mc)

			Util.mul_img(tmp3,m)
			Util.add_img(tmp3,mc)

			Util.div_img(tmp3,tmp1)

			Util.mul_img(tmp3,m)
			freq=(fl+fh)/2.0
			bailout = True
			for x in xrange(nn):
				for y in xrange(nn):
					for z in xrange(nn):
						if(m.get_value_at(x,y,z) > 0.5):
							if(freqvol.get_value_at(x,y,z) == 0.0):
								if(tmp3.get_value_at(x,y,z) < cutoff):
									freqvol.set_value_at(x,y,z,freq)
									bailout = False
								else:
									bailout = False
			if(bailout):  break

		freqvol.write_image(outvol)
		if(options.fsc != None): write_text_row(resolut, options.fsc)
Beispiel #8
0
def main():
    def params_3D_2D_NEW(phi, theta, psi, s2x, s2y, mirror):
        if mirror:
            m = 1
            alpha, sx, sy, scalen = compose_transform2(0, s2x, s2y, 1.0,
                                                       540.0 - psi, 0, 0, 1.0)
        else:
            m = 0
            alpha, sx, sy, scalen = compose_transform2(0, s2x, s2y, 1.0,
                                                       360.0 - psi, 0, 0, 1.0)
        return alpha, sx, sy, m

    progname = os.path.basename(sys.argv[0])
    usage = progname + " prj_stack  --ave2D= --var2D=  --ave3D= --var3D= --img_per_grp= --fl=15. --aa=0.01  --sym=symmetry --CTF"
    parser = OptionParser(usage, version=SPARXVERSION)

    parser.add_option("--output_dir",
                      type="string",
                      default="./",
                      help="output directory")
    parser.add_option("--ave2D",
                      type="string",
                      default=False,
                      help="write to the disk a stack of 2D averages")
    parser.add_option("--var2D",
                      type="string",
                      default=False,
                      help="write to the disk a stack of 2D variances")
    parser.add_option("--ave3D",
                      type="string",
                      default=False,
                      help="write to the disk reconstructed 3D average")
    parser.add_option("--var3D",
                      type="string",
                      default=False,
                      help="compute 3D variability (time consuming!)")
    parser.add_option("--img_per_grp",
                      type="int",
                      default=10,
                      help="number of neighbouring projections")
    parser.add_option("--no_norm",
                      action="store_true",
                      default=False,
                      help="do not use normalization")
    #parser.add_option("--radius", 	    type="int"         ,	default=-1   ,				help="radius for 3D variability" )
    parser.add_option("--npad",
                      type="int",
                      default=2,
                      help="number of time to pad the original images")
    parser.add_option("--sym", type="string", default="c1", help="symmetry")
    parser.add_option(
        "--fl",
        type="float",
        default=0.0,
        help=
        "cutoff freqency in absolute frequency (0.0-0.5). (Default - no filtration)"
    )
    parser.add_option(
        "--aa",
        type="float",
        default=0.0,
        help=
        "fall off of the filter. Put 0.01 if user has no clue about falloff (Default - no filtration)"
    )
    parser.add_option("--CTF",
                      action="store_true",
                      default=False,
                      help="use CFT correction")
    parser.add_option("--VERBOSE",
                      action="store_true",
                      default=False,
                      help="Long output for debugging")
    #parser.add_option("--MPI" , 		action="store_true",	default=False,				help="use MPI version")
    #parser.add_option("--radiuspca", 	type="int"         ,	default=-1   ,				help="radius for PCA" )
    #parser.add_option("--iter", 		type="int"         ,	default=40   ,				help="maximum number of iterations (stop criterion of reconstruction process)" )
    #parser.add_option("--abs", 		type="float"   ,        default=0.0  ,				help="minimum average absolute change of voxels' values (stop criterion of reconstruction process)" )
    #parser.add_option("--squ", 		type="float"   ,	    default=0.0  ,				help="minimum average squared change of voxels' values (stop criterion of reconstruction process)" )
    parser.add_option(
        "--VAR",
        action="store_true",
        default=False,
        help="stack on input consists of 2D variances (Default False)")
    parser.add_option(
        "--decimate",
        type="float",
        default=1.0,
        help=
        "image decimate rate, a number larger (expand image) or less (shrink image) than 1. default is 1"
    )
    parser.add_option(
        "--window",
        type="int",
        default=0,
        help=
        "reduce images to a small image size without changing pixel_size. Default value is zero."
    )
    #parser.add_option("--SND",			action="store_true",	default=False,				help="compute squared normalized differences (Default False)")
    parser.add_option(
        "--nvec",
        type="int",
        default=0,
        help="number of eigenvectors, default = 0 meaning no PCA calculated")
    parser.add_option(
        "--symmetrize",
        action="store_true",
        default=False,
        help="Prepare input stack for handling symmetry (Default False)")

    (options, args) = parser.parse_args()
    #####
    from mpi import mpi_init, mpi_comm_rank, mpi_comm_size, mpi_recv, MPI_COMM_WORLD
    from mpi import mpi_barrier, mpi_reduce, mpi_bcast, mpi_send, MPI_FLOAT, MPI_SUM, MPI_INT, MPI_MAX
    from applications import MPI_start_end
    from reconstruction import recons3d_em, recons3d_em_MPI
    from reconstruction import recons3d_4nn_MPI, recons3d_4nn_ctf_MPI
    from utilities import print_begin_msg, print_end_msg, print_msg
    from utilities import read_text_row, get_image, get_im
    from utilities import bcast_EMData_to_all, bcast_number_to_all
    from utilities import get_symt

    #  This is code for handling symmetries by the above program.  To be incorporated. PAP 01/27/2015

    from EMAN2db import db_open_dict

    # Set up global variables related to bdb cache
    if global_def.CACHE_DISABLE:
        from utilities import disable_bdb_cache
        disable_bdb_cache()

    # Set up global variables related to ERROR function
    global_def.BATCH = True

    # detect if program is running under MPI
    RUNNING_UNDER_MPI = "OMPI_COMM_WORLD_SIZE" in os.environ
    if RUNNING_UNDER_MPI:
        global_def.MPI = True

    if options.symmetrize:
        if RUNNING_UNDER_MPI:
            try:
                sys.argv = mpi_init(len(sys.argv), sys.argv)
                try:
                    number_of_proc = mpi_comm_size(MPI_COMM_WORLD)
                    if (number_of_proc > 1):
                        ERROR(
                            "Cannot use more than one CPU for symmetry prepration",
                            "sx3dvariability", 1)
                except:
                    pass
            except:
                pass
        if options.output_dir != "./" and not os.path.exists(
                options.output_dir):
            os.mkdir(options.output_dir)
        #  Input
        #instack = "Clean_NORM_CTF_start_wparams.hdf"
        #instack = "bdb:data"

        from logger import Logger, BaseLogger_Files
        if os.path.exists(os.path.join(options.output_dir, "log.txt")):
            os.remove(os.path.join(options.output_dir, "log.txt"))
        log_main = Logger(BaseLogger_Files())
        log_main.prefix = os.path.join(options.output_dir, "./")

        instack = args[0]
        sym = options.sym.lower()
        if (sym == "c1"):
            ERROR("There is no need to symmetrize stack for C1 symmetry",
                  "sx3dvariability", 1)

        line = ""
        for a in sys.argv:
            line += " " + a
        log_main.add(line)

        if (instack[:4] != "bdb:"):
            if output_dir == "./": stack = "bdb:data"
            else: stack = "bdb:" + options.output_dir + "/data"
            delete_bdb(stack)
            junk = cmdexecute("sxcpy.py  " + instack + "  " + stack)
        else:
            stack = instack

        qt = EMUtil.get_all_attributes(stack, 'xform.projection')

        na = len(qt)
        ts = get_symt(sym)
        ks = len(ts)
        angsa = [None] * na

        for k in xrange(ks):
            #Qfile = "Q%1d"%k
            if options.output_dir != "./":
                Qfile = os.path.join(options.output_dir, "Q%1d" % k)
            else:
                Qfile = os.path.join(options.output_dir, "Q%1d" % k)
            #delete_bdb("bdb:Q%1d"%k)
            delete_bdb("bdb:" + Qfile)
            #junk = cmdexecute("e2bdb.py  "+stack+"  --makevstack=bdb:Q%1d"%k)
            junk = cmdexecute("e2bdb.py  " + stack + "  --makevstack=bdb:" +
                              Qfile)
            #DB = db_open_dict("bdb:Q%1d"%k)
            DB = db_open_dict("bdb:" + Qfile)
            for i in xrange(na):
                ut = qt[i] * ts[k]
                DB.set_attr(i, "xform.projection", ut)
                #bt = ut.get_params("spider")
                #angsa[i] = [round(bt["phi"],3)%360.0, round(bt["theta"],3)%360.0, bt["psi"], -bt["tx"], -bt["ty"]]
            #write_text_row(angsa, 'ptsma%1d.txt'%k)
            #junk = cmdexecute("e2bdb.py  "+stack+"  --makevstack=bdb:Q%1d"%k)
            #junk = cmdexecute("sxheader.py  bdb:Q%1d  --params=xform.projection  --import=ptsma%1d.txt"%(k,k))
            DB.close()
        if options.output_dir == "./": delete_bdb("bdb:sdata")
        else: delete_bdb("bdb:" + options.output_dir + "/" + "sdata")
        #junk = cmdexecute("e2bdb.py . --makevstack=bdb:sdata --filt=Q")
        sdata = "bdb:" + options.output_dir + "/" + "sdata"
        print(sdata)
        junk = cmdexecute("e2bdb.py   " + options.output_dir +
                          "  --makevstack=" + sdata + " --filt=Q")
        #junk = cmdexecute("ls  EMAN2DB/sdata*")
        #a = get_im("bdb:sdata")
        a = get_im(sdata)
        a.set_attr("variabilitysymmetry", sym)
        #a.write_image("bdb:sdata")
        a.write_image(sdata)

    else:

        sys.argv = mpi_init(len(sys.argv), sys.argv)
        myid = mpi_comm_rank(MPI_COMM_WORLD)
        number_of_proc = mpi_comm_size(MPI_COMM_WORLD)
        main_node = 0

        if len(args) == 1:
            stack = args[0]
        else:
            print(("usage: " + usage))
            print(("Please run '" + progname + " -h' for detailed options"))
            return 1

        t0 = time()
        # obsolete flags
        options.MPI = True
        options.nvec = 0
        options.radiuspca = -1
        options.iter = 40
        options.abs = 0.0
        options.squ = 0.0

        if options.fl > 0.0 and options.aa == 0.0:
            ERROR("Fall off has to be given for the low-pass filter",
                  "sx3dvariability", 1, myid)
        if options.VAR and options.SND:
            ERROR("Only one of var and SND can be set!", "sx3dvariability",
                  myid)
            exit()
        if options.VAR and (options.ave2D or options.ave3D or options.var2D):
            ERROR(
                "When VAR is set, the program cannot output ave2D, ave3D or var2D",
                "sx3dvariability", 1, myid)
            exit()
        #if options.SND and (options.ave2D or options.ave3D):
        #	ERROR("When SND is set, the program cannot output ave2D or ave3D", "sx3dvariability", 1, myid)
        #	exit()
        if options.nvec > 0:
            ERROR("PCA option not implemented", "sx3dvariability", 1, myid)
            exit()
        if options.nvec > 0 and options.ave3D == None:
            ERROR("When doing PCA analysis, one must set ave3D",
                  "sx3dvariability",
                  myid=myid)
            exit()
        import string
        options.sym = options.sym.lower()

        # if global_def.CACHE_DISABLE:
        # 	from utilities import disable_bdb_cache
        # 	disable_bdb_cache()
        # global_def.BATCH = True

        if myid == main_node:
            if options.output_dir != "./" and not os.path.exists(
                    options.output_dir):
                os.mkdir(options.output_dir)

        img_per_grp = options.img_per_grp
        nvec = options.nvec
        radiuspca = options.radiuspca

        from logger import Logger, BaseLogger_Files
        #if os.path.exists(os.path.join(options.output_dir, "log.txt")): os.remove(os.path.join(options.output_dir, "log.txt"))
        log_main = Logger(BaseLogger_Files())
        log_main.prefix = os.path.join(options.output_dir, "./")

        if myid == main_node:
            line = ""
            for a in sys.argv:
                line += " " + a
            log_main.add(line)
            log_main.add("-------->>>Settings given by all options<<<-------")
            log_main.add("instack  		    :" + stack)
            log_main.add("output_dir        :" + options.output_dir)
            log_main.add("var3d   		    :" + options.var3D)

        if myid == main_node:
            line = strftime("%Y-%m-%d_%H:%M:%S", localtime()) + " =>"
            #print_begin_msg("sx3dvariability")
            msg = "sx3dvariability"
            log_main.add(msg)
            print(line, msg)
            msg = ("%-70s:  %s\n" % ("Input stack", stack))
            log_main.add(msg)
            print(line, msg)

        symbaselen = 0
        if myid == main_node:
            nima = EMUtil.get_image_count(stack)
            img = get_image(stack)
            nx = img.get_xsize()
            ny = img.get_ysize()
            if options.sym != "c1":
                imgdata = get_im(stack)
                try:
                    i = imgdata.get_attr("variabilitysymmetry").lower()
                    if (i != options.sym):
                        ERROR(
                            "The symmetry provided does not agree with the symmetry of the input stack",
                            "sx3dvariability",
                            myid=myid)
                except:
                    ERROR(
                        "Input stack is not prepared for symmetry, please follow instructions",
                        "sx3dvariability",
                        myid=myid)
                from utilities import get_symt
                i = len(get_symt(options.sym))
                if ((nima / i) * i != nima):
                    ERROR(
                        "The length of the input stack is incorrect for symmetry processing",
                        "sx3dvariability",
                        myid=myid)
                symbaselen = nima / i
            else:
                symbaselen = nima
        else:
            nima = 0
            nx = 0
            ny = 0
        nima = bcast_number_to_all(nima)
        nx = bcast_number_to_all(nx)
        ny = bcast_number_to_all(ny)
        Tracker = {}
        Tracker["total_stack"] = nima
        if options.decimate == 1.:
            if options.window != 0:
                nx = options.window
                ny = options.window
        else:
            if options.window == 0:
                nx = int(nx * options.decimate)
                ny = int(ny * options.decimate)
            else:
                nx = int(options.window * options.decimate)
                ny = nx
        Tracker["nx"] = nx
        Tracker["ny"] = ny
        Tracker["nz"] = nx
        symbaselen = bcast_number_to_all(symbaselen)
        if radiuspca == -1: radiuspca = nx / 2 - 2

        if myid == main_node:
            line = strftime("%Y-%m-%d_%H:%M:%S", localtime()) + " =>"
            msg = "%-70s:  %d\n" % ("Number of projection", nima)
            log_main.add(msg)
            print(line, msg)
        img_begin, img_end = MPI_start_end(nima, number_of_proc, myid)
        """
		if options.SND:
			from projection		import prep_vol, prgs
			from statistics		import im_diff
			from utilities		import get_im, model_circle, get_params_proj, set_params_proj
			from utilities		import get_ctf, generate_ctf
			from filter			import filt_ctf
		
			imgdata = EMData.read_images(stack, range(img_begin, img_end))

			if options.CTF:
				vol = recons3d_4nn_ctf_MPI(myid, imgdata, 1.0, symmetry=options.sym, npad=options.npad, xysize=-1, zsize=-1)
			else:
				vol = recons3d_4nn_MPI(myid, imgdata, symmetry=options.sym, npad=options.npad, xysize=-1, zsize=-1)

			bcast_EMData_to_all(vol, myid)
			volft, kb = prep_vol(vol)

			mask = model_circle(nx/2-2, nx, ny)
			varList = []
			for i in xrange(img_begin, img_end):
				phi, theta, psi, s2x, s2y = get_params_proj(imgdata[i-img_begin])
				ref_prj = prgs(volft, kb, [phi, theta, psi, -s2x, -s2y])
				if options.CTF:
					ctf_params = get_ctf(imgdata[i-img_begin])
					ref_prj = filt_ctf(ref_prj, generate_ctf(ctf_params))
				diff, A, B = im_diff(ref_prj, imgdata[i-img_begin], mask)
				diff2 = diff*diff
				set_params_proj(diff2, [phi, theta, psi, s2x, s2y])
				varList.append(diff2)
			mpi_barrier(MPI_COMM_WORLD)
		"""
        if options.VAR:
            #varList   = EMData.read_images(stack, range(img_begin, img_end))
            varList = []
            this_image = EMData()
            for index_of_particle in xrange(img_begin, img_end):
                this_image.read_image(stack, index_of_particle)
                varList.append(
                    image_decimate_window_xform_ctf(this_image,
                                                    options.decimate,
                                                    options.window,
                                                    options.CTF))
        else:
            from utilities import bcast_number_to_all, bcast_list_to_all, send_EMData, recv_EMData
            from utilities import set_params_proj, get_params_proj, params_3D_2D, get_params2D, set_params2D, compose_transform2
            from utilities import model_blank, nearest_proj, model_circle
            from applications import pca
            from statistics import avgvar, avgvar_ctf, ccc
            from filter import filt_tanl
            from morphology import threshold, square_root
            from projection import project, prep_vol, prgs
            from sets import Set

            if myid == main_node:
                t1 = time()
                proj_angles = []
                aveList = []
                tab = EMUtil.get_all_attributes(stack, 'xform.projection')
                for i in xrange(nima):
                    t = tab[i].get_params('spider')
                    phi = t['phi']
                    theta = t['theta']
                    psi = t['psi']
                    x = theta
                    if x > 90.0: x = 180.0 - x
                    x = x * 10000 + psi
                    proj_angles.append([x, t['phi'], t['theta'], t['psi'], i])
                t2 = time()
                line = strftime("%Y-%m-%d_%H:%M:%S", localtime()) + " =>"
                msg = "%-70s:  %d\n" % ("Number of neighboring projections",
                                        img_per_grp)
                log_main.add(msg)
                print(line, msg)
                msg = "...... Finding neighboring projections\n"
                log_main.add(msg)
                print(line, msg)
                if options.VERBOSE:
                    msg = "Number of images per group: %d" % img_per_grp
                    log_main.add(msg)
                    print(line, msg)
                    msg = "Now grouping projections"
                    log_main.add(msg)
                    print(line, msg)
                proj_angles.sort()
            proj_angles_list = [0.0] * (nima * 4)
            if myid == main_node:
                for i in xrange(nima):
                    proj_angles_list[i * 4] = proj_angles[i][1]
                    proj_angles_list[i * 4 + 1] = proj_angles[i][2]
                    proj_angles_list[i * 4 + 2] = proj_angles[i][3]
                    proj_angles_list[i * 4 + 3] = proj_angles[i][4]
            proj_angles_list = bcast_list_to_all(proj_angles_list, myid,
                                                 main_node)
            proj_angles = []
            for i in xrange(nima):
                proj_angles.append([
                    proj_angles_list[i * 4], proj_angles_list[i * 4 + 1],
                    proj_angles_list[i * 4 + 2],
                    int(proj_angles_list[i * 4 + 3])
                ])
            del proj_angles_list
            proj_list, mirror_list = nearest_proj(proj_angles, img_per_grp,
                                                  range(img_begin, img_end))

            all_proj = Set()
            for im in proj_list:
                for jm in im:
                    all_proj.add(proj_angles[jm][3])

            all_proj = list(all_proj)
            if options.VERBOSE:
                print("On node %2d, number of images needed to be read = %5d" %
                      (myid, len(all_proj)))

            index = {}
            for i in xrange(len(all_proj)):
                index[all_proj[i]] = i
            mpi_barrier(MPI_COMM_WORLD)

            if myid == main_node:
                line = strftime("%Y-%m-%d_%H:%M:%S", localtime()) + " =>"
                msg = ("%-70s:  %.2f\n" %
                       ("Finding neighboring projections lasted [s]",
                        time() - t2))
                log_main.add(msg)
                print(msg)
                msg = ("%-70s:  %d\n" %
                       ("Number of groups processed on the main node",
                        len(proj_list)))
                log_main.add(msg)
                print(line, msg)
                if options.VERBOSE:
                    print("Grouping projections took: ", (time() - t2) / 60,
                          "[min]")
                    print("Number of groups on main node: ", len(proj_list))
            mpi_barrier(MPI_COMM_WORLD)

            if myid == main_node:
                line = strftime("%Y-%m-%d_%H:%M:%S", localtime()) + " =>"
                msg = ("...... calculating the stack of 2D variances \n")
                log_main.add(msg)
                print(line, msg)
                if options.VERBOSE:
                    print("Now calculating the stack of 2D variances")

            proj_params = [0.0] * (nima * 5)
            aveList = []
            varList = []
            if nvec > 0:
                eigList = [[] for i in xrange(nvec)]

            if options.VERBOSE:
                print("Begin to read images on processor %d" % (myid))
            ttt = time()
            #imgdata = EMData.read_images(stack, all_proj)
            imgdata = []
            for index_of_proj in xrange(len(all_proj)):
                #img     = EMData()
                #img.read_image(stack, all_proj[index_of_proj])
                dmg = image_decimate_window_xform_ctf(
                    get_im(stack, all_proj[index_of_proj]), options.decimate,
                    options.window, options.CTF)
                #print dmg.get_xsize(), "init"
                imgdata.append(dmg)
            if options.VERBOSE:
                print("Reading images on processor %d done, time = %.2f" %
                      (myid, time() - ttt))
                print("On processor %d, we got %d images" %
                      (myid, len(imgdata)))
            mpi_barrier(MPI_COMM_WORLD)
            '''	
			imgdata2 = EMData.read_images(stack, range(img_begin, img_end))
			if options.fl > 0.0:
				for k in xrange(len(imgdata2)):
					imgdata2[k] = filt_tanl(imgdata2[k], options.fl, options.aa)
			if options.CTF:
				vol = recons3d_4nn_ctf_MPI(myid, imgdata2, 1.0, symmetry=options.sym, npad=options.npad, xysize=-1, zsize=-1)
			else:
				vol = recons3d_4nn_MPI(myid, imgdata2, symmetry=options.sym, npad=options.npad, xysize=-1, zsize=-1)
			if myid == main_node:
				vol.write_image("vol_ctf.hdf")
				print_msg("Writing to the disk volume reconstructed from averages as		:  %s\n"%("vol_ctf.hdf"))
			del vol, imgdata2
			mpi_barrier(MPI_COMM_WORLD)
			'''
            from applications import prepare_2d_forPCA
            from utilities import model_blank
            for i in xrange(len(proj_list)):
                ki = proj_angles[proj_list[i][0]][3]
                if ki >= symbaselen: continue
                mi = index[ki]
                phiM, thetaM, psiM, s2xM, s2yM = get_params_proj(imgdata[mi])

                grp_imgdata = []
                for j in xrange(img_per_grp):
                    mj = index[proj_angles[proj_list[i][j]][3]]
                    phi, theta, psi, s2x, s2y = get_params_proj(imgdata[mj])
                    alpha, sx, sy, mirror = params_3D_2D_NEW(
                        phi, theta, psi, s2x, s2y, mirror_list[i][j])
                    if thetaM <= 90:
                        if mirror == 0:
                            alpha, sx, sy, scale = compose_transform2(
                                alpha, sx, sy, 1.0, phiM - phi, 0.0, 0.0, 1.0)
                        else:
                            alpha, sx, sy, scale = compose_transform2(
                                alpha, sx, sy, 1.0, 180 - (phiM - phi), 0.0,
                                0.0, 1.0)
                    else:
                        if mirror == 0:
                            alpha, sx, sy, scale = compose_transform2(
                                alpha, sx, sy, 1.0, -(phiM - phi), 0.0, 0.0,
                                1.0)
                        else:
                            alpha, sx, sy, scale = compose_transform2(
                                alpha, sx, sy, 1.0, -(180 - (phiM - phi)), 0.0,
                                0.0, 1.0)
                    set_params2D(imgdata[mj], [alpha, sx, sy, mirror, 1.0])
                    grp_imgdata.append(imgdata[mj])
                    #print grp_imgdata[j].get_xsize(), imgdata[mj].get_xsize()

                if not options.no_norm:
                    #print grp_imgdata[j].get_xsize()
                    mask = model_circle(nx / 2 - 2, nx, nx)
                    for k in xrange(img_per_grp):
                        ave, std, minn, maxx = Util.infomask(
                            grp_imgdata[k], mask, False)
                        grp_imgdata[k] -= ave
                        grp_imgdata[k] /= std
                    del mask

                if options.fl > 0.0:
                    from filter import filt_ctf, filt_table
                    from fundamentals import fft, window2d
                    nx2 = 2 * nx
                    ny2 = 2 * ny
                    if options.CTF:
                        from utilities import pad
                        for k in xrange(img_per_grp):
                            grp_imgdata[k] = window2d(
                                fft(
                                    filt_tanl(
                                        filt_ctf(
                                            fft(
                                                pad(grp_imgdata[k], nx2, ny2,
                                                    1, 0.0)),
                                            grp_imgdata[k].get_attr("ctf"),
                                            binary=1), options.fl,
                                        options.aa)), nx, ny)
                            #grp_imgdata[k] = window2d(fft( filt_table( filt_tanl( filt_ctf(fft(pad(grp_imgdata[k], nx2, ny2, 1,0.0)), grp_imgdata[k].get_attr("ctf"), binary=1), options.fl, options.aa), fifi) ),nx,ny)
                            #grp_imgdata[k] = filt_tanl(grp_imgdata[k], options.fl, options.aa)
                    else:
                        for k in xrange(img_per_grp):
                            grp_imgdata[k] = filt_tanl(grp_imgdata[k],
                                                       options.fl, options.aa)
                            #grp_imgdata[k] = window2d(fft( filt_table( filt_tanl( filt_ctf(fft(pad(grp_imgdata[k], nx2, ny2, 1,0.0)), grp_imgdata[k].get_attr("ctf"), binary=1), options.fl, options.aa), fifi) ),nx,ny)
                            #grp_imgdata[k] = filt_tanl(grp_imgdata[k], options.fl, options.aa)
                else:
                    from utilities import pad, read_text_file
                    from filter import filt_ctf, filt_table
                    from fundamentals import fft, window2d
                    nx2 = 2 * nx
                    ny2 = 2 * ny
                    if options.CTF:
                        from utilities import pad
                        for k in xrange(img_per_grp):
                            grp_imgdata[k] = window2d(
                                fft(
                                    filt_ctf(fft(
                                        pad(grp_imgdata[k], nx2, ny2, 1, 0.0)),
                                             grp_imgdata[k].get_attr("ctf"),
                                             binary=1)), nx, ny)
                            #grp_imgdata[k] = window2d(fft( filt_table( filt_tanl( filt_ctf(fft(pad(grp_imgdata[k], nx2, ny2, 1,0.0)), grp_imgdata[k].get_attr("ctf"), binary=1), options.fl, options.aa), fifi) ),nx,ny)
                            #grp_imgdata[k] = filt_tanl(grp_imgdata[k], options.fl, options.aa)
                '''
				if i < 10 and myid == main_node:
					for k in xrange(10):
						grp_imgdata[k].write_image("grp%03d.hdf"%i, k)
				'''
                """
				if myid == main_node and i==0:
					for pp in xrange(len(grp_imgdata)):
						grp_imgdata[pp].write_image("pp.hdf", pp)
				"""
                ave, grp_imgdata = prepare_2d_forPCA(grp_imgdata)
                """
				if myid == main_node and i==0:
					for pp in xrange(len(grp_imgdata)):
						grp_imgdata[pp].write_image("qq.hdf", pp)
				"""

                var = model_blank(nx, ny)
                for q in grp_imgdata:
                    Util.add_img2(var, q)
                Util.mul_scalar(var, 1.0 / (len(grp_imgdata) - 1))
                # Switch to std dev
                var = square_root(threshold(var))
                #if options.CTF:	ave, var = avgvar_ctf(grp_imgdata, mode="a")
                #else:	            ave, var = avgvar(grp_imgdata, mode="a")
                """
				if myid == main_node:
					ave.write_image("avgv.hdf",i)
					var.write_image("varv.hdf",i)
				"""

                set_params_proj(ave, [phiM, thetaM, 0.0, 0.0, 0.0])
                set_params_proj(var, [phiM, thetaM, 0.0, 0.0, 0.0])

                aveList.append(ave)
                varList.append(var)

                if options.VERBOSE:
                    print("%5.2f%% done on processor %d" %
                          (i * 100.0 / len(proj_list), myid))
                if nvec > 0:
                    eig = pca(input_stacks=grp_imgdata,
                              subavg="",
                              mask_radius=radiuspca,
                              nvec=nvec,
                              incore=True,
                              shuffle=False,
                              genbuf=True)
                    for k in xrange(nvec):
                        set_params_proj(eig[k], [phiM, thetaM, 0.0, 0.0, 0.0])
                        eigList[k].append(eig[k])
                    """
					if myid == 0 and i == 0:
						for k in xrange(nvec):
							eig[k].write_image("eig.hdf", k)
					"""

            del imgdata
            #  To this point, all averages, variances, and eigenvectors are computed

            if options.ave2D:
                from fundamentals import fpol
                if myid == main_node:
                    km = 0
                    for i in xrange(number_of_proc):
                        if i == main_node:
                            for im in xrange(len(aveList)):
                                aveList[im].write_image(
                                    os.path.join(options.output_dir,
                                                 options.ave2D), km)
                                km += 1
                        else:
                            nl = mpi_recv(1, MPI_INT, i,
                                          SPARX_MPI_TAG_UNIVERSAL,
                                          MPI_COMM_WORLD)
                            nl = int(nl[0])
                            for im in xrange(nl):
                                ave = recv_EMData(i, im + i + 70000)
                                """
								nm = mpi_recv(1, MPI_INT, i, SPARX_MPI_TAG_UNIVERSAL, MPI_COMM_WORLD)
								nm = int(nm[0])
								members = mpi_recv(nm, MPI_INT, i, SPARX_MPI_TAG_UNIVERSAL, MPI_COMM_WORLD)
								ave.set_attr('members', map(int, members))
								members = mpi_recv(nm, MPI_FLOAT, i, SPARX_MPI_TAG_UNIVERSAL, MPI_COMM_WORLD)
								ave.set_attr('pix_err', map(float, members))
								members = mpi_recv(3, MPI_FLOAT, i, SPARX_MPI_TAG_UNIVERSAL, MPI_COMM_WORLD)
								ave.set_attr('refprojdir', map(float, members))
								"""
                                tmpvol = fpol(ave, Tracker["nx"],
                                              Tracker["nx"], 1)
                                tmpvol.write_image(
                                    os.path.join(options.output_dir,
                                                 options.ave2D), km)
                                km += 1
                else:
                    mpi_send(len(aveList), 1, MPI_INT, main_node,
                             SPARX_MPI_TAG_UNIVERSAL, MPI_COMM_WORLD)
                    for im in xrange(len(aveList)):
                        send_EMData(aveList[im], main_node, im + myid + 70000)
                        """
						members = aveList[im].get_attr('members')
						mpi_send(len(members), 1, MPI_INT, main_node, SPARX_MPI_TAG_UNIVERSAL, MPI_COMM_WORLD)
						mpi_send(members, len(members), MPI_INT, main_node, SPARX_MPI_TAG_UNIVERSAL, MPI_COMM_WORLD)
						members = aveList[im].get_attr('pix_err')
						mpi_send(members, len(members), MPI_FLOAT, main_node, SPARX_MPI_TAG_UNIVERSAL, MPI_COMM_WORLD)
						try:
							members = aveList[im].get_attr('refprojdir')
							mpi_send(members, 3, MPI_FLOAT, main_node, SPARX_MPI_TAG_UNIVERSAL, MPI_COMM_WORLD)
						except:
							mpi_send([-999.0,-999.0,-999.0], 3, MPI_FLOAT, main_node, SPARX_MPI_TAG_UNIVERSAL, MPI_COMM_WORLD)
						"""

            if options.ave3D:
                from fundamentals import fpol
                if options.VERBOSE:
                    print("Reconstructing 3D average volume")
                ave3D = recons3d_4nn_MPI(myid,
                                         aveList,
                                         symmetry=options.sym,
                                         npad=options.npad)
                bcast_EMData_to_all(ave3D, myid)
                if myid == main_node:
                    line = strftime("%Y-%m-%d_%H:%M:%S", localtime()) + " =>"
                    ave3D = fpol(ave3D, Tracker["nx"], Tracker["nx"],
                                 Tracker["nx"])
                    ave3D.write_image(
                        os.path.join(options.output_dir, options.ave3D))
                    msg = ("%-70s:  %s\n" % (
                        "Writing to the disk volume reconstructed from averages as",
                        options.ave3D))
                    log_main.add(msg)
                    print(line, msg)
            del ave, var, proj_list, stack, phi, theta, psi, s2x, s2y, alpha, sx, sy, mirror, aveList

            if nvec > 0:
                for k in xrange(nvec):
                    if options.VERBOSE:
                        print("Reconstruction eigenvolumes", k)
                    cont = True
                    ITER = 0
                    mask2d = model_circle(radiuspca, nx, nx)
                    while cont:
                        #print "On node %d, iteration %d"%(myid, ITER)
                        eig3D = recons3d_4nn_MPI(myid,
                                                 eigList[k],
                                                 symmetry=options.sym,
                                                 npad=options.npad)
                        bcast_EMData_to_all(eig3D, myid, main_node)
                        if options.fl > 0.0:
                            eig3D = filt_tanl(eig3D, options.fl, options.aa)
                        if myid == main_node:
                            eig3D.write_image(
                                os.path.join(options.outpout_dir,
                                             "eig3d_%03d.hdf" % (k, ITER)))
                        Util.mul_img(eig3D,
                                     model_circle(radiuspca, nx, nx, nx))
                        eig3Df, kb = prep_vol(eig3D)
                        del eig3D
                        cont = False
                        icont = 0
                        for l in xrange(len(eigList[k])):
                            phi, theta, psi, s2x, s2y = get_params_proj(
                                eigList[k][l])
                            proj = prgs(eig3Df, kb,
                                        [phi, theta, psi, s2x, s2y])
                            cl = ccc(proj, eigList[k][l], mask2d)
                            if cl < 0.0:
                                icont += 1
                                cont = True
                                eigList[k][l] *= -1.0
                        u = int(cont)
                        u = mpi_reduce([u], 1, MPI_INT, MPI_MAX, main_node,
                                       MPI_COMM_WORLD)
                        icont = mpi_reduce([icont], 1, MPI_INT, MPI_SUM,
                                           main_node, MPI_COMM_WORLD)

                        if myid == main_node:
                            line = strftime("%Y-%m-%d_%H:%M:%S",
                                            localtime()) + " =>"
                            u = int(u[0])
                            msg = (" Eigenvector: ", k, " number changed ",
                                   int(icont[0]))
                            log_main.add(msg)
                            print(line, msg)
                        else:
                            u = 0
                        u = bcast_number_to_all(u, main_node)
                        cont = bool(u)
                        ITER += 1

                    del eig3Df, kb
                    mpi_barrier(MPI_COMM_WORLD)
                del eigList, mask2d

            if options.ave3D: del ave3D
            if options.var2D:
                from fundamentals import fpol
                if myid == main_node:
                    km = 0
                    for i in xrange(number_of_proc):
                        if i == main_node:
                            for im in xrange(len(varList)):
                                tmpvol = fpol(varList[im], Tracker["nx"],
                                              Tracker["nx"], 1)
                                tmpvol.write_image(
                                    os.path.join(options.output_dir,
                                                 options.var2D), km)
                                km += 1
                        else:
                            nl = mpi_recv(1, MPI_INT, i,
                                          SPARX_MPI_TAG_UNIVERSAL,
                                          MPI_COMM_WORLD)
                            nl = int(nl[0])
                            for im in xrange(nl):
                                ave = recv_EMData(i, im + i + 70000)
                                tmpvol = fpol(ave, Tracker["nx"],
                                              Tracker["nx"], 1)
                                tmpvol.write_image(
                                    os.path.join(options.output_dir,
                                                 options.var2D, km))
                                km += 1
                else:
                    mpi_send(len(varList), 1, MPI_INT, main_node,
                             SPARX_MPI_TAG_UNIVERSAL, MPI_COMM_WORLD)
                    for im in xrange(len(varList)):
                        send_EMData(varList[im], main_node, im + myid +
                                    70000)  #  What with the attributes??

            mpi_barrier(MPI_COMM_WORLD)

        if options.var3D:
            if myid == main_node and options.VERBOSE:
                line = strftime("%Y-%m-%d_%H:%M:%S", localtime()) + " =>"
                msg = ("Reconstructing 3D variability volume")
                log_main.add(msg)
                print(line, msg)
            t6 = time()
            # radiusvar = options.radius
            # if( radiusvar < 0 ):  radiusvar = nx//2 -3
            res = recons3d_4nn_MPI(myid,
                                   varList,
                                   symmetry=options.sym,
                                   npad=options.npad)
            #res = recons3d_em_MPI(varList, vol_stack, options.iter, radiusvar, options.abs, True, options.sym, options.squ)
            if myid == main_node:
                from fundamentals import fpol
                res = fpol(res, Tracker["nx"], Tracker["nx"], Tracker["nx"])
                res.write_image(os.path.join(options.output_dir,
                                             options.var3D))

            if myid == main_node:
                line = strftime("%Y-%m-%d_%H:%M:%S", localtime()) + " =>"
                msg = ("%-70s:  %.2f\n" %
                       ("Reconstructing 3D variability took [s]", time() - t6))
                log_main.add(msg)
                print(line, msg)
                if options.VERBOSE:
                    print("Reconstruction took: %.2f [min]" %
                          ((time() - t6) / 60))

            if myid == main_node:
                line = strftime("%Y-%m-%d_%H:%M:%S", localtime()) + " =>"
                msg = ("%-70s:  %.2f\n" %
                       ("Total time for these computations [s]", time() - t0))
                print(line, msg)
                log_main.add(msg)
                if options.VERBOSE:
                    print("Total time for these computations: %.2f [min]" %
                          ((time() - t0) / 60))
                line = strftime("%Y-%m-%d_%H:%M:%S", localtime()) + " =>"
                msg = ("sx3dvariability")
                print(line, msg)
                log_main.add(msg)

        from mpi import mpi_finalize
        mpi_finalize()

        if RUNNING_UNDER_MPI:
            global_def.MPI = False

        global_def.BATCH = False
Beispiel #9
0
def main():
    import os
    import sys
    from optparse import OptionParser
    arglist = []
    for arg in sys.argv:
        arglist.append(arg)
    progname = os.path.basename(arglist[0])
    usage = progname + """ inputvolume  locresvolume maskfile outputfile   --radius --falloff  --MPI

	    Locally filer a volume based on local resolution volume (sxlocres.py) within area outlined by the maskfile
	"""
    parser = OptionParser(usage, version=SPARXVERSION)

    parser.add_option(
        "--radius",
        type="int",
        default=-1,
        help=
        "if there is no maskfile, sphere with r=radius will be used, by default the radius is nx/2-1"
    )
    parser.add_option("--falloff",
                      type="float",
                      default=0.1,
                      help="falloff of tanl filter (default 0.1)")
    parser.add_option("--MPI",
                      action="store_true",
                      default=False,
                      help="use MPI version")

    (options, args) = parser.parse_args(arglist[1:])

    if len(args) < 3 or len(args) > 4:
        print("See usage " + usage)
        sys.exit()

    if global_def.CACHE_DISABLE:
        from utilities import disable_bdb_cache
        disable_bdb_cache()

    if options.MPI:
        from mpi import mpi_init, mpi_comm_size, mpi_comm_rank, MPI_COMM_WORLD
        from mpi import mpi_reduce, mpi_bcast, mpi_barrier, mpi_gatherv, mpi_send, mpi_recv
        from mpi import MPI_SUM, MPI_FLOAT, MPI_INT
        sys.argv = mpi_init(len(sys.argv), sys.argv)

        number_of_proc = mpi_comm_size(MPI_COMM_WORLD)
        myid = mpi_comm_rank(MPI_COMM_WORLD)
        main_node = 0

        if (myid == main_node):
            #print sys.argv
            vi = get_im(sys.argv[1])
            ui = get_im(sys.argv[2])
            #print   Util.infomask(ui, None, True)
            radius = options.radius
            nx = vi.get_xsize()
            ny = vi.get_ysize()
            nz = vi.get_zsize()
            dis = [nx, ny, nz]
        else:
            falloff = 0.0
            radius = 0
            dis = [0, 0, 0]
            vi = None
            ui = None
        dis = bcast_list_to_all(dis, myid, source_node=main_node)

        if (myid != main_node):
            nx = int(dis[0])
            ny = int(dis[1])
            nz = int(dis[2])
        radius = bcast_number_to_all(radius, main_node)
        if len(args) == 3:
            if (radius == -1): radius = min(nx, ny, nz) // 2 - 1
            m = model_circle(radius, nx, ny, nz)
            outvol = args[2]

        elif len(args) == 4:
            if (myid == main_node): m = binarize(get_im(args[2]), 0.5)
            else: m = model_blank(nx, ny, nz)
            outvol = args[3]
            bcast_EMData_to_all(m, myid, main_node)

        from filter import filterlocal
        filteredvol = filterlocal(ui, vi, m, options.falloff, myid, main_node,
                                  number_of_proc)

        if (myid == 0): filteredvol.write_image(outvol)

        from mpi import mpi_finalize
        mpi_finalize()

    else:
        vi = get_im(args[0])
        ui = get_im(
            args[1]
        )  # resolution volume, values are assumed to be from 0 to 0.5

        nn = vi.get_xsize()

        falloff = options.falloff

        if len(args) == 3:
            radius = options.radius
            if (radius == -1): radius = nn // 2 - 1
            m = model_circle(radius, nn, nn, nn)
            outvol = args[2]

        elif len(args) == 4:
            m = binarize(get_im(args[2]), 0.5)
            outvol = args[3]

        fftip(vi)  # this is the volume to be filtered

        #  Round all resolution numbers to two digits
        for x in range(nn):
            for y in range(nn):
                for z in range(nn):
                    ui.set_value_at_fast(x, y, z,
                                         round(ui.get_value_at(x, y, z), 2))
        st = Util.infomask(ui, m, True)

        filteredvol = model_blank(nn, nn, nn)
        cutoff = max(st[2] - 0.01, 0.0)
        while (cutoff < st[3]):
            cutoff = round(cutoff + 0.01, 2)
            pt = Util.infomask(
                threshold_outside(ui, cutoff - 0.00501, cutoff + 0.005), m,
                True)
            if (pt[0] != 0.0):
                vovo = fft(filt_tanl(vi, cutoff, falloff))
                for x in range(nn):
                    for y in range(nn):
                        for z in range(nn):
                            if (m.get_value_at(x, y, z) > 0.5):
                                if (round(ui.get_value_at(x, y, z),
                                          2) == cutoff):
                                    filteredvol.set_value_at_fast(
                                        x, y, z, vovo.get_value_at(x, y, z))

        filteredvol.write_image(outvol)
Beispiel #10
0
def calculate_volumes_after_rotation_and_save_them(ali3d_options, rviper_iter, masterdir, bdb_stack_location, mpi_rank, mpi_size,
												   no_of_viper_runs_analyzed_together, no_of_viper_runs_analyzed_together_from_user_options, mpi_comm = -1):
	
	# This function takes into account the case in which there are more processors than images

	if mpi_comm == -1:
		mpi_comm = MPI_COMM_WORLD

	# some arguments are for debugging purposes

	mainoutputdir = masterdir + DIR_DELIM + NAME_OF_MAIN_DIR + ("%03d" + DIR_DELIM) %(rviper_iter)

	# list_of_projection_indices_used_for_outlier_elimination = map(int, read_text_file(mainoutputdir + DIR_DELIM + "list_of_viper_runs_included_in_outlier_elimination.txt"))
	import json; f = open(mainoutputdir + "list_of_viper_runs_included_in_outlier_elimination.json", 'r')
	list_of_independent_viper_run_indices_used_for_outlier_elimination  = json.load(f); f.close()

	if len(list_of_independent_viper_run_indices_used_for_outlier_elimination)==0:
		print "Error: len(list_of_independent_viper_run_indices_used_for_outlier_elimination)==0"
		mpi_finalize()
		sys.exit()

	# if this data analysis step was already performed in the past then return
	# for future changes make sure that the file checked is the last one to be processed !!!
	
	# if(os.path.exists(mainoutputdir + DIR_DELIM + NAME_OF_RUN_DIR + "%03d"%(no_of_viper_runs_analyzed_together - 1) + DIR_DELIM + "rotated_volume.hdf")):
	# check_last_run = max(get_latest_directory_increment_value(mainoutputdir, NAME_OF_RUN_DIR, start_value=0), no_of_viper_runs_analyzed_together_from_user_options)
	# if(os.path.exists(mainoutputdir + DIR_DELIM + NAME_OF_RUN_DIR + "%03d"%(check_last_run) + DIR_DELIM + "rotated_volume.hdf")):
	# 	return

	# if this data analysis step was already performed in the past then return
	for check_run in list_of_independent_viper_run_indices_used_for_outlier_elimination:
		if not (os.path.exists(mainoutputdir + DIR_DELIM + NAME_OF_RUN_DIR + "%03d"%(check_run) + DIR_DELIM + "rotated_volume.hdf")):
			break
	else:
		return

	partstack = []
	# for i1 in range(0,no_of_viper_runs_analyzed_together):
	for i1 in list_of_independent_viper_run_indices_used_for_outlier_elimination:
		partstack.append(mainoutputdir + NAME_OF_RUN_DIR + "%03d"%(i1) + DIR_DELIM + "rotated_reduced_params.txt")
	partids_file_name = mainoutputdir + "this_iteration_index_keep_images.txt"

	lpartids = map(int, read_text_file(partids_file_name) )
	n_projs = len(lpartids)


	if (mpi_size > n_projs):
		# if there are more processors than images
		working = int(not(mpi_rank < n_projs))
		mpi_subcomm = mpi_comm_split(mpi_comm, working,  mpi_rank - working*n_projs)
		mpi_subsize = mpi_comm_size(mpi_subcomm)
		mpi_subrank = mpi_comm_rank(mpi_subcomm)
		if (mpi_rank < n_projs):

			# for i in xrange(no_of_viper_runs_analyzed_together):
			for idx, i in enumerate(list_of_independent_viper_run_indices_used_for_outlier_elimination):
				projdata = getindexdata(bdb_stack_location + "_%03d"%(rviper_iter - 1), partids_file_name, partstack[idx], mpi_rank, mpi_subsize)
				vol = do_volume(projdata, ali3d_options, 0, mpi_comm = mpi_subcomm)
				del projdata
				if( mpi_rank == 0):
					vol.write_image(mainoutputdir + DIR_DELIM + NAME_OF_RUN_DIR + "%03d"%(i) + DIR_DELIM + "rotated_volume.hdf")
					line = strftime("%Y-%m-%d_%H:%M:%S", localtime()) + " => "
					print line  + "Generated rec_ref_volume_run #%01d \n"%i
				del vol

		mpi_barrier(mpi_comm)
	else:
		for idx, i in enumerate(list_of_independent_viper_run_indices_used_for_outlier_elimination):
			projdata = getindexdata(bdb_stack_location + "_%03d"%(rviper_iter - 1), partids_file_name, partstack[idx], mpi_rank, mpi_size)
			vol = do_volume(projdata, ali3d_options, 0, mpi_comm = mpi_comm)
			del projdata
			if( mpi_rank == 0):
				vol.write_image(mainoutputdir + DIR_DELIM + NAME_OF_RUN_DIR + "%03d"%(i) + DIR_DELIM + "rotated_volume.hdf")
				line = strftime("%Y-%m-%d_%H:%M:%S", localtime()) + " => "
				print line + "Generated rec_ref_volume_run #%01d"%i
			del vol

	if( mpi_rank == 0):
		# Align all rotated volumes, calculate their average and save as an overall result
		from utilities import get_params3D, set_params3D, get_im, model_circle
		from statistics import ave_var
		from applications import ali_vol
		# vls = [None]*no_of_viper_runs_analyzed_together
		vls = [None]*len(list_of_independent_viper_run_indices_used_for_outlier_elimination)
		# for i in xrange(no_of_viper_runs_analyzed_together):
		for idx, i in enumerate(list_of_independent_viper_run_indices_used_for_outlier_elimination):
			vls[idx] = get_im(mainoutputdir + DIR_DELIM + NAME_OF_RUN_DIR + "%03d"%(i) + DIR_DELIM + "rotated_volume.hdf")
			set_params3D(vls[idx],[0.,0.,0.,0.,0.,0.,0,1.0])
		asa,sas = ave_var(vls)
		# do the alignment
		nx = asa.get_xsize()
		radius = nx/2 - .5
		st = Util.infomask(asa*asa, model_circle(radius,nx,nx,nx), True)
		goal = st[0]
		going = True
		while(going):
			set_params3D(asa,[0.,0.,0.,0.,0.,0.,0,1.0])
			# for i in xrange(no_of_viper_runs_analyzed_together):
			for idx, i in enumerate(list_of_independent_viper_run_indices_used_for_outlier_elimination):
				o = ali_vol(vls[idx],asa,7.0,5.,radius)  # range of angles and shifts, maybe should be adjusted
				p = get_params3D(o)
				del o
				set_params3D(vls[idx],p)
			asa,sas = ave_var(vls)
			st = Util.infomask(asa*asa, model_circle(radius,nx,nx,nx), True)
			if(st[0] > goal):  goal = st[0]
			else:  going = False
		# over and out
		asa.write_image(mainoutputdir + DIR_DELIM + "average_volume.hdf")
		sas.write_image(mainoutputdir + DIR_DELIM + "variance_volume.hdf")
	return
Beispiel #11
0
def main():
	from logger import Logger, BaseLogger_Files
        arglist = []
        i = 0
        while( i < len(sys.argv) ):
            if sys.argv[i]=='-p4pg':
                i = i+2
            elif sys.argv[i]=='-p4wd':
                i = i+2
            else:
                arglist.append( sys.argv[i] )
                i = i+1
	progname = os.path.basename(arglist[0])
	usage = progname + " stack  outdir  <mask> --focus=3Dmask --radius=outer_radius --delta=angular_step" +\
	"--an=angular_neighborhood --maxit=max_iter  --CTF --sym=c1 --function=user_function --independent=indenpendent_runs  --number_of_images_per_group=number_of_images_per_group  --low_pass_frequency=.25  --seed=random_seed"
	parser = OptionParser(usage,version=SPARXVERSION)
	parser.add_option("--focus",                         type   ="string",        default ='',                    help="bineary 3D mask for focused clustering ")
	parser.add_option("--ir",                            type   = "int",          default =1, 	                  help="inner radius for rotational correlation > 0 (set to 1)")
	parser.add_option("--radius",                        type   = "int",          default =-1,	                  help="particle radius in pixel for rotational correlation <nx-1 (set to the radius of the particle)")
	parser.add_option("--maxit",	                     type   = "int",          default =25, 	                  help="maximum number of iteration")
	parser.add_option("--rs",                            type   = "int",          default =1,	                  help="step between rings in rotational correlation >0 (set to 1)" ) 
	parser.add_option("--xr",                            type   ="string",        default ='1',                   help="range for translation search in x direction, search is +/-xr ")
	parser.add_option("--yr",                            type   ="string",        default ='-1',	              help="range for translation search in y direction, search is +/-yr (default = same as xr)")
	parser.add_option("--ts",                            type   ="string",        default ='0.25',                help="step size of the translation search in both directions direction, search is -xr, -xr+ts, 0, xr-ts, xr ")
	parser.add_option("--delta",                         type   ="string",        default ='2',                   help="angular step of reference projections")
	parser.add_option("--an",                            type   ="string",        default ='-1',	              help="angular neighborhood for local searches")
	parser.add_option("--center",                        type   ="int",           default =0,	                  help="0 - if you do not want the volume to be centered, 1 - center the volume using cog (default=0)")
	parser.add_option("--nassign",                       type   ="int",           default =1, 	                  help="number of reassignment iterations performed for each angular step (set to 3) ")
	parser.add_option("--nrefine",                       type   ="int",           default =0, 	                  help="number of alignment iterations performed for each angular step (set to 0)")
	parser.add_option("--CTF",                           action ="store_true",    default =False,                 help="do CTF correction during clustring")
	parser.add_option("--stoprnct",                      type   ="float",         default =3.0,                   help="Minimum percentage of assignment change to stop the program")
	parser.add_option("--sym",                           type   ="string",        default ='c1',                  help="symmetry of the structure ")
	parser.add_option("--function",                      type   ="string",        default ='do_volume_mrk05',     help="name of the reference preparation function")
	parser.add_option("--independent",                   type   ="int",           default = 3,                    help="number of independent run")
	parser.add_option("--number_of_images_per_group",    type   ="int",           default =1000,                  help="number of groups")
	parser.add_option("--low_pass_filter",               type   ="float",         default =-1.0,                  help="absolute frequency of low-pass filter for 3d sorting on the original image size" )
	parser.add_option("--nxinit",                        type   ="int",           default =64,                    help="initial image size for sorting" )
	parser.add_option("--unaccounted",                   action ="store_true",    default =False,                 help="reconstruct the unaccounted images")
	parser.add_option("--seed",                          type   ="int",           default =-1,                    help="random seed for create initial random assignment for EQ Kmeans")
	parser.add_option("--smallest_group",                type   ="int",           default =500,                   help="minimum members for identified group")
	parser.add_option("--sausage",                       action ="store_true",    default =False,                 help="way of filter volume")
	parser.add_option("--chunkdir",                      type   ="string",        default ='',                    help="chunkdir for computing margin of error")
	parser.add_option("--PWadjustment",                  type   ="string",        default ='',                    help="1-D power spectrum of PDB file used for EM volume power spectrum correction")
	parser.add_option("--protein_shape",                 type   ="string",        default ='g',                   help="protein shape. It defines protein preferred orientation angles. Currently it has g and f two types ")
	parser.add_option("--upscale",                       type   ="float",         default =0.5,                   help=" scaling parameter to adjust the power spectrum of EM volumes")
	parser.add_option("--wn",                            type   ="int",           default =0,                     help="optimal window size for data processing")
	parser.add_option("--interpolation",                 type   ="string",        default ="4nn",                 help="3-d reconstruction interpolation method, two options trl and 4nn")
	(options, args) = parser.parse_args(arglist[1:])
	if len(args) < 1  or len(args) > 4:
    		print "usage: " + usage
    		print "Please run '" + progname + " -h' for detailed options"
	else:

		if len(args)>2:
			mask_file = args[2]
		else:
			mask_file = None

		orgstack                        =args[0]
		masterdir                       =args[1]
		global_def.BATCH = True
		#---initialize MPI related variables
		from mpi import mpi_init, mpi_comm_size, MPI_COMM_WORLD, mpi_comm_rank,mpi_barrier,mpi_bcast, mpi_bcast, MPI_INT,MPI_CHAR
		sys.argv = mpi_init(len(sys.argv),sys.argv)
		nproc    = mpi_comm_size(MPI_COMM_WORLD)
		myid     = mpi_comm_rank(MPI_COMM_WORLD)
		mpi_comm = MPI_COMM_WORLD
		main_node= 0
		# import some utilities
		from utilities import get_im,bcast_number_to_all,cmdexecute,write_text_file,read_text_file,wrap_mpi_bcast, get_params_proj, write_text_row
		from applications import recons3d_n_MPI, mref_ali3d_MPI, Kmref_ali3d_MPI
		from statistics import k_means_match_clusters_asg_new,k_means_stab_bbenum
		from applications import mref_ali3d_EQ_Kmeans, ali3d_mref_Kmeans_MPI  
		# Create the main log file
		from logger import Logger,BaseLogger_Files
		if myid ==main_node:
			log_main=Logger(BaseLogger_Files())
			log_main.prefix = masterdir+"/"
		else:
			log_main =None
		#--- fill input parameters into dictionary named after Constants
		Constants		                         ={}
		Constants["stack"]                       = args[0]
		Constants["masterdir"]                   = masterdir
		Constants["mask3D"]                      = mask_file
		Constants["focus3Dmask"]                 = options.focus
		Constants["indep_runs"]                  = options.independent
		Constants["stoprnct"]                    = options.stoprnct
		Constants["number_of_images_per_group"]  = options.number_of_images_per_group
		Constants["CTF"]                         = options.CTF
		Constants["maxit"]                       = options.maxit
		Constants["ir"]                          = options.ir 
		Constants["radius"]                      = options.radius 
		Constants["nassign"]                     = options.nassign
		Constants["rs"]                          = options.rs 
		Constants["xr"]                          = options.xr
		Constants["yr"]                          = options.yr
		Constants["ts"]                          = options.ts
		Constants["delta"]               		 = options.delta
		Constants["an"]                  		 = options.an
		Constants["sym"]                 		 = options.sym
		Constants["center"]              		 = options.center
		Constants["nrefine"]             		 = options.nrefine
		#Constants["fourvar"]            		 = options.fourvar 
		Constants["user_func"]           		 = options.function
		Constants["low_pass_filter"]     		 = options.low_pass_filter # enforced low_pass_filter
		#Constants["debug"]              		 = options.debug
		Constants["main_log_prefix"]     		 = args[1]
		#Constants["importali3d"]        		 = options.importali3d
		Constants["myid"]	             		 = myid
		Constants["main_node"]           		 = main_node
		Constants["nproc"]               		 = nproc
		Constants["log_main"]            		 = log_main
		Constants["nxinit"]              		 = options.nxinit
		Constants["unaccounted"]         		 = options.unaccounted
		Constants["seed"]                		 = options.seed
		Constants["smallest_group"]      		 = options.smallest_group
		Constants["sausage"]             		 = options.sausage
		Constants["chunkdir"]            		 = options.chunkdir
		Constants["PWadjustment"]        		 = options.PWadjustment
		Constants["upscale"]             		 = options.upscale
		Constants["wn"]                  		 = options.wn
		Constants["3d-interpolation"]    		 = options.interpolation
		Constants["protein_shape"]    		     = options.protein_shape 
		# -----------------------------------------------------
		#
		# Create and initialize Tracker dictionary with input options
		Tracker = 			    		{}
		Tracker["constants"]       = Constants
		Tracker["maxit"]           = Tracker["constants"]["maxit"]
		Tracker["radius"]          = Tracker["constants"]["radius"]
		#Tracker["xr"]             = ""
		#Tracker["yr"]             = "-1"  # Do not change!
		#Tracker["ts"]             = 1
		#Tracker["an"]             = "-1"
		#Tracker["delta"]          = "2.0"
		#Tracker["zoom"]           = True
		#Tracker["nsoft"]          = 0
		#Tracker["local"]          = False
		#Tracker["PWadjustment"]   = Tracker["constants"]["PWadjustment"]
		Tracker["upscale"]         = Tracker["constants"]["upscale"]
		#Tracker["upscale"]        = 0.5
		Tracker["applyctf"]        = False  #  Should the data be premultiplied by the CTF.  Set to False for local continuous.
		#Tracker["refvol"]         = None
		Tracker["nxinit"]          = Tracker["constants"]["nxinit"]
		#Tracker["nxstep"]         = 32
		Tracker["icurrentres"]     = -1
		#Tracker["ireachedres"]    = -1
		#Tracker["lowpass"]        = 0.4
		#Tracker["falloff"]        = 0.2
		#Tracker["inires"]         = options.inires  # Now in A, convert to absolute before using
		Tracker["fuse_freq"]       = 50  # Now in A, convert to absolute before using
		#Tracker["delpreviousmax"] = False
		#Tracker["anger"]          = -1.0
		#Tracker["shifter"]        = -1.0
		#Tracker["saturatecrit"]   = 0.95
		#Tracker["pixercutoff"]    = 2.0
		#Tracker["directory"]      = ""
		#Tracker["previousoutputdir"] = ""
		#Tracker["eliminated-outliers"] = False
		#Tracker["mainiteration"]  = 0
		#Tracker["movedback"]      = False
		#Tracker["state"]          = Tracker["constants"]["states"][0] 
		#Tracker["global_resolution"] =0.0
		Tracker["orgstack"]        = orgstack
		#--------------------------------------------------------------------
		# import from utilities
		from utilities import sample_down_1D_curve,get_initial_ID,remove_small_groups,print_upper_triangular_matrix,print_a_line_with_timestamp
		from utilities import print_dict,get_resolution_mrk01,partition_to_groups,partition_independent_runs,get_outliers
		from utilities import merge_groups, save_alist, margin_of_error, get_margin_of_error, do_two_way_comparison, select_two_runs, get_ali3d_params
		from utilities import counting_projections, unload_dict, load_dict, get_stat_proj, create_random_list, get_number_of_groups, recons_mref
		from utilities import apply_low_pass_filter, get_groups_from_partition, get_number_of_groups, get_complementary_elements_total, update_full_dict
		from utilities import count_chunk_members, set_filter_parameters_from_adjusted_fsc, adjust_fsc_down, get_two_chunks_from_stack
		####------------------------------------------------------------------
		#
		# Get the pixel size; if none, set to 1.0, and the original image size
		from utilities import get_shrink_data_huang
		if(myid == main_node):
			line = strftime("%Y-%m-%d_%H:%M:%S", localtime()) + " =>"
			print(line+"Initialization of 3-D sorting")
			a = get_im(orgstack)
			nnxo = a.get_xsize()
			if( Tracker["nxinit"] > nnxo ):
				ERROR("Image size less than minimum permitted $d"%Tracker["nxinit"],"sxsort3d.py",1)
				nnxo = -1
			else:
				if Tracker["constants"]["CTF"]:
					i = a.get_attr('ctf')
					pixel_size = i.apix
					fq = pixel_size/Tracker["fuse_freq"]
				else:
					pixel_size = 1.0
					#  No pixel size, fusing computed as 5 Fourier pixels
					fq = 5.0/nnxo
					del a
		else:
			nnxo = 0
			fq = 0.0
			pixel_size = 1.0
		nnxo = bcast_number_to_all(nnxo, source_node = main_node)
		if( nnxo < 0 ):
			mpi_finalize()
			exit()
		pixel_size = bcast_number_to_all(pixel_size, source_node = main_node)
		fq         = bcast_number_to_all(fq, source_node = main_node)
		if Tracker["constants"]["wn"]==0:
			Tracker["constants"]["nnxo"]          = nnxo
		else:
			Tracker["constants"]["nnxo"]          = Tracker["constants"]["wn"]
			nnxo                                  = Tracker["constants"]["nnxo"]
		Tracker["constants"]["pixel_size"]        = pixel_size
		Tracker["fuse_freq"]                      = fq
		del fq, nnxo, pixel_size
		if(Tracker["constants"]["radius"] < 1):
			Tracker["constants"]["radius"]  = Tracker["constants"]["nnxo"]//2-2
		elif((2*Tracker["constants"]["radius"] +2) > Tracker["constants"]["nnxo"]):
			ERROR("Particle radius set too large!","sxsort3d.py",1,myid)
####-----------------------------------------------------------------------------------------
		# Master directory
		if myid == main_node:
			if masterdir =="":
				timestring = strftime("_%d_%b_%Y_%H_%M_%S", localtime())
				masterdir ="master_sort3d"+timestring
			li =len(masterdir)
			cmd="{} {}".format("mkdir", masterdir)
			os.system(cmd)
		else:
			li=0
		li = mpi_bcast(li,1,MPI_INT,main_node,MPI_COMM_WORLD)[0]
		if li>0:
			masterdir = mpi_bcast(masterdir,li,MPI_CHAR,main_node,MPI_COMM_WORLD)
			import string
			masterdir = string.join(masterdir,"")
		if myid ==main_node:
			print_dict(Tracker["constants"],"Permanent settings of 3-D sorting program")
		######### create a vstack from input stack to the local stack in masterdir
		# stack name set to default
		Tracker["constants"]["stack"]       = "bdb:"+masterdir+"/rdata"
		Tracker["constants"]["ali3d"]       = os.path.join(masterdir, "ali3d_init.txt")
		Tracker["constants"]["ctf_params"]  = os.path.join(masterdir, "ctf_params.txt")
		Tracker["constants"]["partstack"]   = Tracker["constants"]["ali3d"]  # also serves for refinement
		if myid == main_node:
			total_stack = EMUtil.get_image_count(Tracker["orgstack"])
		else:
			total_stack = 0
		total_stack = bcast_number_to_all(total_stack, source_node = main_node)
		mpi_barrier(MPI_COMM_WORLD)
		from time import sleep
		while not os.path.exists(masterdir):
				print  "Node ",myid,"  waiting..."
				sleep(5)
		mpi_barrier(MPI_COMM_WORLD)
		if myid == main_node:
			log_main.add("Sphire sort3d ")
			log_main.add("the sort3d master directory is "+masterdir)
		#####
		###----------------------------------------------------------------------------------
		# Initial data analysis and handle two chunk files
		from random import shuffle
		# Compute the resolution 
		#### make chunkdir dictionary for computing margin of error
		import user_functions
		user_func  = user_functions.factory[Tracker["constants"]["user_func"]]
		chunk_dict = {}
		chunk_list = []
		if myid == main_node:
			chunk_one = read_text_file(os.path.join(Tracker["constants"]["chunkdir"],"chunk0.txt"))
			chunk_two = read_text_file(os.path.join(Tracker["constants"]["chunkdir"],"chunk1.txt"))
		else:
			chunk_one = 0
			chunk_two = 0
		chunk_one = wrap_mpi_bcast(chunk_one, main_node)
		chunk_two = wrap_mpi_bcast(chunk_two, main_node)
		mpi_barrier(MPI_COMM_WORLD)
		######################## Read/write bdb: data on main node ############################
	   	if myid==main_node:
			if(orgstack[:4] == "bdb:"):	cmd = "{} {} {}".format("e2bdb.py", orgstack,"--makevstack="+Tracker["constants"]["stack"])
			else:  cmd = "{} {} {}".format("sxcpy.py", orgstack, Tracker["constants"]["stack"])
	   		cmdexecute(cmd)
			cmd = "{} {} {}".format("sxheader.py  --params=xform.projection", "--export="+Tracker["constants"]["ali3d"],orgstack)
			cmdexecute(cmd)
			cmd = "{} {} {}".format("sxheader.py  --params=ctf", "--export="+Tracker["constants"]["ctf_params"],orgstack)
			cmdexecute(cmd)
		mpi_barrier(MPI_COMM_WORLD)	   		   	
		########-----------------------------------------------------------------------------
		Tracker["total_stack"]              = total_stack
		Tracker["constants"]["total_stack"] = total_stack
		Tracker["shrinkage"]                = float(Tracker["nxinit"])/Tracker["constants"]["nnxo"]
		Tracker["radius"]                   = Tracker["constants"]["radius"]*Tracker["shrinkage"]
		if Tracker["constants"]["mask3D"]:
			Tracker["mask3D"] = os.path.join(masterdir,"smask.hdf")
		else:
			Tracker["mask3D"]  = None
		if Tracker["constants"]["focus3Dmask"]:
			Tracker["focus3D"] = os.path.join(masterdir,"sfocus.hdf")
		else:
			Tracker["focus3D"] = None
		if myid == main_node:
			if Tracker["constants"]["mask3D"]:
				mask_3D = get_shrink_3dmask(Tracker["nxinit"],Tracker["constants"]["mask3D"])
				mask_3D.write_image(Tracker["mask3D"])
			if Tracker["constants"]["focus3Dmask"]:
				mask_3D = get_shrink_3dmask(Tracker["nxinit"],Tracker["constants"]["focus3Dmask"])
				st = Util.infomask(mask_3D, None, True)
				if( st[0] == 0.0 ):  ERROR("sxrsort3d","incorrect focused mask, after binarize all values zero",1)
				mask_3D.write_image(Tracker["focus3D"])
				del mask_3D
		if Tracker["constants"]["PWadjustment"] !='':
			PW_dict              = {}
			nxinit_pwsp          = sample_down_1D_curve(Tracker["constants"]["nxinit"],Tracker["constants"]["nnxo"],Tracker["constants"]["PWadjustment"])
			Tracker["nxinit_PW"] = os.path.join(masterdir,"spwp.txt")
			if myid == main_node:  write_text_file(nxinit_pwsp,Tracker["nxinit_PW"])
			PW_dict[Tracker["constants"]["nnxo"]]   = Tracker["constants"]["PWadjustment"]
			PW_dict[Tracker["constants"]["nxinit"]] = Tracker["nxinit_PW"]
			Tracker["PW_dict"]                      = PW_dict
		mpi_barrier(MPI_COMM_WORLD)
		#-----------------------From two chunks to FSC, and low pass filter-----------------------------------------###
		for element in chunk_one: chunk_dict[element] = 0
		for element in chunk_two: chunk_dict[element] = 1
		chunk_list =[chunk_one, chunk_two]
		Tracker["chunk_dict"] = chunk_dict
		Tracker["P_chunk0"]   = len(chunk_one)/float(total_stack)
		Tracker["P_chunk1"]   = len(chunk_two)/float(total_stack)
		### create two volumes to estimate resolution
		if myid == main_node:
			for index in xrange(2): write_text_file(chunk_list[index],os.path.join(masterdir,"chunk%01d.txt"%index))
		mpi_barrier(MPI_COMM_WORLD)
		vols = []
		for index in xrange(2):
			data,old_shifts = get_shrink_data_huang(Tracker,Tracker["constants"]["nxinit"], os.path.join(masterdir,"chunk%01d.txt"%index), Tracker["constants"]["partstack"],myid,main_node,nproc,preshift=True)
			vol             = recons3d_4nn_ctf_MPI(myid=myid, prjlist=data,symmetry=Tracker["constants"]["sym"], finfo=None)
			if myid == main_node:
				vol.write_image(os.path.join(masterdir, "vol%d.hdf"%index))
			vols.append(vol)
			mpi_barrier(MPI_COMM_WORLD)
		if myid ==main_node:
			low_pass, falloff,currentres = get_resolution_mrk01(vols,Tracker["constants"]["radius"],Tracker["constants"]["nxinit"],masterdir,Tracker["mask3D"])
			if low_pass >Tracker["constants"]["low_pass_filter"]: low_pass= Tracker["constants"]["low_pass_filter"]
		else:
			low_pass    =0.0
			falloff     =0.0
			currentres  =0.0
		bcast_number_to_all(currentres,source_node = main_node)
		bcast_number_to_all(low_pass,source_node   = main_node)
		bcast_number_to_all(falloff,source_node    = main_node)
		Tracker["currentres"]                      = currentres
		Tracker["falloff"]                         = falloff
		if Tracker["constants"]["low_pass_filter"] ==-1.0:
			Tracker["low_pass_filter"] = min(.45,low_pass/Tracker["shrinkage"]) # no better than .45
		else:
			Tracker["low_pass_filter"] = min(.45,Tracker["constants"]["low_pass_filter"]/Tracker["shrinkage"])
		Tracker["lowpass"]             = Tracker["low_pass_filter"]
		Tracker["falloff"]             =.1
		Tracker["global_fsc"]          = os.path.join(masterdir, "fsc.txt")
		############################################################################################
		if myid == main_node:
			log_main.add("The command-line inputs are as following:")
			log_main.add("**********************************************************")
		for a in sys.argv:
			if myid == main_node:log_main.add(a)
		if myid == main_node:
			log_main.add("number of cpus used in this run is %d"%Tracker["constants"]["nproc"])
			log_main.add("**********************************************************")
		from filter import filt_tanl
		### START 3-D sorting
		if myid ==main_node:
			log_main.add("----------3-D sorting  program------- ")
			log_main.add("current resolution %6.3f for images of original size in terms of absolute frequency"%Tracker["currentres"])
			log_main.add("equivalent to %f Angstrom resolution"%(Tracker["constants"]["pixel_size"]/Tracker["currentres"]/Tracker["shrinkage"]))
			log_main.add("the user provided enforced low_pass_filter is %f"%Tracker["constants"]["low_pass_filter"])
			#log_main.add("equivalent to %f Angstrom resolution"%(Tracker["constants"]["pixel_size"]/Tracker["constants"]["low_pass_filter"]))
			for index in xrange(2):
				filt_tanl(get_im(os.path.join(masterdir,"vol%01d.hdf"%index)), Tracker["low_pass_filter"],Tracker["falloff"]).write_image(os.path.join(masterdir, "volf%01d.hdf"%index))
		mpi_barrier(MPI_COMM_WORLD)
		from utilities import get_input_from_string
		delta       = get_input_from_string(Tracker["constants"]["delta"])
		delta       = delta[0]
		from utilities import even_angles
		n_angles    = even_angles(delta, 0, 180)
		this_ali3d  = Tracker["constants"]["ali3d"]
		sampled     = get_stat_proj(Tracker,delta,this_ali3d)
		if myid ==main_node:
			nc = 0
			for a in sampled:
				if len(sampled[a])>0:
					nc += 1
			log_main.add("total sampled direction %10d  at angle step %6.3f"%(len(n_angles), delta)) 
			log_main.add("captured sampled directions %10d percentage covered by data  %6.3f"%(nc,float(nc)/len(n_angles)*100))
		number_of_images_per_group = Tracker["constants"]["number_of_images_per_group"]
		if myid ==main_node: log_main.add("user provided number_of_images_per_group %d"%number_of_images_per_group)
		Tracker["number_of_images_per_group"] = number_of_images_per_group
		number_of_groups = get_number_of_groups(total_stack,number_of_images_per_group)
		Tracker["number_of_groups"] =  number_of_groups
		generation     =0
		partition_dict ={}
		full_dict      ={}
		workdir =os.path.join(masterdir,"generation%03d"%generation)
		Tracker["this_dir"] = workdir
		if myid ==main_node:
			log_main.add("---- generation         %5d"%generation)
			log_main.add("number of images per group is set as %d"%number_of_images_per_group)
			log_main.add("the initial number of groups is  %10d "%number_of_groups)
			cmd="{} {}".format("mkdir",workdir)
			os.system(cmd)
		mpi_barrier(MPI_COMM_WORLD)
		list_to_be_processed = range(Tracker["constants"]["total_stack"])
		Tracker["this_data_list"] = list_to_be_processed
		create_random_list(Tracker)
		#################################
		full_dict ={}
		for iptl in xrange(Tracker["constants"]["total_stack"]):
			 full_dict[iptl]    = iptl
		Tracker["full_ID_dict"] = full_dict
		################################# 	
		for indep_run in xrange(Tracker["constants"]["indep_runs"]):
			Tracker["this_particle_list"] = Tracker["this_indep_list"][indep_run]
			ref_vol =  recons_mref(Tracker)
			if myid == main_node: log_main.add("independent run  %10d"%indep_run)
			mpi_barrier(MPI_COMM_WORLD)
			Tracker["this_data_list"]          = list_to_be_processed
			Tracker["total_stack"]             = len(Tracker["this_data_list"])
			Tracker["this_particle_text_file"] = os.path.join(workdir,"independent_list_%03d.txt"%indep_run) # for get_shrink_data
			if myid == main_node: write_text_file(Tracker["this_data_list"], Tracker["this_particle_text_file"])
			mpi_barrier(MPI_COMM_WORLD)
			outdir  = os.path.join(workdir, "EQ_Kmeans%03d"%indep_run)
			ref_vol = apply_low_pass_filter(ref_vol,Tracker)
			mref_ali3d_EQ_Kmeans(ref_vol, outdir, Tracker["this_particle_text_file"], Tracker)
			partition_dict[indep_run]=Tracker["this_partition"]
		Tracker["partition_dict"]    = partition_dict
		Tracker["total_stack"]       = len(Tracker["this_data_list"])
		Tracker["this_total_stack"]  = Tracker["total_stack"]
		###############################
		do_two_way_comparison(Tracker)
		###############################
		ref_vol_list = []
		from time import sleep
		number_of_ref_class = []
		for igrp in xrange(len(Tracker["two_way_stable_member"])):
			Tracker["this_data_list"]      = Tracker["two_way_stable_member"][igrp]
			Tracker["this_data_list_file"] = os.path.join(workdir,"stable_class%d.txt"%igrp)
			if myid == main_node:
				write_text_file(Tracker["this_data_list"], Tracker["this_data_list_file"])
			data,old_shifts = get_shrink_data_huang(Tracker,Tracker["nxinit"], Tracker["this_data_list_file"], Tracker["constants"]["partstack"], myid, main_node, nproc, preshift = True)
			volref          = recons3d_4nn_ctf_MPI(myid=myid, prjlist = data, symmetry=Tracker["constants"]["sym"], finfo = None)
			ref_vol_list.append(volref)
			number_of_ref_class.append(len(Tracker["this_data_list"]))
			if myid == main_node:
				log_main.add("group  %d  members %d "%(igrp,len(Tracker["this_data_list"])))
		Tracker["number_of_ref_class"] = number_of_ref_class
		nx_of_image = ref_vol_list[0].get_xsize()
		if Tracker["constants"]["PWadjustment"]:
			Tracker["PWadjustment"] = Tracker["PW_dict"][nx_of_image]
		else:
			Tracker["PWadjustment"] = Tracker["constants"]["PWadjustment"]	 # no PW adjustment
		if myid == main_node:
			for iref in xrange(len(ref_vol_list)):
				refdata    = [None]*4
				refdata[0] = ref_vol_list[iref]
				refdata[1] = Tracker
				refdata[2] = Tracker["constants"]["myid"]
				refdata[3] = Tracker["constants"]["nproc"]
				volref     = user_func(refdata)
				volref.write_image(os.path.join(workdir,"volf_stable.hdf"),iref)
		mpi_barrier(MPI_COMM_WORLD)
		Tracker["this_data_list"]           = Tracker["this_accounted_list"]
		outdir                              = os.path.join(workdir,"Kmref")  
		empty_group, res_groups, final_list = ali3d_mref_Kmeans_MPI(ref_vol_list,outdir,Tracker["this_accounted_text"],Tracker)
		Tracker["this_unaccounted_list"]    = get_complementary_elements(list_to_be_processed,final_list)
		if myid == main_node:
			log_main.add("the number of particles not processed is %d"%len(Tracker["this_unaccounted_list"]))
			write_text_file(Tracker["this_unaccounted_list"],Tracker["this_unaccounted_text"])
		update_full_dict(Tracker["this_unaccounted_list"], Tracker)
		#######################################
		number_of_groups    = len(res_groups)
		vol_list            = []
		number_of_ref_class = []
		for igrp in xrange(number_of_groups):
			data,old_shifts = get_shrink_data_huang(Tracker, Tracker["constants"]["nnxo"], os.path.join(outdir,"Class%d.txt"%igrp), Tracker["constants"]["partstack"],myid,main_node,nproc,preshift = True)
			volref          = recons3d_4nn_ctf_MPI(myid=myid, prjlist = data, symmetry=Tracker["constants"]["sym"], finfo=None)
			vol_list.append(volref)

			if( myid == main_node ):  npergroup = len(read_text_file(os.path.join(outdir,"Class%d.txt"%igrp)))
			else:  npergroup = 0
			npergroup = bcast_number_to_all(npergroup, main_node )
			number_of_ref_class.append(npergroup)

		Tracker["number_of_ref_class"] = number_of_ref_class
		
		mpi_barrier(MPI_COMM_WORLD)
		nx_of_image = vol_list[0].get_xsize()
		if Tracker["constants"]["PWadjustment"]:
			Tracker["PWadjustment"]=Tracker["PW_dict"][nx_of_image]
		else:
			Tracker["PWadjustment"]=Tracker["constants"]["PWadjustment"]	

		if myid == main_node:
			for ivol in xrange(len(vol_list)):
				refdata     =[None]*4
				refdata[0] = vol_list[ivol]
				refdata[1] = Tracker
				refdata[2] = Tracker["constants"]["myid"]
				refdata[3] = Tracker["constants"]["nproc"] 
				volref = user_func(refdata)
				volref.write_image(os.path.join(workdir,"volf_of_Classes.hdf"),ivol)
				log_main.add("number of unaccounted particles  %10d"%len(Tracker["this_unaccounted_list"]))
				log_main.add("number of accounted particles  %10d"%len(Tracker["this_accounted_list"]))
				
		Tracker["this_data_list"]    = Tracker["this_unaccounted_list"]   # reset parameters for the next round calculation
		Tracker["total_stack"]       = len(Tracker["this_unaccounted_list"])
		Tracker["this_total_stack"]  = Tracker["total_stack"]
		number_of_groups             = get_number_of_groups(len(Tracker["this_unaccounted_list"]),number_of_images_per_group)
		Tracker["number_of_groups"]  =  number_of_groups
		while number_of_groups >= 2 :
			generation     +=1
			partition_dict ={}
			workdir =os.path.join(masterdir,"generation%03d"%generation)
			Tracker["this_dir"] = workdir
			if myid ==main_node:
				log_main.add("*********************************************")
				log_main.add("-----    generation             %5d    "%generation)
				log_main.add("number of images per group is set as %10d "%number_of_images_per_group)
				log_main.add("the number of groups is  %10d "%number_of_groups)
				log_main.add(" number of particles for clustering is %10d"%Tracker["total_stack"])
				cmd ="{} {}".format("mkdir",workdir)
				os.system(cmd)
			mpi_barrier(MPI_COMM_WORLD)
			create_random_list(Tracker)
			for indep_run in xrange(Tracker["constants"]["indep_runs"]):
				Tracker["this_particle_list"] = Tracker["this_indep_list"][indep_run]
				ref_vol                       = recons_mref(Tracker)
				if myid == main_node:
					log_main.add("independent run  %10d"%indep_run)
					outdir = os.path.join(workdir, "EQ_Kmeans%03d"%indep_run)
				Tracker["this_data_list"]   = Tracker["this_unaccounted_list"]
				#ref_vol=apply_low_pass_filter(ref_vol,Tracker)
				mref_ali3d_EQ_Kmeans(ref_vol,outdir,Tracker["this_unaccounted_text"],Tracker)
				partition_dict[indep_run]   = Tracker["this_partition"]
				Tracker["this_data_list"]   = Tracker["this_unaccounted_list"]
				Tracker["total_stack"]      = len(Tracker["this_unaccounted_list"])
				Tracker["partition_dict"]   = partition_dict
				Tracker["this_total_stack"] = Tracker["total_stack"]
			total_list_of_this_run          = Tracker["this_unaccounted_list"]
			###############################
			do_two_way_comparison(Tracker)
			###############################
			ref_vol_list        = []
			number_of_ref_class = []
			for igrp in xrange(len(Tracker["two_way_stable_member"])):
				Tracker["this_data_list"]      = Tracker["two_way_stable_member"][igrp]
				Tracker["this_data_list_file"] = os.path.join(workdir,"stable_class%d.txt"%igrp)
				if myid == main_node: write_text_file(Tracker["this_data_list"], Tracker["this_data_list_file"])
				mpi_barrier(MPI_COMM_WORLD)
				data,old_shifts  = get_shrink_data_huang(Tracker,Tracker["constants"]["nxinit"],Tracker["this_data_list_file"],Tracker["constants"]["partstack"],myid,main_node,nproc,preshift = True)
				volref           = recons3d_4nn_ctf_MPI(myid=myid, prjlist = data, symmetry=Tracker["constants"]["sym"],finfo= None)
				#volref = filt_tanl(volref, Tracker["constants"]["low_pass_filter"],.1)
				if myid == main_node:volref.write_image(os.path.join(workdir,"vol_stable.hdf"),iref)
				#volref = resample(volref,Tracker["shrinkage"])
				ref_vol_list.append(volref)
				number_of_ref_class.append(len(Tracker["this_data_list"]))
				mpi_barrier(MPI_COMM_WORLD)
			Tracker["number_of_ref_class"]      = number_of_ref_class
			Tracker["this_data_list"]           = Tracker["this_accounted_list"]
			outdir                              = os.path.join(workdir,"Kmref")
			empty_group, res_groups, final_list = ali3d_mref_Kmeans_MPI(ref_vol_list,outdir,Tracker["this_accounted_text"],Tracker)
			# calculate the 3-D structure of original image size for each group
			number_of_groups                    =  len(res_groups)
			Tracker["this_unaccounted_list"]    = get_complementary_elements(total_list_of_this_run,final_list)
			if myid == main_node:
				log_main.add("the number of particles not processed is %d"%len(Tracker["this_unaccounted_list"]))
				write_text_file(Tracker["this_unaccounted_list"],Tracker["this_unaccounted_text"])
			mpi_barrier(MPI_COMM_WORLD)
			update_full_dict(Tracker["this_unaccounted_list"],Tracker)
			vol_list = []
			for igrp in xrange(number_of_groups):
				data,old_shifts = get_shrink_data_huang(Tracker,Tracker["constants"]["nnxo"], os.path.join(outdir,"Class%d.txt"%igrp), Tracker["constants"]["partstack"], myid, main_node, nproc,preshift = True)
				volref = recons3d_4nn_ctf_MPI(myid=myid, prjlist = data, symmetry=Tracker["constants"]["sym"],finfo= None)
				vol_list.append(volref)

			mpi_barrier(MPI_COMM_WORLD)
			nx_of_image=ref_vol_list[0].get_xsize()
			if Tracker["constants"]["PWadjustment"]:
				Tracker["PWadjustment"] = Tracker["PW_dict"][nx_of_image]
			else:
				Tracker["PWadjustment"] = Tracker["constants"]["PWadjustment"]	

			if myid == main_node:
				for ivol in xrange(len(vol_list)):
					refdata    = [None]*4
					refdata[0] = vol_list[ivol]
					refdata[1] = Tracker
					refdata[2] = Tracker["constants"]["myid"]
					refdata[3] = Tracker["constants"]["nproc"] 
					volref     = user_func(refdata)
					volref.write_image(os.path.join(workdir, "volf_of_Classes.hdf"),ivol)
				log_main.add("number of unaccounted particles  %10d"%len(Tracker["this_unaccounted_list"]))
				log_main.add("number of accounted particles  %10d"%len(Tracker["this_accounted_list"]))
			del vol_list
			mpi_barrier(MPI_COMM_WORLD)
			number_of_groups            = get_number_of_groups(len(Tracker["this_unaccounted_list"]),number_of_images_per_group)
			Tracker["number_of_groups"] =  number_of_groups
			Tracker["this_data_list"]   = Tracker["this_unaccounted_list"]
			Tracker["total_stack"]      = len(Tracker["this_unaccounted_list"])
		if Tracker["constants"]["unaccounted"]:
			data,old_shifts = get_shrink_data_huang(Tracker,Tracker["constants"]["nnxo"],Tracker["this_unaccounted_text"],Tracker["constants"]["partstack"],myid,main_node,nproc,preshift = True)
			volref          = recons3d_4nn_ctf_MPI(myid=myid, prjlist = data, symmetry=Tracker["constants"]["sym"],finfo= None)
			nx_of_image     = volref.get_xsize()
			if Tracker["constants"]["PWadjustment"]:
				Tracker["PWadjustment"]=Tracker["PW_dict"][nx_of_image]
			else:
				Tracker["PWadjustment"]=Tracker["constants"]["PWadjustment"]	
			if( myid == main_node ):
				refdata    = [None]*4
				refdata[0] = volref
				refdata[1] = Tracker
				refdata[2] = Tracker["constants"]["myid"]
				refdata[3] = Tracker["constants"]["nproc"]
				volref     = user_func(refdata)
				#volref    = filt_tanl(volref, Tracker["constants"]["low_pass_filter"],.1)
				volref.write_image(os.path.join(workdir,"volf_unaccounted.hdf"))
		# Finish program
		if myid ==main_node: log_main.add("sxsort3d finishes")
		mpi_barrier(MPI_COMM_WORLD)
		from mpi import mpi_finalize
		mpi_finalize()
		exit()
Beispiel #12
0
def main():
    program_name = os.path.basename(sys.argv[0])
    usage = program_name + """  input_image_path  output_directory  --selection_list=selection_list  --wn=CTF_WINDOW_SIZE --apix=PIXEL_SIZE  --Cs=CS  --voltage=VOLTAGE  --ac=AMP_CONTRAST  --f_start=FREA_START  --f_stop=FREQ_STOP  --vpp  --kboot=KBOOT  --overlap_x=OVERLAP_X  --overlap_y=OVERLAP_Y  --edge_x=EDGE_X  --edge_y=EDGE_Y  --check_consistency  --stack_mode  --debug_mode

Automated estimation of CTF parameters with error assessment.

All Micrographs Mode - Process all micrographs in a directory: 
	Specify a list of input micrographs using a wild card (*), called here input micrographs path pattern. 
	Use the wild card to indicate the place of variable part of the file names (e.g. serial number, time stamp, and etc). 
	Running from the command line requires enclosing the string by single quotes (') or double quotes ("). 
	sxgui.py will automatically adds single quotes to the string. 
	BDB files can not be selected as input micrographs. 
	Then, specify output directory where all outputs should be saved. 
	In this mode, all micrographs matching the path pattern will be processed.

	mpirun -np 16 sxcter.py './mic*.hdf' outdir_cter --wn=512 --apix=2.29 --Cs=2.0 --voltage=300 --ac=10.0

Selected Micrographs Mode - Process all micrographs in a selection list file:
	In addition to input micrographs path pattern and output directry arguments, 
	specify a name of micrograph selection list text file using --selection_list option 
	(e.g. output of sxgui_unblur.py or sxgui_cter.py). The file extension must be ".txt". 
	In this mode, only micrographs in the selection list which matches the file name part of the pattern (ignoring the directory paths) will be processed. 
	If a micrograph name in the selection list does not exists in the directory specified by the micrograph path pattern, processing of the micrograph will be skipped.

	mpirun -np 16 sxcter.py './mic*.hdf' outdir_cter --selection_list=mic_list.txt --wn=512 --apix=2.29 --Cs=2.0 --voltage=300 --ac=10.0

Single Micrograph Mode - Process a single micrograph: 
	In addition to input micrographs path pattern and output directry arguments, 
	specify a single micrograph name using --selection_list option. 
	In this mode, only the specified single micrograph will be processed. 
	If this micrograph name does not matches the file name part of the pattern (ignoring the directory paths), the process will exit without processing it. 
	If this micrograph name matches the file name part of the pattern but does not exists in the directory which specified by the micrograph path pattern, again the process will exit without processing it. 
	Use single processor for this mode.

	sxcter.py './mic*.hdf' outdir_cter --selection_list=mic0.hdf --wn=512 --apix=2.29 --Cs=2.0 --voltage=300 --ac=10.0

Stack Mode - Process a particle stack (Not supported by SPHIRE GUI)):: 
	Use --stack_mode option, then specify the path of particle stack file (without wild card "*") and output directory as arguments. 
	This mode ignores --selection_list, --wn --overlap_x, --overlap_y, --edge_x, and --edge_y options. 
	Use single processor for this mode. Not supported by SPHIRE GUI (sxgui.py). 

	sxcter.py bdb:stack outdir_cter --apix=2.29 --Cs=2.0 --voltage=300 --ac=10.0 --stack_mode

"""
    parser = OptionParser(usage, version=SPARXVERSION)
    parser.add_option(
        "--selection_list",
        type="string",
        default=None,
        help=
        "Micrograph selecting list: Specify path of a micrograph selection list text file for Selected Micrographs Mode. The file extension must be \'.txt\'. Alternatively, the file name of a single micrograph can be specified for Single Micrograph Mode. (default none)"
    )
    parser.add_option(
        "--wn",
        type="int",
        default=512,
        help=
        "CTF window size [pixels]: The size should be slightly larger than particle box size. This will be ignored in Stack Mode. (default 512)"
    )
    parser.add_option(
        "--apix",
        type="float",
        default=-1.0,
        help=
        "Pixel size [A/Pixels]: The pixel size of input micrograph(s) or images in input particle stack. (default -1.0)"
    )
    parser.add_option(
        "--Cs",
        type="float",
        default=2.0,
        help=
        "Microscope spherical aberration (Cs) [mm]: The spherical aberration (Cs) of microscope used for imaging. (default 2.0)"
    )
    parser.add_option(
        "--voltage",
        type="float",
        default=300.0,
        help=
        "Microscope voltage [kV]: The acceleration voltage of microscope used for imaging. (default 300.0)"
    )
    parser.add_option(
        "--ac",
        type="float",
        default=10.0,
        help=
        "Amplitude contrast [%]: The typical amplitude contrast is in the range of 7% - 14%. The value mainly depends on the thickness of the ice embedding the particles. (default 10.0)"
    )
    parser.add_option(
        "--f_start",
        type="float",
        default=-1.0,
        help=
        "Lowest frequency [1/A]: Lowest frequency to be considered in the CTF estimation. Determined automatically by default. (default -1.0)"
    )
    parser.add_option(
        "--f_stop",
        type="float",
        default=-1.0,
        help=
        "Highest frequency [1/A]: Highest frequency to be considered in the CTF estimation. Determined automatically by default. (default -1.0)"
    )
    parser.add_option(
        "--kboot",
        type="int",
        default=16,
        help=
        "Number of CTF estimates per micrograph: Used for error assessment. (default 16)"
    )
    parser.add_option(
        "--overlap_x",
        type="int",
        default=50,
        help=
        "X overlap [%]: Overlap between the windows in the x direction. This will be ignored in Stack Mode. (default 50)"
    )
    parser.add_option(
        "--overlap_y",
        type="int",
        default=50,
        help=
        "Y overlap [%]: Overlap between the windows in the y direction. This will be ignored in Stack Mode. (default 50)"
    )
    parser.add_option(
        "--edge_x",
        type="int",
        default=0,
        help=
        "Edge x [pixels]: Defines the edge of the tiling area in the x direction. Normally it does not need to be modified. This will be ignored in Stack Mode. (default 0)"
    )
    parser.add_option(
        "--edge_y",
        type="int",
        default=0,
        help=
        "Edge y [pixels]: Defines the edge of the tiling area in the y direction. Normally it does not need to be modified. This will be ignored in Stack Mode. (default 0)"
    )
    parser.add_option(
        "--check_consistency",
        action="store_true",
        default=False,
        help=
        "Check consistency of inputs: Create a text file containing the list of inconsistent Micrograph ID entries (i.e. inconsist_mic_list_file.txt). (default False)"
    )
    parser.add_option(
        "--stack_mode",
        action="store_true",
        default=False,
        help=
        "Use stack mode: Use a stack as the input. Please set the file path of a stack as the first argument and output directory for the second argument. This is advanced option. Not supported by sxgui. (default False)"
    )
    parser.add_option(
        "--debug_mode",
        action="store_true",
        default=False,
        help="Enable debug mode: Print out debug information. (default False)")
    parser.add_option(
        "--vpp",
        action="store_true",
        default=False,
        help="Volta Phase Plate - fit smplitude contrast. (default False)")
    parser.add_option("--defocus_min",
                      type="float",
                      default=0.3,
                      help="Minimum defocus search [um] (default 0.3)")
    parser.add_option("--defocus_max",
                      type="float",
                      default=9.0,
                      help="Maximum defocus search [um] (default 9.0)")
    parser.add_option("--defocus_step",
                      type="float",
                      default=0.1,
                      help="Step defocus search [um] (default 0.1)")
    parser.add_option("--phase_min",
                      type="float",
                      default=5.0,
                      help="Minimum phase search [degrees] (default 5.0)")
    parser.add_option("--phase_max",
                      type="float",
                      default=175.0,
                      help="Maximum phase search [degrees] (default 175.0)")
    parser.add_option("--phase_step",
                      type="float",
                      default=5.0,
                      help="Step phase search [degrees] (default 5.0)")
    parser.add_option("--pap",
                      action="store_true",
                      default=False,
                      help="Use power spectrum for fitting. (default False)")

    (options, args) = parser.parse_args(sys.argv[1:])

    # ====================================================================================
    # Prepare processing
    # ====================================================================================
    # ------------------------------------------------------------------------------------
    # Set up MPI related variables
    # ------------------------------------------------------------------------------------
    # Detect if program is running under MPI
    RUNNING_UNDER_MPI = "OMPI_COMM_WORLD_SIZE" in os.environ

    main_mpi_proc = 0
    if RUNNING_UNDER_MPI:
        from mpi import mpi_init, mpi_comm_rank, mpi_comm_size, mpi_barrier, MPI_COMM_WORLD

        sys.argv = mpi_init(len(sys.argv), sys.argv)
        my_mpi_proc_id = mpi_comm_rank(MPI_COMM_WORLD)
        n_mpi_procs = mpi_comm_size(MPI_COMM_WORLD)
        global_def.MPI = True

    else:
        my_mpi_proc_id = 0
        n_mpi_procs = 1

    # ------------------------------------------------------------------------------------
    # Set up SPHIRE global definitions
    # ------------------------------------------------------------------------------------
    if global_def.CACHE_DISABLE:
        from utilities import disable_bdb_cache
        disable_bdb_cache()

    # Change the name log file for error message
    original_logfilename = global_def.LOGFILE
    global_def.LOGFILE = os.path.splitext(
        program_name)[0] + '_' + original_logfilename + '.txt'

    # ------------------------------------------------------------------------------------
    # Check error conditions of arguments and options, then prepare variables for arguments
    # ------------------------------------------------------------------------------------
    input_image_path = None
    output_directory = None
    # not a real while, an if with the opportunity to use break when errors need to be reported
    error_status = None
    while True:
        # --------------------------------------------------------------------------------
        # Check the number of arguments. If OK, then prepare variables for them
        # --------------------------------------------------------------------------------
        if len(args) != 2:
            error_status = (
                "Please check usage for number of arguments.\n Usage: " +
                usage + "\n" + "Please run %s -h for help." % (program_name),
                getframeinfo(currentframe()))
            break

        # NOTE: 2015/11/27 Toshio Moriya
        # Require single quotes (') or double quotes (") when input micrograph pattern is give for input_image_path
        #  so that sys.argv does not automatically expand wild card and create a list of file names
        #
        input_image_path = args[0]
        output_directory = args[1]

        # --------------------------------------------------------------------------------
        # NOTE: 2016/03/17 Toshio Moriya
        # cter_mrk() will take care of all the error conditions
        # --------------------------------------------------------------------------------

        break
    if_error_then_all_processes_exit_program(error_status)
    #  Toshio, please see how to make it informative
    assert input_image_path != None, " directory  missing  input_image_path"
    assert output_directory != None, " directory  missing  output_directory"

    if options.vpp == False:
        wrong_params = False
        import string as str
        vpp_options = [
            "--defocus_min", "--defocus_max", "--defocus_step", "--phase_min",
            "--phase_max", "--phase_step"
        ]
        for command_token in sys.argv:
            for vppo in vpp_options:
                if str.find(command_token, vppo) > -1: wrong_params = True
                if wrong_params: break
            if wrong_params: break
        if wrong_params:
            ERROR(
                "Some options are valid only for Volta Phase Plate command  %s"
                % command_token, "sxcter", 1, my_mpi_proc_id)

    if my_mpi_proc_id == main_mpi_proc:
        command_line = ""
        for command_token in sys.argv:
            command_line += command_token + "  "
        print(" ")
        print("Shell line command:")
        print(command_line)

    if options.vpp:
        vpp_options = [
            options.defocus_min, options.defocus_max, options.defocus_step,
            options.phase_min, options.phase_max, options.phase_step
        ]
        from morphology import cter_vpp
        result = cter_vpp(input_image_path, output_directory,
                          options.selection_list, options.wn, options.apix,
                          options.Cs, options.voltage, options.ac,
                          options.f_start, options.f_stop, options.kboot,
                          options.overlap_x, options.overlap_y, options.edge_x,
                          options.edge_y, options.check_consistency,
                          options.stack_mode, options.debug_mode, program_name,
                          vpp_options, RUNNING_UNDER_MPI, main_mpi_proc,
                          my_mpi_proc_id, n_mpi_procs)
    elif options.pap:
        from morphology import cter_pap
        result = cter_pap(input_image_path, output_directory,
                          options.selection_list, options.wn, options.apix,
                          options.Cs, options.voltage, options.ac,
                          options.f_start, options.f_stop, options.kboot,
                          options.overlap_x, options.overlap_y, options.edge_x,
                          options.edge_y, options.check_consistency,
                          options.stack_mode, options.debug_mode, program_name,
                          RUNNING_UNDER_MPI, main_mpi_proc, my_mpi_proc_id,
                          n_mpi_procs)
    else:
        from morphology import cter_mrk
        result = cter_mrk(input_image_path, output_directory,
                          options.selection_list, options.wn, options.apix,
                          options.Cs, options.voltage, options.ac,
                          options.f_start, options.f_stop, options.kboot,
                          options.overlap_x, options.overlap_y, options.edge_x,
                          options.edge_y, options.check_consistency,
                          options.stack_mode, options.debug_mode, program_name,
                          RUNNING_UNDER_MPI, main_mpi_proc, my_mpi_proc_id,
                          n_mpi_procs)

    if RUNNING_UNDER_MPI:
        mpi_barrier(MPI_COMM_WORLD)

    if main_mpi_proc == my_mpi_proc_id:
        if options.debug_mode:
            print("Returned value from cter_mrk() := ", result)
        print(" ")
        print("DONE!!!")
        print(" ")

    # ====================================================================================
    # Clean up
    # ====================================================================================
    # ------------------------------------------------------------------------------------
    # Reset SPHIRE global definitions
    # ------------------------------------------------------------------------------------
    global_def.LOGFILE = original_logfilename

    # ------------------------------------------------------------------------------------
    # Clean up MPI related variables
    # ------------------------------------------------------------------------------------
    if RUNNING_UNDER_MPI:
        mpi_barrier(MPI_COMM_WORLD)
        from mpi import mpi_finalize
        mpi_finalize()

    sys.stdout.flush()
    sys.exit(0)
Beispiel #13
0
def run(args):

    progname = optparse.os.path.basename(sys.argv[0])
    usage = (
        progname
        + " stack  [output_directory] --ir=inner_radius --rs=ring_step --xr=x_range --yr=y_range  --ts=translational_search_step  --delta=angular_step --center=center_type --maxit1=max_iter1 --maxit2=max_iter2 --L2threshold=0.1 --ref_a=S --sym=c1"
    )
    usage += """

stack			2D images in a stack file: (default required string)
directory		output directory name: into which the results will be written (if it does not exist, it will be created, if it does exist, the results will be written possibly overwriting previous results) (default required string)
"""

    parser = optparse.OptionParser(usage, version=sp_global_def.SPARXVERSION)
    parser.add_option(
        "--radius",
        type="int",
        default=29,
        help="radius of the particle: has to be less than < int(nx/2)-1 (default 29)",
    )

    parser.add_option(
        "--xr",
        type="string",
        default="0",
        help="range for translation search in x direction: search is +/xr in pixels (default '0')",
    )
    parser.add_option(
        "--yr",
        type="string",
        default="0",
        help="range for translation search in y direction: if omitted will be set to xr, search is +/yr in pixels (default '0')",
    )
    parser.add_option(
        "--mask3D", type="string", default=None, help="3D mask file: (default sphere)"
    )
    parser.add_option(
        "--moon_elimination",
        type="string",
        default="",
        help="elimination of disconnected pieces: two arguments: mass in KDa and pixel size in px/A separated by comma, no space (default none)",
    )
    parser.add_option(
        "--ir",
        type="int",
        default=1,
        help="inner radius for rotational search: > 0 (default 1)",
    )

    # 'radius' and 'ou' are the same as per Pawel's request; 'ou' is hidden from the user
    # the 'ou' variable is not changed to 'radius' in the 'sparx' program. This change is at interface level only for sxviper.
    ##### XXXXXXXXXXXXXXXXXXXXXX option does not exist in docs XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
    parser.add_option("--ou", type="int", default=-1, help=optparse.SUPPRESS_HELP)

    parser.add_option(
        "--rs",
        type="int",
        default=1,
        help="step between rings in rotational search: >0 (default 1)",
    )
    parser.add_option(
        "--ts",
        type="string",
        default="1.0",
        help="step size of the translation search in x-y directions: search is -xr, -xr+ts, 0, xr-ts, xr, can be fractional (default '1.0')",
    )
    parser.add_option(
        "--delta",
        type="string",
        default="2.0",
        help="angular step of reference projections: (default '2.0')",
    )
    parser.add_option(
        "--center",
        type="float",
        default=-1.0,
        help="centering of 3D template: average shift method; 0: no centering; 1: center of gravity (default -1.0)",
    )
    parser.add_option(
        "--maxit1",
        type="int",
        default=400,
        help="maximum number of iterations performed for the GA part: (default 400)",
    )
    parser.add_option(
        "--maxit2",
        type="int",
        default=50,
        help="maximum number of iterations performed for the finishing up part: (default 50)",
    )
    parser.add_option(
        "--L2threshold",
        type="float",
        default=0.03,
        help="stopping criterion of GA: given as a maximum relative dispersion of volumes' L2 norms: (default 0.03)",
    )
    parser.add_option(
        "--ref_a",
        type="string",
        default="S",
        help="method for generating the quasi-uniformly distributed projection directions: (default S)",
    )
    parser.add_option(
        "--sym",
        type="string",
        default="c1",
        help="point-group symmetry of the structure: (default c1)",
    )

    # parser.add_option("--function", type="string", default="ref_ali3d",         help="name of the reference preparation function (ref_ali3d by default)")
    ##### XXXXXXXXXXXXXXXXXXXXXX option does not exist in docs XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
    parser.add_option(
        "--function", type="string", default="ref_ali3d", help=optparse.SUPPRESS_HELP
    )

    parser.add_option(
        "--nruns",
        type="int",
        default=6,
        help="GA population: aka number of quasi-independent volumes (default 6)",
    )
    parser.add_option(
        "--doga",
        type="float",
        default=0.1,
        help="do GA when fraction of orientation changes less than 1.0 degrees is at least doga: (default 0.1)",
    )
    ##### XXXXXXXXXXXXXXXXXXXXXX option does not exist in docs XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
    parser.add_option(
        "--npad",
        type="int",
        default=2,
        help="padding size for 3D reconstruction (default=2)",
    )
    parser.add_option(
        "--fl",
        type="float",
        default=0.25,
        help="cut-off frequency applied to the template volume: using a hyperbolic tangent low-pass filter (default 0.25)",
    )
    parser.add_option(
        "--aa",
        type="float",
        default=0.1,
        help="fall-off of hyperbolic tangent low-pass filter: (default 0.1)",
    )
    parser.add_option(
        "--pwreference",
        type="string",
        default="",
        help="text file with a reference power spectrum: (default none)",
    )
    parser.add_option(
        "--debug",
        action="store_true",
        default=False,
        help="debug info printout: (default False)",
    )

    ##### XXXXXXXXXXXXXXXXXXXXXX option does not exist in docs XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
    parser.add_option(
        "--return_options",
        action="store_true",
        dest="return_options",
        default=False,
        help=optparse.SUPPRESS_HELP,
    )

    parser.add_option("--filament_width", default=-1)
    # parser.add_option("--an",       type="string", default= "-1",               help="NOT USED angular neighborhood for local searches (phi and theta)")
    # parser.add_option("--CTF",      action="store_true", default=False,         help="NOT USED Consider CTF correction during the alignment ")
    # parser.add_option("--snr",      type="float",  default= 1.0,                help="NOT USED Signal-to-Noise Ratio of the data (default 1.0)")
    # (options, args) = parser.parse_args(sys.argv[1:])

    required_option_list = ["radius"]
    (options, args) = parser.parse_args(args)
    # option_dict = vars(options)
    # print parser

    if options.return_options:
        return parser

    if options.moon_elimination == "":
        options.moon_elimination = []
    else:
        options.moon_elimination = list(map(float, options.moon_elimination.split(",")))

    # Making sure all required options appeared.
    for required_option in required_option_list:
        if not options.__dict__[required_option]:
            sp_global_def.sxprint(
                "\n ==%s== mandatory option is missing.\n" % required_option
            )
            sp_global_def.sxprint(
                "Please run '" + progname + " -h' for detailed options"
            )
            sp_global_def.ERROR("Missing parameter. Please see above")
            return

    if len(args) < 2 or len(args) > 3:
        sp_global_def.sxprint("Usage: " + usage)
        sp_global_def.sxprint("Please run '" + progname + " -h' for detailed options")
        sp_global_def.ERROR(
            "Invalid number of parameters used. Please see usage information above."
        )
        return

    log = sp_logger.Logger(sp_logger.BaseLogger_Files())

    # 'radius' and 'ou' are the same as per Pawel's request; 'ou' is hidden from the user
    # the 'ou' variable is not changed to 'radius' in the 'sparx' program. This change is at interface level only for sxviper.
    options.ou = options.radius
    runs_count = options.nruns
    mpi_rank = mpi.mpi_comm_rank(mpi.MPI_COMM_WORLD)
    mpi_size = mpi.mpi_comm_size(
        mpi.MPI_COMM_WORLD
    )  # Total number of processes, passed by --np option.

    if mpi_rank == 0:
        all_projs = EMAN2_cppwrap.EMData.read_images(args[0])
        subset = list(range(len(all_projs)))
        # if mpi_size > len(all_projs):
        # 	ERROR('Number of processes supplied by --np needs to be less than or equal to %d (total number of images) ' % len(all_projs), 'sxviper', 1)
        # 	mpi.mpi_finalize()
        # 	return
    else:
        all_projs = None
        subset = None

    outdir = args[1]
    error = 0
    if mpi_rank == 0:
        if mpi_size % options.nruns != 0:
            sp_global_def.ERROR(
                "Number of processes needs to be a multiple of total number of runs. Total runs by default are 3, you can change it by specifying --nruns option.",
                action=0,
            )
            error = 1

        if optparse.os.path.exists(outdir):
            sp_global_def.ERROR(
                "Output directory '%s' exists, please change the name and restart the program"
                % outdir,
                action=0,
            )
            error = 1
        sp_global_def.LOGFILE = optparse.os.path.join(outdir, sp_global_def.LOGFILE)

    mpi.mpi_barrier(mpi.MPI_COMM_WORLD)
    error = sp_utilities.bcast_number_to_all(
        error, source_node=0, mpi_comm=mpi.MPI_COMM_WORLD
    )
    if error == 1:
        return

    if mpi_rank == 0:
        optparse.os.makedirs(outdir)
        sp_global_def.write_command(outdir)

    if outdir[-1] != "/":
        outdir += "/"
    log.prefix = outdir

    # if len(args) > 2:
    # 	ref_vol = get_im(args[2])
    # else:
    # ref_vol = None

    options.user_func = sp_user_functions.factory[options.function]

    options.CTF = False
    options.snr = 1.0
    options.an = -1.0
    options.filament_width = -1
    out_params, out_vol, out_peaks = sp_multi_shc.multi_shc(
        all_projs, subset, runs_count, options, mpi_comm=mpi.MPI_COMM_WORLD, log=log
    )
Beispiel #14
0
def main():
	import	global_def
	from	optparse 	import OptionParser
	from	EMAN2 		import EMUtil
	import	os
	import	sys
	from time import time

	progname = os.path.basename(sys.argv[0])
	usage = progname + " proj_stack output_averages --MPI"
	parser = OptionParser(usage, version=SPARXVERSION)

	parser.add_option("--img_per_group",type="int"         ,	default=100  ,				help="number of images per group" )
	parser.add_option("--radius", 		type="int"         ,	default=-1   ,				help="radius for alignment" )
	parser.add_option("--xr",           type="string"      ,    default="2 1",              help="range for translation search in x direction, search is +/xr")
	parser.add_option("--yr",           type="string"      ,    default="-1",               help="range for translation search in y direction, search is +/yr (default = same as xr)")
	parser.add_option("--ts",           type="string"      ,    default="1 0.5",            help="step size of the translation search in both directions, search is -xr, -xr+ts, 0, xr-ts, xr, can be fractional")
	parser.add_option("--iter", 		type="int"         ,	default=30,                 help="number of iterations within alignment (default = 30)" )
	parser.add_option("--num_ali",      type="int"     	   ,    default=5,         			help="number of alignments performed for stability (default = 5)" )
	parser.add_option("--thld_err",     type="float"       ,    default=1.0,         		help="threshold of pixel error (default = 1.732)" )
	parser.add_option("--grouping" , 	type="string"      ,	default="GRP",				help="do grouping of projections: PPR - per projection, GRP - different size groups, exclusive (default), GEV - grouping equal size")
	parser.add_option("--delta",        type="float"       ,    default=-1.0,         		help="angular step for reference projections (required for GEV method)")
	parser.add_option("--fl",           type="float"       ,    default=0.3,                help="cut-off frequency of hyperbolic tangent low-pass Fourier filter")
	parser.add_option("--aa",           type="float"       ,    default=0.2,                help="fall-off of hyperbolic tangent low-pass Fourier filter")
	parser.add_option("--CTF",          action="store_true",    default=False,              help="Consider CTF correction during the alignment ")
	parser.add_option("--MPI" , 		action="store_true",	default=False,				help="use MPI version")

	(options,args) = parser.parse_args()
	
	from mpi          import mpi_init, mpi_comm_rank, mpi_comm_size, MPI_COMM_WORLD, MPI_TAG_UB
	from mpi          import mpi_barrier, mpi_send, mpi_recv, mpi_bcast, MPI_INT, mpi_finalize, MPI_FLOAT
	from applications import MPI_start_end, within_group_refinement, ali2d_ras
	from pixel_error  import multi_align_stability
	from utilities    import send_EMData, recv_EMData
	from utilities    import get_image, bcast_number_to_all, set_params2D, get_params2D
	from utilities    import group_proj_by_phitheta, model_circle, get_input_from_string

	sys.argv = mpi_init(len(sys.argv), sys.argv)
	myid = mpi_comm_rank(MPI_COMM_WORLD)
	number_of_proc = mpi_comm_size(MPI_COMM_WORLD)
	main_node = 0

	if len(args) == 2:
		stack  = args[0]
		outdir = args[1]
	else:
		ERROR("incomplete list of arguments", "sxproj_stability", 1, myid=myid)
		exit()
	if not options.MPI:
		ERROR("Non-MPI not supported!", "sxproj_stability", myid=myid)
		exit()		 

	if global_def.CACHE_DISABLE:
		from utilities import disable_bdb_cache
		disable_bdb_cache()
	global_def.BATCH = True

	#if os.path.exists(outdir):  ERROR('Output directory exists, please change the name and restart the program', "sxproj_stability", 1, myid)
	#mpi_barrier(MPI_COMM_WORLD)

	
	img_per_grp = options.img_per_group
	radius = options.radius
	ite = options.iter
	num_ali = options.num_ali
	thld_err = options.thld_err

	xrng        = get_input_from_string(options.xr)
	if  options.yr == "-1":  yrng = xrng
	else          :  yrng = get_input_from_string(options.yr)
	step        = get_input_from_string(options.ts)


	if myid == main_node:
		nima = EMUtil.get_image_count(stack)
		img  = get_image(stack)
		nx   = img.get_xsize()
		ny   = img.get_ysize()
	else:
		nima = 0
		nx = 0
		ny = 0
	nima = bcast_number_to_all(nima)
	nx   = bcast_number_to_all(nx)
	ny   = bcast_number_to_all(ny)
	if radius == -1: radius = nx/2-2
	mask = model_circle(radius, nx, nx)

	st = time()
	if options.grouping == "GRP":
		if myid == main_node:
			print "  A  ",myid,"  ",time()-st
			proj_attr = EMUtil.get_all_attributes(stack, "xform.projection")
			proj_params = []
			for i in xrange(nima):
				dp = proj_attr[i].get_params("spider")
				phi, theta, psi, s2x, s2y = dp["phi"], dp["theta"], dp["psi"], -dp["tx"], -dp["ty"]
				proj_params.append([phi, theta, psi, s2x, s2y])

			# Here is where the grouping is done, I didn't put enough annotation in the group_proj_by_phitheta,
			# So I will briefly explain it here
			# proj_list  : Returns a list of list of particle numbers, each list contains img_per_grp particle numbers
			#              except for the last one. Depending on the number of particles left, they will either form a
			#              group or append themselves to the last group
			# angle_list : Also returns a list of list, each list contains three numbers (phi, theta, delta), (phi, 
			#              theta) is the projection angle of the center of the group, delta is the range of this group
			# mirror_list: Also returns a list of list, each list contains img_per_grp True or False, which indicates
			#              whether it should take mirror position.
			# In this program angle_list and mirror list are not of interest.

			proj_list_all, angle_list, mirror_list = group_proj_by_phitheta(proj_params, img_per_grp=img_per_grp)
			del proj_params
			print "  B  number of groups  ",myid,"  ",len(proj_list_all),time()-st
		mpi_barrier(MPI_COMM_WORLD)

		# Number of groups, actually there could be one or two more groups, since the size of the remaining group varies
		# we will simply assign them to main node.
		n_grp = nima/img_per_grp-1

		# Divide proj_list_all equally to all nodes, and becomes proj_list
		proj_list = []
		for i in xrange(n_grp):
			proc_to_stay = i%number_of_proc
			if proc_to_stay == main_node:
				if myid == main_node: 	proj_list.append(proj_list_all[i])
			elif myid == main_node:
				mpi_send(len(proj_list_all[i]), 1, MPI_INT, proc_to_stay, MPI_TAG_UB, MPI_COMM_WORLD)
				mpi_send(proj_list_all[i], len(proj_list_all[i]), MPI_INT, proc_to_stay, MPI_TAG_UB, MPI_COMM_WORLD)
			elif myid == proc_to_stay:
				img_per_grp = mpi_recv(1, MPI_INT, main_node, MPI_TAG_UB, MPI_COMM_WORLD)
				img_per_grp = int(img_per_grp[0])
				temp = mpi_recv(img_per_grp, MPI_INT, main_node, MPI_TAG_UB, MPI_COMM_WORLD)
				proj_list.append(map(int, temp))
				del temp
			mpi_barrier(MPI_COMM_WORLD)
		print "  C  ",myid,"  ",time()-st
		if myid == main_node:
			# Assign the remaining groups to main_node
			for i in xrange(n_grp, len(proj_list_all)):
				proj_list.append(proj_list_all[i])
			del proj_list_all, angle_list, mirror_list


	#   Compute stability per projection projection direction, equal number assigned, thus overlaps
	elif options.grouping == "GEV":
		if options.delta == -1.0: ERROR("Angular step for reference projections is required for GEV method","sxproj_stability",1)
		from utilities import even_angles, nearestk_to_refdir, getvec
		refproj = even_angles(options.delta)
		img_begin, img_end = MPI_start_end(len(refproj), number_of_proc, myid)
		# Now each processor keeps its own share of reference projections
		refprojdir = refproj[img_begin: img_end]
		del refproj

		ref_ang = [0.0]*(len(refprojdir)*2)
		for i in xrange(len(refprojdir)):
			ref_ang[i*2]   = refprojdir[0][0]
			ref_ang[i*2+1] = refprojdir[0][1]+i*0.1

		print "  A  ",myid,"  ",time()-st
		proj_attr = EMUtil.get_all_attributes(stack, "xform.projection")
		#  the solution below is very slow, do not use it unless there is a problem with the i/O
		"""
		for i in xrange(number_of_proc):
			if myid == i:
				proj_attr = EMUtil.get_all_attributes(stack, "xform.projection")
			mpi_barrier(MPI_COMM_WORLD)
		"""
		print "  B  ",myid,"  ",time()-st

		proj_ang = [0.0]*(nima*2)
		for i in xrange(nima):
			dp = proj_attr[i].get_params("spider")
			proj_ang[i*2]   = dp["phi"]
			proj_ang[i*2+1] = dp["theta"]
		print "  C  ",myid,"  ",time()-st
		asi = Util.nearestk_to_refdir(proj_ang, ref_ang, img_per_grp)
		del proj_ang, ref_ang
		proj_list = []
		for i in xrange(len(refprojdir)):
			proj_list.append(asi[i*img_per_grp:(i+1)*img_per_grp])
		del asi
		print "  D  ",myid,"  ",time()-st
		#from sys import exit
		#exit()


	#   Compute stability per projection
	elif options.grouping == "PPR":
		print "  A  ",myid,"  ",time()-st
		proj_attr = EMUtil.get_all_attributes(stack, "xform.projection")
		print "  B  ",myid,"  ",time()-st
		proj_params = []
		for i in xrange(nima):
			dp = proj_attr[i].get_params("spider")
			phi, theta, psi, s2x, s2y = dp["phi"], dp["theta"], dp["psi"], -dp["tx"], -dp["ty"]
			proj_params.append([phi, theta, psi, s2x, s2y])
		img_begin, img_end = MPI_start_end(nima, number_of_proc, myid)
		print "  C  ",myid,"  ",time()-st
		from utilities import nearest_proj
		proj_list, mirror_list = nearest_proj(proj_params, img_per_grp, range(img_begin, img_begin+1))#range(img_begin, img_end))
		refprojdir = proj_params[img_begin: img_end]
		del proj_params, mirror_list
		print "  D  ",myid,"  ",time()-st
	else:  ERROR("Incorrect projection grouping option","sxproj_stability",1)
	"""
	from utilities import write_text_file
	for i in xrange(len(proj_list)):
		write_text_file(proj_list[i],"projlist%06d_%04d"%(i,myid))
	"""

	###########################################################################################################
	# Begin stability test
	from utilities import get_params_proj, read_text_file
	#if myid == 0:
	#	from utilities import read_text_file
	#	proj_list[0] = map(int, read_text_file("lggrpp0.txt"))


	from utilities import model_blank
	aveList = [model_blank(nx,ny)]*len(proj_list)
	if options.grouping == "GRP":  refprojdir = [[0.0,0.0,-1.0]]*len(proj_list)
	for i in xrange(len(proj_list)):
		print "  E  ",myid,"  ",time()-st
		class_data = EMData.read_images(stack, proj_list[i])
		#print "  R  ",myid,"  ",time()-st
		if options.CTF :
			from filter import filt_ctf
			for im in xrange(len(class_data)):  #  MEM LEAK!!
				atemp = class_data[im].copy()
				btemp = filt_ctf(atemp, atemp.get_attr("ctf"), binary=1)
				class_data[im] = btemp
				#class_data[im] = filt_ctf(class_data[im], class_data[im].get_attr("ctf"), binary=1)
		for im in class_data:
			try:
				t = im.get_attr("xform.align2d") # if they are there, no need to set them!
			except:
				try:
					t = im.get_attr("xform.projection")
					d = t.get_params("spider")
					set_params2D(im, [0.0,-d["tx"],-d["ty"],0,1.0])
				except:
					set_params2D(im, [0.0, 0.0, 0.0, 0, 1.0])
		#print "  F  ",myid,"  ",time()-st
		# Here, we perform realignment num_ali times
		all_ali_params = []
		for j in xrange(num_ali):
			if( xrng[0] == 0.0 and yrng[0] == 0.0 ):
				avet = ali2d_ras(class_data, randomize = True, ir = 1, ou = radius, rs = 1, step = 1.0, dst = 90.0, maxit = ite, check_mirror = True, FH=options.fl, FF=options.aa)
			else:
				avet = within_group_refinement(class_data, mask, True, 1, radius, 1, xrng, yrng, step, 90.0, ite, options.fl, options.aa)
			ali_params = []
			for im in xrange(len(class_data)):
				alpha, sx, sy, mirror, scale = get_params2D(class_data[im])
				ali_params.extend( [alpha, sx, sy, mirror] )
			all_ali_params.append(ali_params)
		#aveList[i] = avet
		#print "  G  ",myid,"  ",time()-st
		del ali_params
		# We determine the stability of this group here.
		# stable_set contains all particles deemed stable, it is a list of list
		# each list has two elements, the first is the pixel error, the second is the image number
		# stable_set is sorted based on pixel error
		#from utilities import write_text_file
		#write_text_file(all_ali_params, "all_ali_params%03d.txt"%myid)
		stable_set, mir_stab_rate, average_pix_err = multi_align_stability(all_ali_params, 0.0, 10000.0, thld_err, False, 2*radius+1)
		#print "  H  ",myid,"  ",time()-st
		if(len(stable_set) > 5):
			stable_set_id = []
			members = []
			pix_err = []
			# First put the stable members into attr 'members' and 'pix_err'
			for s in stable_set:
				# s[1] - number in this subset
				stable_set_id.append(s[1])
				# the original image number
				members.append(proj_list[i][s[1]])
				pix_err.append(s[0])
			# Then put the unstable members into attr 'members' and 'pix_err'
			from fundamentals import rot_shift2D
			avet.to_zero()
			if options.grouping == "GRP":
				aphi = 0.0
				atht = 0.0
				vphi = 0.0
				vtht = 0.0
			l = -1
			for j in xrange(len(proj_list[i])):
				#  Here it will only work if stable_set_id is sorted in the increasing number, see how l progresses
				if j in stable_set_id:
					l += 1
					avet += rot_shift2D(class_data[j], stable_set[l][2][0], stable_set[l][2][1], stable_set[l][2][2], stable_set[l][2][3] )
					if options.grouping == "GRP":
						phi, theta, psi, sxs, sys = get_params_proj(class_data[j])
						if( theta > 90.0):
							phi = (phi+540.0)%360.0
							theta = 180.0 - theta
						aphi += phi
						atht += theta
						vphi += phi*phi
						vtht += theta*theta
				else:
					members.append(proj_list[i][j])
					pix_err.append(99999.99)
			aveList[i] = avet.copy()
			if l>1 :
				l += 1
				aveList[i] /= l
				if options.grouping == "GRP":
					aphi /= l
					atht /= l
					vphi = (vphi - l*aphi*aphi)/l
					vtht = (vtht - l*atht*atht)/l
					from math import sqrt
					refprojdir[i] = [aphi, atht, (sqrt(max(vphi,0.0))+sqrt(max(vtht,0.0)))/2.0]

			# Here more information has to be stored, PARTICULARLY WHAT IS THE REFERENCE DIRECTION
			aveList[i].set_attr('members', members)
			aveList[i].set_attr('refprojdir',refprojdir[i])
			aveList[i].set_attr('pixerr', pix_err)
		else:
			print  " empty group ",i, refprojdir[i]
			aveList[i].set_attr('members',[-1])
			aveList[i].set_attr('refprojdir',refprojdir[i])
			aveList[i].set_attr('pixerr', [99999.])

	del class_data

	if myid == main_node:
		km = 0
		for i in xrange(number_of_proc):
			if i == main_node :
				for im in xrange(len(aveList)):
					aveList[im].write_image(args[1], km)
					km += 1
			else:
				nl = mpi_recv(1, MPI_INT, i, MPI_TAG_UB, MPI_COMM_WORLD)
				nl = int(nl[0])
				for im in xrange(nl):
					ave = recv_EMData(i, im+i+70000)
					nm = mpi_recv(1, MPI_INT, i, MPI_TAG_UB, MPI_COMM_WORLD)
					nm = int(nm[0])
					members = mpi_recv(nm, MPI_INT, i, MPI_TAG_UB, MPI_COMM_WORLD)
					ave.set_attr('members', map(int, members))
					members = mpi_recv(nm, MPI_FLOAT, i, MPI_TAG_UB, MPI_COMM_WORLD)
					ave.set_attr('pixerr', map(float, members))
					members = mpi_recv(3, MPI_FLOAT, i, MPI_TAG_UB, MPI_COMM_WORLD)
					ave.set_attr('refprojdir', map(float, members))
					ave.write_image(args[1], km)
					km += 1
	else:
		mpi_send(len(aveList), 1, MPI_INT, main_node, MPI_TAG_UB, MPI_COMM_WORLD)
		for im in xrange(len(aveList)):
			send_EMData(aveList[im], main_node,im+myid+70000)
			members = aveList[im].get_attr('members')
			mpi_send(len(members), 1, MPI_INT, main_node, MPI_TAG_UB, MPI_COMM_WORLD)
			mpi_send(members, len(members), MPI_INT, main_node, MPI_TAG_UB, MPI_COMM_WORLD)
			members = aveList[im].get_attr('pixerr')
			mpi_send(members, len(members), MPI_FLOAT, main_node, MPI_TAG_UB, MPI_COMM_WORLD)
			try:
				members = aveList[im].get_attr('refprojdir')
				mpi_send(members, 3, MPI_FLOAT, main_node, MPI_TAG_UB, MPI_COMM_WORLD)
			except:
				mpi_send([-999.0,-999.0,-999.0], 3, MPI_FLOAT, main_node, MPI_TAG_UB, MPI_COMM_WORLD)

	global_def.BATCH = False
	mpi_barrier(MPI_COMM_WORLD)
	from mpi import mpi_finalize
	mpi_finalize()
Beispiel #15
0
def main(args):
    """
	Main function

	Arguments:
	args - Arguments as dictionary

	Returns:
	None
	"""

    main_mpi_proc = 0
    my_mpi_proc_id = mpi.mpi_comm_rank(mpi.MPI_COMM_WORLD)
    n_mpi_procs = mpi.mpi_comm_size(mpi.MPI_COMM_WORLD)

    # Import the file names
    sanity_checks(args, my_mpi_proc_id)
    if my_mpi_proc_id == main_mpi_proc:
        if args['Expert options']:
            sp_global_def.sxprint(
                'Expert option detected! The program will enable expert mode!')
        if args['Magnification correction']:
            sp_global_def.sxprint(
                'Magnification correction option detected! The program will enable magnification correction mode!'
            )
        file_names = load_file_names_by_pattern(
            args['input_micrograph_pattern'], args['selection_file'])
    else:
        file_names = []

    file_names = sp_utilities.wrap_mpi_bcast(file_names, main_mpi_proc)

    # Split the list indices by node
    max_proc = min(n_mpi_procs, len(file_names))
    if my_mpi_proc_id in list(range(max_proc)):
        idx_start, idx_end = sp_applications.MPI_start_end(
            len(file_names), max_proc, my_mpi_proc_id)
    else:
        idx_start = 0
        idx_end = 0

    nima = idx_end - idx_start
    max_nima_list = sp_utilities.wrap_mpi_gatherv([nima], main_mpi_proc,
                                                  mpi.MPI_COMM_WORLD)
    max_nima_list = sp_utilities.wrap_mpi_bcast(max_nima_list, main_mpi_proc,
                                                mpi.MPI_COMM_WORLD)
    max_nima = max(max_nima_list)
    mpi_print_id = max_nima_list.index(max_nima)

    try:
        os.makedirs(args['output_directory'])
    except OSError:
        pass
    sp_global_def.write_command(args['output_directory'])
    start_unblur = time.time.time()
    for idx, file_path in enumerate(file_names[idx_start:idx_end]):
        if my_mpi_proc_id == mpi_print_id:
            total_time = time.time.time() - start_unblur
            if idx == 0:
                average_time = 0
            else:
                average_time = total_time / float(idx)
            sp_global_def.sxprint(
                '{0: 6.2f}% => Elapsed time: {1: 6.2f}min | Estimated total time: {2: 6.2f}min | Time per micrograph: {3: 5.2f}min/mic'
                .format(
                    100 * idx / float(max_nima),
                    total_time / float(60),
                    (max_nima) * average_time / float(60),
                    average_time / float(60),
                ))

        file_name = os.path.basename(os.path.splitext(file_path)[0])
        file_name_out = '{0}.mrc'.format(file_name)
        file_name_log = '{0}.log'.format(file_name)
        file_name_err = '{0}.err'.format(file_name)

        output_dir_name = os.path.join(args['output_directory'], 'corrsum')
        output_dir_name_log = os.path.join(args['output_directory'],
                                           'corrsum_log')
        output_dir_name_dw = os.path.join(args['output_directory'],
                                          'corrsum_dw')
        output_dir_name_dw_log = os.path.join(args['output_directory'],
                                              'corrsum_dw_log')
        if args['additional_dose_unadjusted']:
            unblur_list = (
                (True, output_dir_name_dw, output_dir_name_dw_log),
                (False, output_dir_name, output_dir_name_log),
            )
        elif args['skip_dose_adjustment']:
            unblur_list = ((False, output_dir_name, output_dir_name_log), )
        else:
            unblur_list = ((True, output_dir_name_dw,
                            output_dir_name_dw_log), )

        for dose_adjustment, dir_name, log_dir_name in unblur_list:
            try:
                os.makedirs(dir_name)
            except OSError:
                pass
            try:
                os.makedirs(log_dir_name)
            except OSError:
                pass
            output_name = os.path.join(dir_name, file_name_out)
            output_name_log = os.path.join(log_dir_name, file_name_log)
            output_name_err = os.path.join(log_dir_name, file_name_err)
            unblur_command = create_unblur_command(
                file_path,
                output_name,
                args['pixel_size'],
                args['bin_factor'],
                dose_adjustment,
                args['voltage'],
                args['exposure_per_frame'],
                args['pre_exposure'],
                args['Expert options'],
                args['min_shift_initial'],
                args['outer_radius'],
                args['b_factor'],
                args['half_width_vert'],
                args['half_width_hor'],
                args['termination'],
                args['max_iterations'],
                bool(not args['dont_restore_noise_power']),
                args['gain_file'],
                args['first_frame'],
                args['last_frame'],
                args['Magnification correction'],
                args['distortion_angle'],
                args['major_scale'],
                args['minor_scale'],
            )

            execute_command = r'echo "{0}" | {1}'.format(
                unblur_command, args['unblur_path'])
            with open(output_name_log, 'w') as log, open(output_name_err,
                                                         'w') as err:
                start = time.time.time()
                child = subprocess.Popen(execute_command,
                                         shell=True,
                                         stdout=log,
                                         stderr=err)
                child.wait()
                if child.returncode != 0:
                    sp_global_def.sxprint(
                        'Process failed for image {0}.\nPlease make sure that the unblur path is correct\nand check the respective logfile.'
                        .format(file_path))
                log.write('Time => {0:.2f} for command: {1}'.format(
                    time.time.time() - start, execute_command))

    mpi.mpi_barrier(mpi.MPI_COMM_WORLD)

    if my_mpi_proc_id == mpi_print_id:
        idx = idx + 1
        total_time = time.time.time() - start_unblur
        average_time = total_time / float(idx)
        sp_global_def.sxprint(
            '{0: 6.2f}% => Elapsed time: {1: 6.2f}min | Estimated total time: {2: 6.2f}min | Time per micrograph: {3: 5.2f}min/mic'
            .format(
                100 * idx / float(max_nima),
                total_time / float(60),
                (max_nima) * average_time / float(60),
                average_time / float(60),
            ))
Beispiel #16
0
def ali3d_MPI(stack,
              ref_vol,
              outdir,
              maskfile=None,
              ir=1,
              ou=-1,
              rs=1,
              xr="4 2 2 1",
              yr="-1",
              ts="1 1 0.5 0.25",
              delta="10 6 4 4",
              an="-1",
              center=0,
              maxit=5,
              term=95,
              CTF=False,
              fourvar=False,
              snr=1.0,
              ref_a="S",
              sym="c1",
              sort=True,
              cutoff=999.99,
              pix_cutoff="0",
              two_tail=False,
              model_jump="1 1 1 1 1",
              restart=False,
              save_half=False,
              protos=None,
              oplane=None,
              lmask=-1,
              ilmask=-1,
              findseam=False,
              vertstep=None,
              hpars="-1",
              hsearch="0.0 50.0",
              full_output=False,
              compare_repro=False,
              compare_ref_free="-1",
              ref_free_cutoff="-1 -1 -1 -1",
              wcmask=None,
              debug=False,
              recon_pad=4,
              olmask=75):

    from alignment import Numrinit, prepare_refrings
    from utilities import model_circle, get_image, drop_image, get_input_from_string
    from utilities import bcast_list_to_all, bcast_number_to_all, reduce_EMData_to_root, bcast_EMData_to_all
    from utilities import send_attr_dict
    from utilities import get_params_proj, file_type
    from fundamentals import rot_avg_image
    import os
    import types
    from utilities import print_begin_msg, print_end_msg, print_msg
    from mpi import mpi_bcast, mpi_comm_size, mpi_comm_rank, MPI_FLOAT, MPI_COMM_WORLD, mpi_barrier, mpi_reduce
    from mpi import mpi_reduce, MPI_INT, MPI_SUM, mpi_finalize
    from filter import filt_ctf
    from projection import prep_vol, prgs
    from statistics import hist_list, varf3d_MPI, fsc_mask
    from numpy import array, bincount, array2string, ones

    number_of_proc = mpi_comm_size(MPI_COMM_WORLD)
    myid = mpi_comm_rank(MPI_COMM_WORLD)
    main_node = 0
    if myid == main_node:
        if os.path.exists(outdir):
            ERROR(
                'Output directory exists, please change the name and restart the program',
                "ali3d_MPI", 1)
        os.mkdir(outdir)
    mpi_barrier(MPI_COMM_WORLD)

    if debug:
        from time import sleep
        while not os.path.exists(outdir):
            print "Node ", myid, "  waiting..."
            sleep(5)

        info_file = os.path.join(outdir, "progress%04d" % myid)
        finfo = open(info_file, 'w')
    else:
        finfo = None
    mjump = get_input_from_string(model_jump)
    xrng = get_input_from_string(xr)
    if yr == "-1": yrng = xrng
    else: yrng = get_input_from_string(yr)
    step = get_input_from_string(ts)
    delta = get_input_from_string(delta)
    ref_free_cutoff = get_input_from_string(ref_free_cutoff)
    pix_cutoff = get_input_from_string(pix_cutoff)

    lstp = min(len(xrng), len(yrng), len(step), len(delta))
    if an == "-1":
        an = [-1] * lstp
    else:
        an = get_input_from_string(an)
    # make sure pix_cutoff is set for all iterations
    if len(pix_cutoff) < lstp:
        for i in xrange(len(pix_cutoff), lstp):
            pix_cutoff.append(pix_cutoff[-1])
    # don't waste time on sub-pixel alignment for low-resolution ang incr
    for i in range(len(step)):
        if (delta[i] > 4 or delta[i] == -1) and step[i] < 1:
            step[i] = 1

    first_ring = int(ir)
    rstep = int(rs)
    last_ring = int(ou)
    max_iter = int(maxit)
    center = int(center)

    nrefs = EMUtil.get_image_count(ref_vol)
    nmasks = 0
    if maskfile:
        # read number of masks within each maskfile (mc)
        nmasks = EMUtil.get_image_count(maskfile)
        # open masks within maskfile (mc)
        maskF = EMData.read_images(maskfile, xrange(nmasks))
    vol = EMData.read_images(ref_vol, xrange(nrefs))
    nx = vol[0].get_xsize()

    ## make sure box sizes are the same
    if myid == main_node:
        im = EMData.read_images(stack, [0])
        bx = im[0].get_xsize()
        if bx != nx:
            print_msg(
                "Error: Stack box size (%i) differs from initial model (%i)\n"
                % (bx, nx))
            sys.exit()
        del im, bx

    # for helical processing:
    helicalrecon = False
    if protos is not None or hpars != "-1" or findseam is True:
        helicalrecon = True
        # if no out-of-plane param set, use 5 degrees
        if oplane is None:
            oplane = 5.0
    if protos is not None:
        proto = get_input_from_string(protos)
        if len(proto) != nrefs:
            print_msg("Error: insufficient protofilament numbers supplied")
            sys.exit()
    if hpars != "-1":
        hpars = get_input_from_string(hpars)
        if len(hpars) != 2 * nrefs:
            print_msg("Error: insufficient helical parameters supplied")
            sys.exit()
    ## create helical parameter file for helical reconstruction
    if helicalrecon is True and myid == main_node:
        from hfunctions import createHpar
        # create initial helical parameter files
        dp = [0] * nrefs
        dphi = [0] * nrefs
        vdp = [0] * nrefs
        vdphi = [0] * nrefs
        for iref in xrange(nrefs):
            hpar = os.path.join(outdir, "hpar%02d.spi" % (iref))
            params = False
            if hpars != "-1":
                # if helical parameters explicitly given, set twist & rise
                params = [float(hpars[iref * 2]), float(hpars[(iref * 2) + 1])]
            dp[iref], dphi[iref], vdp[iref], vdphi[iref] = createHpar(
                hpar, proto[iref], params, vertstep)

    # get values for helical search parameters
    hsearch = get_input_from_string(hsearch)
    if len(hsearch) != 2:
        print_msg("Error: specify outer and inner radii for helical search")
        sys.exit()

    if last_ring < 0 or last_ring > int(nx / 2) - 2:
        last_ring = int(nx / 2) - 2

    if myid == main_node:
        #	import user_functions
        #	user_func = user_functions.factory[user_func_name]

        print_begin_msg("ali3d_MPI")
        print_msg("Input stack		 : %s\n" % (stack))
        print_msg("Reference volume	    : %s\n" % (ref_vol))
        print_msg("Output directory	    : %s\n" % (outdir))
        if nmasks > 0:
            print_msg("Maskfile (number of masks)  : %s (%i)\n" %
                      (maskfile, nmasks))
        print_msg("Inner radius		: %i\n" % (first_ring))
        print_msg("Outer radius		: %i\n" % (last_ring))
        print_msg("Ring step		   : %i\n" % (rstep))
        print_msg("X search range	      : %s\n" % (xrng))
        print_msg("Y search range	      : %s\n" % (yrng))
        print_msg("Translational step	  : %s\n" % (step))
        print_msg("Angular step		: %s\n" % (delta))
        print_msg("Angular search range	: %s\n" % (an))
        print_msg("Maximum iteration	   : %i\n" % (max_iter))
        print_msg("Center type		 : %i\n" % (center))
        print_msg("CTF correction	      : %s\n" % (CTF))
        print_msg("Signal-to-Noise Ratio       : %f\n" % (snr))
        print_msg("Reference projection method : %s\n" % (ref_a))
        print_msg("Symmetry group	      : %s\n" % (sym))
        print_msg("Fourier padding for 3D      : %i\n" % (recon_pad))
        print_msg("Number of reference models  : %i\n" % (nrefs))
        print_msg("Sort images between models  : %s\n" % (sort))
        print_msg("Allow images to jump	: %s\n" % (mjump))
        print_msg("CC cutoff standard dev      : %f\n" % (cutoff))
        print_msg("Two tail cutoff	     : %s\n" % (two_tail))
        print_msg("Termination pix error       : %f\n" % (term))
        print_msg("Pixel error cutoff	  : %s\n" % (pix_cutoff))
        print_msg("Restart		     : %s\n" % (restart))
        print_msg("Full output		 : %s\n" % (full_output))
        print_msg("Compare reprojections       : %s\n" % (compare_repro))
        print_msg("Compare ref free class avgs : %s\n" % (compare_ref_free))
        print_msg("Use cutoff from ref free    : %s\n" % (ref_free_cutoff))
        if protos:
            print_msg("Protofilament numbers	: %s\n" % (proto))
            print_msg("Using helical search range   : %s\n" % hsearch)
        if findseam is True:
            print_msg("Using seam-based reconstruction\n")
        if hpars != "-1":
            print_msg("Using hpars		  : %s\n" % hpars)
        if vertstep != None:
            print_msg("Using vertical step    : %.2f\n" % vertstep)
        if save_half is True:
            print_msg("Saving even/odd halves\n")
        for i in xrange(100):
            print_msg("*")
        print_msg("\n\n")
    if maskfile:
        if type(maskfile) is types.StringType: mask3D = get_image(maskfile)
        else: mask3D = maskfile
    else: mask3D = model_circle(last_ring, nx, nx, nx)

    numr = Numrinit(first_ring, last_ring, rstep, "F")
    mask2D = model_circle(last_ring, nx, nx) - model_circle(first_ring, nx, nx)

    fscmask = model_circle(last_ring, nx, nx, nx)
    if CTF:
        from filter import filt_ctf
    from reconstruction_rjh import rec3D_MPI_noCTF

    if myid == main_node:
        active = EMUtil.get_all_attributes(stack, 'active')
        list_of_particles = []
        for im in xrange(len(active)):
            if active[im]: list_of_particles.append(im)
        del active
        nima = len(list_of_particles)
    else:
        nima = 0
    total_nima = bcast_number_to_all(nima, source_node=main_node)

    if myid != main_node:
        list_of_particles = [-1] * total_nima
    list_of_particles = bcast_list_to_all(list_of_particles,
                                          source_node=main_node)

    image_start, image_end = MPI_start_end(total_nima, number_of_proc, myid)

    # create a list of images for each node
    list_of_particles = list_of_particles[image_start:image_end]
    nima = len(list_of_particles)
    if debug:
        finfo.write("image_start, image_end: %d %d\n" %
                    (image_start, image_end))
        finfo.flush()

    data = EMData.read_images(stack, list_of_particles)

    t_zero = Transform({
        "type": "spider",
        "phi": 0,
        "theta": 0,
        "psi": 0,
        "tx": 0,
        "ty": 0
    })
    transmulti = [[t_zero for i in xrange(nrefs)] for j in xrange(nima)]

    for iref, im in ((iref, im) for iref in xrange(nrefs)
                     for im in xrange(nima)):
        if nrefs == 1:
            transmulti[im][iref] = data[im].get_attr("xform.projection")
        else:
            # if multi models, keep track of eulers for all models
            try:
                transmulti[im][iref] = data[im].get_attr("eulers_txty.%i" %
                                                         iref)
            except:
                data[im].set_attr("eulers_txty.%i" % iref, t_zero)

    scoremulti = [[0.0 for i in xrange(nrefs)] for j in xrange(nima)]
    pixelmulti = [[0.0 for i in xrange(nrefs)] for j in xrange(nima)]
    ref_res = [0.0 for x in xrange(nrefs)]
    apix = data[0].get_attr('apix_x')

    # for oplane parameter, create cylindrical mask
    if oplane is not None and myid == main_node:
        from hfunctions import createCylMask
        cmaskf = os.path.join(outdir, "mask3D_cyl.mrc")
        mask3D = createCylMask(data, olmask, lmask, ilmask, cmaskf)
        # if finding seam of helix, create wedge masks
        if findseam is True:
            wedgemask = []
            for pf in xrange(nrefs):
                wedgemask.append(EMData())
            # wedgemask option
            if wcmask is not None:
                wcmask = get_input_from_string(wcmask)
                if len(wcmask) != 3:
                    print_msg(
                        "Error: wcmask option requires 3 values: x y radius")
                    sys.exit()

    # determine if particles have helix info:
    try:
        data[0].get_attr('h_angle')
        original_data = []
        boxmask = True
        from hfunctions import createBoxMask
    except:
        boxmask = False

    # prepare particles
    for im in xrange(nima):
        data[im].set_attr('ID', list_of_particles[im])
        data[im].set_attr('pix_score', int(0))
        if CTF:
            # only phaseflip particles, not full CTF correction
            ctf_params = data[im].get_attr("ctf")
            st = Util.infomask(data[im], mask2D, False)
            data[im] -= st[0]
            data[im] = filt_ctf(data[im], ctf_params, sign=-1, binary=1)
            data[im].set_attr('ctf_applied', 1)
        # for window mask:
        if boxmask is True:
            h_angle = data[im].get_attr("h_angle")
            original_data.append(data[im].copy())
            bmask = createBoxMask(nx, apix, ou, lmask, h_angle)
            data[im] *= bmask
            del bmask
    if debug:
        finfo.write('%d loaded  \n' % nima)
        finfo.flush()
    if myid == main_node:
        # initialize data for the reference preparation function
        ref_data = [mask3D, max(center, 0), None, None, None, None]
        # for method -1, switch off centering in user function

    from time import time

    #  this is needed for gathering of pixel errors
    disps = []
    recvcount = []
    disps_score = []
    recvcount_score = []
    for im in xrange(number_of_proc):
        if (im == main_node):
            disps.append(0)
            disps_score.append(0)
        else:
            disps.append(disps[im - 1] + recvcount[im - 1])
            disps_score.append(disps_score[im - 1] + recvcount_score[im - 1])
        ib, ie = MPI_start_end(total_nima, number_of_proc, im)
        recvcount.append(ie - ib)
        recvcount_score.append((ie - ib) * nrefs)

    pixer = [0.0] * nima
    cs = [0.0] * 3
    total_iter = 0
    volodd = EMData.read_images(ref_vol, xrange(nrefs))
    voleve = EMData.read_images(ref_vol, xrange(nrefs))

    if restart:
        # recreate initial volumes from alignments stored in header
        itout = "000_00"
        for iref in xrange(nrefs):
            if (nrefs == 1):
                modout = ""
            else:
                modout = "_model_%02d" % (iref)

            if (sort):
                group = iref
                for im in xrange(nima):
                    imgroup = data[im].get_attr('group')
                    if imgroup == iref:
                        data[im].set_attr('xform.projection',
                                          transmulti[im][iref])
            else:
                group = int(999)
                for im in xrange(nima):
                    data[im].set_attr('xform.projection', transmulti[im][iref])

            fscfile = os.path.join(outdir, "fsc_%s%s" % (itout, modout))

            vol[iref], fscc, volodd[iref], voleve[iref] = rec3D_MPI_noCTF(
                data,
                sym,
                fscmask,
                fscfile,
                myid,
                main_node,
                index=group,
                npad=recon_pad)

            if myid == main_node:
                if helicalrecon:
                    from hfunctions import processHelicalVol
                    vstep = None
                    if vertstep is not None:
                        vstep = (vdp[iref], vdphi[iref])
                    print_msg(
                        "Old rise and twist for model %i     : %8.3f, %8.3f\n"
                        % (iref, dp[iref], dphi[iref]))
                    hvals = processHelicalVol(vol[iref], voleve[iref],
                                              volodd[iref], iref, outdir,
                                              itout, dp[iref], dphi[iref],
                                              apix, hsearch, findseam, vstep,
                                              wcmask)
                    (vol[iref], voleve[iref], volodd[iref], dp[iref],
                     dphi[iref], vdp[iref], vdphi[iref]) = hvals
                    print_msg(
                        "New rise and twist for model %i     : %8.3f, %8.3f\n"
                        % (iref, dp[iref], dphi[iref]))
                    # get new FSC from symmetrized half volumes
                    fscc = fsc_mask(volodd[iref], voleve[iref], mask3D, rstep,
                                    fscfile)
                else:
                    vol[iref].write_image(
                        os.path.join(outdir, "vol_%s.hdf" % itout), -1)

                if save_half is True:
                    volodd[iref].write_image(
                        os.path.join(outdir, "volodd_%s.hdf" % itout), -1)
                    voleve[iref].write_image(
                        os.path.join(outdir, "voleve_%s.hdf" % itout), -1)

                if nmasks > 1:
                    # Read mask for multiplying
                    ref_data[0] = maskF[iref]
                ref_data[2] = vol[iref]
                ref_data[3] = fscc
                #  call user-supplied function to prepare reference image, i.e., center and filter it
                vol[iref], cs, fl = ref_ali3d(ref_data)
                vol[iref].write_image(
                    os.path.join(outdir, "volf_%s.hdf" % (itout)), -1)
                if (apix == 1):
                    res_msg = "Models filtered at spatial frequency of:\t"
                    res = fl
                else:
                    res_msg = "Models filtered at resolution of:       \t"
                    res = apix / fl
                ares = array2string(array(res), precision=2)
                print_msg("%s%s\n\n" % (res_msg, ares))

            bcast_EMData_to_all(vol[iref], myid, main_node)
            # write out headers, under MPI writing has to be done sequentially
            mpi_barrier(MPI_COMM_WORLD)

    # projection matching
    for N_step in xrange(lstp):
        terminate = 0
        Iter = -1
        while (Iter < max_iter - 1 and terminate == 0):
            Iter += 1
            total_iter += 1
            itout = "%03g_%02d" % (delta[N_step], Iter)
            if myid == main_node:
                print_msg(
                    "ITERATION #%3d, inner iteration #%3d\nDelta = %4.1f, an = %5.2f, xrange = %5.2f, yrange = %5.2f, step = %5.2f\n\n"
                    % (N_step, Iter, delta[N_step], an[N_step], xrng[N_step],
                       yrng[N_step], step[N_step]))

            for iref in xrange(nrefs):
                if myid == main_node: start_time = time()
                volft, kb = prep_vol(vol[iref])

                ## constrain projections to out of plane parameter
                theta1 = None
                theta2 = None
                if oplane is not None:
                    theta1 = 90 - oplane
                    theta2 = 90 + oplane
                refrings = prepare_refrings(volft,
                                            kb,
                                            nx,
                                            delta[N_step],
                                            ref_a,
                                            sym,
                                            numr,
                                            MPI=True,
                                            phiEqpsi="Minus",
                                            initial_theta=theta1,
                                            delta_theta=theta2)

                del volft, kb

                if myid == main_node:
                    print_msg(
                        "Time to prepare projections for model %i: %s\n" %
                        (iref, legibleTime(time() - start_time)))
                    start_time = time()

                for im in xrange(nima):
                    data[im].set_attr("xform.projection", transmulti[im][iref])
                    if an[N_step] == -1:
                        t1, peak, pixer[im] = proj_ali_incore(
                            data[im], refrings, numr, xrng[N_step],
                            yrng[N_step], step[N_step], finfo)
                    else:
                        t1, peak, pixer[im] = proj_ali_incore_local(
                            data[im], refrings, numr, xrng[N_step],
                            yrng[N_step], step[N_step], an[N_step], finfo)
                    #data[im].set_attr("xform.projection"%iref, t1)
                    if nrefs > 1:
                        data[im].set_attr("eulers_txty.%i" % iref, t1)
                    scoremulti[im][iref] = peak
                    from pixel_error import max_3D_pixel_error
                    # t1 is the current param, t2 is old
                    t2 = transmulti[im][iref]
                    pixelmulti[im][iref] = max_3D_pixel_error(t1, t2, numr[-3])
                    transmulti[im][iref] = t1

                if myid == main_node:
                    print_msg("Time of alignment for model %i: %s\n" %
                              (iref, legibleTime(time() - start_time)))
                    start_time = time()

            # gather scoring data from all processors
            from mpi import mpi_gatherv
            scoremultisend = sum(scoremulti, [])
            pixelmultisend = sum(pixelmulti, [])
            tmp = mpi_gatherv(scoremultisend, len(scoremultisend), MPI_FLOAT,
                              recvcount_score, disps_score, MPI_FLOAT,
                              main_node, MPI_COMM_WORLD)
            tmp1 = mpi_gatherv(pixelmultisend, len(pixelmultisend), MPI_FLOAT,
                               recvcount_score, disps_score, MPI_FLOAT,
                               main_node, MPI_COMM_WORLD)
            tmp = mpi_bcast(tmp, (total_nima * nrefs), MPI_FLOAT, 0,
                            MPI_COMM_WORLD)
            tmp1 = mpi_bcast(tmp1, (total_nima * nrefs), MPI_FLOAT, 0,
                             MPI_COMM_WORLD)
            tmp = map(float, tmp)
            tmp1 = map(float, tmp1)
            score = array(tmp).reshape(-1, nrefs)
            pixelerror = array(tmp1).reshape(-1, nrefs)
            score_local = array(scoremulti)
            mean_score = score.mean(axis=0)
            std_score = score.std(axis=0)
            cut = mean_score - (cutoff * std_score)
            cut2 = mean_score + (cutoff * std_score)
            res_max = score_local.argmax(axis=1)
            minus_cc = [0.0 for x in xrange(nrefs)]
            minus_pix = [0.0 for x in xrange(nrefs)]
            minus_ref = [0.0 for x in xrange(nrefs)]

            #output pixel errors
            if (myid == main_node):
                from statistics import hist_list
                lhist = 20
                pixmin = pixelerror.min(axis=1)
                region, histo = hist_list(pixmin, lhist)
                if (region[0] < 0.0): region[0] = 0.0
                print_msg(
                    "Histogram of pixel errors\n      ERROR       number of particles\n"
                )
                for lhx in xrange(lhist):
                    print_msg(" %10.3f     %7d\n" % (region[lhx], histo[lhx]))
                # Terminate if 95% within 1 pixel error
                im = 0
                for lhx in xrange(lhist):
                    if (region[lhx] > 1.0): break
                    im += histo[lhx]
                print_msg("Percent of particles with pixel error < 1: %f\n\n" %
                          (im / float(total_nima) * 100))
                term_cond = float(term) / 100
                if (im / float(total_nima) > term_cond):
                    terminate = 1
                    print_msg("Terminating internal loop\n")
                del region, histo
            terminate = mpi_bcast(terminate, 1, MPI_INT, 0, MPI_COMM_WORLD)
            terminate = int(terminate[0])

            for im in xrange(nima):
                if (sort == False):
                    data[im].set_attr('group', 999)
                elif (mjump[N_step] == 1):
                    data[im].set_attr('group', int(res_max[im]))

                pix_run = data[im].get_attr('pix_score')
                if (pix_cutoff[N_step] == 1
                        and (terminate == 1 or Iter == max_iter - 1)):
                    if (pixelmulti[im][int(res_max[im])] > 1):
                        data[im].set_attr('pix_score', int(777))

                if (score_local[im][int(res_max[im])] < cut[int(
                        res_max[im])]) or (two_tail and score_local[im][int(
                            res_max[im])] > cut2[int(res_max[im])]):
                    data[im].set_attr('group', int(888))
                    minus_cc[int(res_max[im])] = minus_cc[int(res_max[im])] + 1

                if (pix_run == 777):
                    data[im].set_attr('group', int(777))
                    minus_pix[int(
                        res_max[im])] = minus_pix[int(res_max[im])] + 1

                if (compare_ref_free != "-1") and (ref_free_cutoff[N_step] !=
                                                   -1) and (total_iter > 1):
                    id = data[im].get_attr('ID')
                    if id in rejects:
                        data[im].set_attr('group', int(666))
                        minus_ref[int(
                            res_max[im])] = minus_ref[int(res_max[im])] + 1

            minus_cc_tot = mpi_reduce(minus_cc, nrefs, MPI_FLOAT, MPI_SUM, 0,
                                      MPI_COMM_WORLD)
            minus_pix_tot = mpi_reduce(minus_pix, nrefs, MPI_FLOAT, MPI_SUM, 0,
                                       MPI_COMM_WORLD)
            minus_ref_tot = mpi_reduce(minus_ref, nrefs, MPI_FLOAT, MPI_SUM, 0,
                                       MPI_COMM_WORLD)
            if (myid == main_node):
                if (sort):
                    tot_max = score.argmax(axis=1)
                    res = bincount(tot_max)
                else:
                    res = ones(nrefs) * total_nima
                print_msg("Particle distribution:	     \t\t%s\n" % (res * 1.0))
                afcut1 = res - minus_cc_tot
                afcut2 = afcut1 - minus_pix_tot
                afcut3 = afcut2 - minus_ref_tot
                print_msg("Particle distribution after cc cutoff:\t\t%s\n" %
                          (afcut1))
                print_msg("Particle distribution after pix cutoff:\t\t%s\n" %
                          (afcut2))
                print_msg("Particle distribution after ref cutoff:\t\t%s\n\n" %
                          (afcut3))

            res = [0.0 for i in xrange(nrefs)]
            for iref in xrange(nrefs):
                if (center == -1):
                    from utilities import estimate_3D_center_MPI, rotate_3D_shift
                    dummy = EMData()
                    cs[0], cs[1], cs[2], dummy, dummy = estimate_3D_center_MPI(
                        data, total_nima, myid, number_of_proc, main_node)
                    cs = mpi_bcast(cs, 3, MPI_FLOAT, main_node, MPI_COMM_WORLD)
                    cs = [-float(cs[0]), -float(cs[1]), -float(cs[2])]
                    rotate_3D_shift(data, cs)

                if (sort):
                    group = iref
                    for im in xrange(nima):
                        imgroup = data[im].get_attr('group')
                        if imgroup == iref:
                            data[im].set_attr('xform.projection',
                                              transmulti[im][iref])
                else:
                    group = int(999)
                    for im in xrange(nima):
                        data[im].set_attr('xform.projection',
                                          transmulti[im][iref])
                if (nrefs == 1):
                    modout = ""
                else:
                    modout = "_model_%02d" % (iref)

                fscfile = os.path.join(outdir, "fsc_%s%s" % (itout, modout))
                vol[iref], fscc, volodd[iref], voleve[iref] = rec3D_MPI_noCTF(
                    data,
                    sym,
                    fscmask,
                    fscfile,
                    myid,
                    main_node,
                    index=group,
                    npad=recon_pad)

                if myid == main_node:
                    print_msg("3D reconstruction time for model %i: %s\n" %
                              (iref, legibleTime(time() - start_time)))
                    start_time = time()

                # Compute Fourier variance
                if fourvar:
                    outvar = os.path.join(outdir, "volVar_%s.hdf" % (itout))
                    ssnr_file = os.path.join(outdir, "ssnr_%s" % (itout))
                    varf = varf3d_MPI(data,
                                      ssnr_text_file=ssnr_file,
                                      mask2D=None,
                                      reference_structure=vol[iref],
                                      ou=last_ring,
                                      rw=1.0,
                                      npad=1,
                                      CTF=None,
                                      sign=1,
                                      sym=sym,
                                      myid=myid)
                    if myid == main_node:
                        print_msg(
                            "Time to calculate 3D Fourier variance for model %i: %s\n"
                            % (iref, legibleTime(time() - start_time)))
                        start_time = time()
                        varf = 1.0 / varf
                        varf.write_image(outvar, -1)
                else:
                    varf = None

                if myid == main_node:
                    if helicalrecon:
                        from hfunctions import processHelicalVol

                        vstep = None
                        if vertstep is not None:
                            vstep = (vdp[iref], vdphi[iref])
                        print_msg(
                            "Old rise and twist for model %i     : %8.3f, %8.3f\n"
                            % (iref, dp[iref], dphi[iref]))
                        hvals = processHelicalVol(vol[iref], voleve[iref],
                                                  volodd[iref], iref, outdir,
                                                  itout, dp[iref], dphi[iref],
                                                  apix, hsearch, findseam,
                                                  vstep, wcmask)
                        (vol[iref], voleve[iref], volodd[iref], dp[iref],
                         dphi[iref], vdp[iref], vdphi[iref]) = hvals
                        print_msg(
                            "New rise and twist for model %i     : %8.3f, %8.3f\n"
                            % (iref, dp[iref], dphi[iref]))
                        # get new FSC from symmetrized half volumes
                        fscc = fsc_mask(volodd[iref], voleve[iref], mask3D,
                                        rstep, fscfile)

                        print_msg(
                            "Time to search and apply helical symmetry for model %i: %s\n\n"
                            % (iref, legibleTime(time() - start_time)))
                        start_time = time()
                    else:
                        vol[iref].write_image(
                            os.path.join(outdir, "vol_%s.hdf" % (itout)), -1)

                    if save_half is True:
                        volodd[iref].write_image(
                            os.path.join(outdir, "volodd_%s.hdf" % (itout)),
                            -1)
                        voleve[iref].write_image(
                            os.path.join(outdir, "voleve_%s.hdf" % (itout)),
                            -1)

                    if nmasks > 1:
                        # Read mask for multiplying
                        ref_data[0] = maskF[iref]
                    ref_data[2] = vol[iref]
                    ref_data[3] = fscc
                    ref_data[4] = varf
                    #  call user-supplied function to prepare reference image, i.e., center and filter it
                    vol[iref], cs, fl = ref_ali3d(ref_data)
                    vol[iref].write_image(
                        os.path.join(outdir, "volf_%s.hdf" % (itout)), -1)
                    if (apix == 1):
                        res_msg = "Models filtered at spatial frequency of:\t"
                        res[iref] = fl
                    else:
                        res_msg = "Models filtered at resolution of:       \t"
                        res[iref] = apix / fl

                del varf
                bcast_EMData_to_all(vol[iref], myid, main_node)

                if compare_ref_free != "-1": compare_repro = True
                if compare_repro:
                    outfile_repro = comp_rep(refrings, data, itout, modout,
                                             vol[iref], group, nima, nx, myid,
                                             main_node, outdir)
                    mpi_barrier(MPI_COMM_WORLD)
                    if compare_ref_free != "-1":
                        ref_free_output = os.path.join(
                            outdir, "ref_free_%s%s" % (itout, modout))
                        rejects = compare(compare_ref_free, outfile_repro,
                                          ref_free_output, yrng[N_step],
                                          xrng[N_step], rstep, nx, apix,
                                          ref_free_cutoff[N_step],
                                          number_of_proc, myid, main_node)

            # retrieve alignment params from all processors
            par_str = ['xform.projection', 'ID', 'group']
            if nrefs > 1:
                for iref in xrange(nrefs):
                    par_str.append('eulers_txty.%i' % iref)

            if myid == main_node:
                from utilities import recv_attr_dict
                recv_attr_dict(main_node, stack, data, par_str, image_start,
                               image_end, number_of_proc)

            else:
                send_attr_dict(main_node, data, par_str, image_start,
                               image_end)

            if myid == main_node:
                ares = array2string(array(res), precision=2)
                print_msg("%s%s\n\n" % (res_msg, ares))
                dummy = EMData()
                if full_output:
                    nimat = EMUtil.get_image_count(stack)
                    output_file = os.path.join(outdir, "paramout_%s" % itout)
                    foutput = open(output_file, 'w')
                    for im in xrange(nimat):
                        # save the parameters for each of the models
                        outstring = ""
                        dummy.read_image(stack, im, True)
                        param3d = dummy.get_attr('xform.projection')
                        g = dummy.get_attr("group")
                        # retrieve alignments in EMAN-format
                        pE = param3d.get_params('eman')
                        outstring += "%f\t%f\t%f\t%f\t%f\t%i\n" % (
                            pE["az"], pE["alt"], pE["phi"], pE["tx"], pE["ty"],
                            g)
                        foutput.write(outstring)
                    foutput.close()
                del dummy
            mpi_barrier(MPI_COMM_WORLD)


#	mpi_finalize()

    if myid == main_node: print_end_msg("ali3d_MPI")
Beispiel #17
0
def do_volume_mrk02(ref_data):
    """
		data - projections (scattered between cpus) or the volume.  If volume, just do the volume processing
		options - the same for all cpus
		return - volume the same for all cpus
	"""
    from EMAN2 import Util
    from mpi import mpi_comm_rank, mpi_comm_size, MPI_COMM_WORLD
    from filter import filt_table
    from reconstruction import recons3d_4nn_MPI, recons3d_4nn_ctf_MPI
    from utilities import bcast_EMData_to_all, bcast_number_to_all, model_blank
    from fundamentals import rops_table, fftip, fft
    import types

    # Retrieve the function specific input arguments from ref_data
    data = ref_data[0]
    Tracker = ref_data[1]
    iter = ref_data[2]
    mpi_comm = ref_data[3]

    # # For DEBUG
    # print "Type of data %s" % (type(data))
    # print "Type of Tracker %s" % (type(Tracker))
    # print "Type of iter %s" % (type(iter))
    # print "Type of mpi_comm %s" % (type(mpi_comm))

    if (mpi_comm == None): mpi_comm = MPI_COMM_WORLD
    myid = mpi_comm_rank(mpi_comm)
    nproc = mpi_comm_size(mpi_comm)

    try:
        local_filter = Tracker["local_filter"]
    except:
        local_filter = False
    #=========================================================================
    # volume reconstruction
    if (type(data) == types.ListType):
        if Tracker["constants"]["CTF"]:
            vol = recons3d_4nn_ctf_MPI(myid, data, Tracker["constants"]["snr"], \
              symmetry=Tracker["constants"]["sym"], npad=Tracker["constants"]["npad"], mpi_comm=mpi_comm, smearstep = Tracker["smearstep"])
        else:
            vol = recons3d_4nn_MPI    (myid, data,\
              symmetry=Tracker["constants"]["sym"], npad=Tracker["constants"]["npad"], mpi_comm=mpi_comm)
    else:
        vol = data

    if myid == 0:
        from morphology import threshold
        from filter import filt_tanl, filt_btwl
        from utilities import model_circle, get_im
        import types
        nx = vol.get_xsize()
        if (Tracker["constants"]["mask3D"] == None):
            mask3D = model_circle(
                int(Tracker["constants"]["radius"] * float(nx) /
                    float(Tracker["constants"]["nnxo"]) + 0.5), nx, nx, nx)
        elif (Tracker["constants"]["mask3D"] == "auto"):
            from utilities import adaptive_mask
            mask3D = adaptive_mask(vol)
        else:
            if (type(Tracker["constants"]["mask3D"]) == types.StringType):
                mask3D = get_im(Tracker["constants"]["mask3D"])
            else:
                mask3D = (Tracker["constants"]["mask3D"]).copy()
            nxm = mask3D.get_xsize()
            if (nx != nxm):
                from fundamentals import rot_shift3D
                mask3D = Util.window(
                    rot_shift3D(mask3D, scale=float(nx) / float(nxm)), nx, nx,
                    nx)
                nxm = mask3D.get_xsize()
                assert (nx == nxm)

        stat = Util.infomask(vol, mask3D, False)
        vol -= stat[0]
        Util.mul_scalar(vol, 1.0 / stat[1])
        vol = threshold(vol)
        Util.mul_img(vol, mask3D)
        if (Tracker["PWadjustment"]):
            from utilities import read_text_file, write_text_file
            rt = read_text_file(Tracker["PWadjustment"])
            fftip(vol)
            ro = rops_table(vol)
            #  Here unless I am mistaken it is enough to take the beginning of the reference pw.
            for i in xrange(1, len(ro)):
                ro[i] = (rt[i] / ro[i])**Tracker["upscale"]
            #write_text_file(rops_table(filt_table( vol, ro),1),"foo.txt")
            if Tracker["constants"]["sausage"]:
                ny = vol.get_ysize()
                y = float(ny)
                from math import exp
                for i in xrange(len(ro)):                    ro[i] *= \
(1.0+1.0*exp(-(((i/y/Tracker["constants"]["pixel_size"])-0.10)/0.025)**2)+1.0*exp(-(((i/y/Tracker["constants"]["pixel_size"])-0.215)/0.025)**2))

            if local_filter:
                # skip low-pass filtration
                vol = fft(filt_table(vol, ro))
            else:
                if (type(Tracker["lowpass"]) == types.ListType):
                    vol = fft(
                        filt_table(filt_table(vol, Tracker["lowpass"]), ro))
                else:
                    vol = fft(
                        filt_table(
                            filt_tanl(vol, Tracker["lowpass"],
                                      Tracker["falloff"]), ro))
            del ro
        else:
            if Tracker["constants"]["sausage"]:
                ny = vol.get_ysize()
                y = float(ny)
                ro = [0.0] * (ny // 2 + 2)
                from math import exp
                for i in xrange(len(ro)):                    ro[i] = \
(1.0+1.0*exp(-(((i/y/Tracker["constants"]["pixel_size"])-0.10)/0.025)**2)+1.0*exp(-(((i/y/Tracker["constants"]["pixel_size"])-0.215)/0.025)**2))
                fftip(vol)
                filt_table(vol, ro)
                del ro
            if not local_filter:
                if (type(Tracker["lowpass"]) == types.ListType):
                    vol = filt_table(vol, Tracker["lowpass"])
                else:
                    vol = filt_tanl(vol, Tracker["lowpass"],
                                    Tracker["falloff"])
            if Tracker["constants"]["sausage"]: vol = fft(vol)

    if local_filter:
        from morphology import binarize
        if (myid == 0): nx = mask3D.get_xsize()
        else: nx = 0
        nx = bcast_number_to_all(nx, source_node=0)
        #  only main processor needs the two input volumes
        if (myid == 0):
            mask = binarize(mask3D, 0.5)
            locres = get_im(Tracker["local_filter"])
            lx = locres.get_xsize()
            if (lx != nx):
                if (lx < nx):
                    from fundamentals import fdecimate, rot_shift3D
                    mask = Util.window(
                        rot_shift3D(mask, scale=float(lx) / float(nx)), lx, lx,
                        lx)
                    vol = fdecimate(vol, lx, lx, lx)
                else:
                    ERROR("local filter cannot be larger than input volume",
                          "user function", 1)
            stat = Util.infomask(vol, mask, False)
            vol -= stat[0]
            Util.mul_scalar(vol, 1.0 / stat[1])
        else:
            lx = 0
            locres = model_blank(1, 1, 1)
            vol = model_blank(1, 1, 1)
        lx = bcast_number_to_all(lx, source_node=0)
        if (myid != 0): mask = model_blank(lx, lx, lx)
        bcast_EMData_to_all(mask, myid, 0, comm=mpi_comm)
        from filter import filterlocal
        vol = filterlocal(locres, vol, mask, Tracker["falloff"], myid, 0,
                          nproc)

        if myid == 0:
            if (lx < nx):
                from fundamentals import fpol
                vol = fpol(vol, nx, nx, nx)
            vol = threshold(vol)
            vol = filt_btwl(vol, 0.38, 0.5)  #  This will have to be corrected.
            Util.mul_img(vol, mask3D)
            del mask3D
            # vol.write_image('toto%03d.hdf'%iter)
        else:
            vol = model_blank(nx, nx, nx)
    else:
        if myid == 0:
            #from utilities import write_text_file
            #write_text_file(rops_table(vol,1),"goo.txt")
            stat = Util.infomask(vol, mask3D, False)
            vol -= stat[0]
            Util.mul_scalar(vol, 1.0 / stat[1])
            vol = threshold(vol)
            vol = filt_btwl(vol, 0.38, 0.5)  #  This will have to be corrected.
            Util.mul_img(vol, mask3D)
            del mask3D
            # vol.write_image('toto%03d.hdf'%iter)
    # broadcast volume
    bcast_EMData_to_all(vol, myid, 0, comm=mpi_comm)
    #=========================================================================
    return vol
Beispiel #18
0
def main():
	progname = os.path.basename(sys.argv[0])
	usage = progname + " stack outdir <maskfile> --ir=inner_radius --ou=outer_radius --rs=ring_step --xr=x_range --yr=y_range --ts=translation_step --dst=delta --center=center --maxit=max_iteration --CTF --snr=SNR --Fourvar=Fourier_variance --Ng=group_number --Function=user_function_name --CUDA --GPUID --MPI"
	parser = OptionParser(usage,version=SPARXVERSION)
	parser.add_option("--ir",       type="float",  default=1,             help="inner radius for rotational correlation > 0 (set to 1)")
	parser.add_option("--ou",       type="float",  default=-1,            help="outer radius for rotational correlation < nx/2-1 (set to the radius of the particle)")
	parser.add_option("--rs",       type="float",  default=1,             help="step between rings in rotational correlation > 0 (set to 1)" ) 
	parser.add_option("--xr",       type="string", default="4 2 1 1",     help="range for translation search in x direction, search is +/xr ")
	parser.add_option("--yr",       type="string", default="-1",          help="range for translation search in y direction, search is +/yr ")
	parser.add_option("--ts",       type="string", default="2 1 0.5 0.25",help="step of translation search in both directions")
	parser.add_option("--nomirror", action="store_true", default=False,   help="Disable checking mirror orientations of images (default False)")
	parser.add_option("--dst",      type="float",  default=0.0,           help="delta")
	parser.add_option("--center",   type="float",  default=-1,            help="-1.average center method; 0.not centered; 1.phase approximation; 2.cc with Gaussian function; 3.cc with donut-shaped image 4.cc with user-defined reference 5.cc with self-rotated average")
	parser.add_option("--maxit",    type="float",  default=0,             help="maximum number of iterations (0 means the maximum iterations is 10, but it will automatically stop should the criterion falls")
	parser.add_option("--CTF",      action="store_true", default=False,   help="use CTF correction during alignment")
	parser.add_option("--snr",      type="float",  default=1.0,           help="signal-to-noise ratio of the data (set to 1.0)")
	parser.add_option("--Fourvar",  action="store_true", default=False,   help="compute Fourier variance")
	#parser.add_option("--Ng",       type="int",          default=-1,      help="number of groups in the new CTF filteration")
	parser.add_option("--function", type="string",       default="ref_ali2d",  help="name of the reference preparation function (default ref_ali2d)")
	#parser.add_option("--CUDA",     action="store_true", default=False,   help="use CUDA program")
	#parser.add_option("--GPUID",    type="string",    default="",         help="ID of GPUs available")
	parser.add_option("--MPI",      action="store_true", default=False,   help="use MPI version ")
	parser.add_option("--rotational", action="store_true", default=False, help="rotational alignment with optional limited in-plane angle, the parameters are: ir, ou, rs, psi_max, mode(F or H), maxit, orient, randomize")
	parser.add_option("--psi_max",  type="float",        default=180.0,   help="psi_max")
	parser.add_option("--mode",     type="string",       default="F",     help="Full or Half rings, default F")
	parser.add_option("--randomize",action="store_true", default=False,   help="randomize initial rotations (suboption of friedel, default False)")
	parser.add_option("--orient",   action="store_true", default=False,   help="orient images such that the average is symmetric about x-axis, for layer lines (suboption of friedel, default False)")
	parser.add_option("--template", type="string",       default=None,    help="2D alignment will be initialized using the template provided (only non-MPI version, default None)")
	parser.add_option("--random_method",   type="string", default="",   help="use SHC or SCF (default standard method)")

	(options, args) = parser.parse_args()

	if len(args) < 2 or len(args) > 3:
		print "usage: " + usage
		print "Please run '" + progname + " -h' for detailed options"
	elif(options.rotational):
		from applications import ali2d_rotationaltop
		global_def.BATCH = True
		ali2d_rotationaltop(args[1], args[0], options.randomize, options.orient, options.ir, options.ou, options.rs, options.psi_max, options.mode, options.maxit)
	else:
		if args[1] == 'None': outdir = None
		else:		          outdir = args[1]

		if len(args) == 2: mask = None
		else:              mask = args[2]
		

		if global_def.CACHE_DISABLE:
			from utilities import disable_bdb_cache
			disable_bdb_cache()
		
		global_def.BATCH = True
		if  options.MPI:
			from applications import ali2d_base
			from mpi import mpi_init, mpi_comm_size, mpi_comm_rank, MPI_COMM_WORLD
			sys.argv = mpi_init(len(sys.argv),sys.argv)

			number_of_proc = mpi_comm_size(MPI_COMM_WORLD)
			myid = mpi_comm_rank(MPI_COMM_WORLD)
			main_node = 0

			if(myid == main_node):
				import subprocess
				from logger import Logger, BaseLogger_Files
				#  Create output directory
				log = Logger(BaseLogger_Files())
				log.prefix = os.path.join(outdir)
				cmd = "mkdir "+log.prefix
				outcome = subprocess.call(cmd, shell=True)
				log.prefix += "/"
			else:
				outcome = 0
				log = None
			from utilities       import bcast_number_to_all
			outcome  = bcast_number_to_all(outcome, source_node = main_node)
			if(outcome == 1):
				ERROR('Output directory exists, please change the name and restart the program', "ali2d_MPI", 1, myid)

			dummy = ali2d_base(args[0], outdir, mask, options.ir, options.ou, options.rs, options.xr, options.yr, \
				options.ts, options.nomirror, options.dst, \
				options.center, options.maxit, options.CTF, options.snr, options.Fourvar, \
				options.function, random_method = options.random_method, log = log, \
				number_of_proc = number_of_proc, myid = myid, main_node = main_node, mpi_comm = MPI_COMM_WORLD,\
				write_headers = True)
		else:
			print " Non-MPI is no more in use, try MPI option, please."
			"""
			from applications import ali2d
			ali2d(args[0], outdir, mask, options.ir, options.ou, options.rs, options.xr, options.yr, \
				options.ts, options.nomirror, options.dst, \
				options.center, options.maxit, options.CTF, options.snr, options.Fourvar, \
				-1, options.function, False, "", options.MPI, \
				options.template, random_method = options.random_method)
	    	"""
		global_def.BATCH = False

		if options.MPI:
			from mpi import mpi_finalize
			mpi_finalize()
Beispiel #19
0
def shiftali_MPI(stack, maskfile=None, maxit=100, CTF=False, snr=1.0, Fourvar=False, search_rng=-1, oneDx=False, search_rng_y=-1):  
	from applications import MPI_start_end
	from utilities    import model_circle, model_blank, get_image, peak_search, get_im
	from utilities    import reduce_EMData_to_root, bcast_EMData_to_all, send_attr_dict, file_type, bcast_number_to_all, bcast_list_to_all
	from statistics   import varf2d_MPI
	from fundamentals import fft, ccf, rot_shift3D, rot_shift2D
	from utilities    import get_params2D, set_params2D
	from utilities    import print_msg, print_begin_msg, print_end_msg
	import os
	import sys
	from mpi 	  	  import mpi_init, mpi_comm_size, mpi_comm_rank, MPI_COMM_WORLD
	from mpi 	  	  import mpi_reduce, mpi_bcast, mpi_barrier, mpi_gatherv
	from mpi 	  	  import MPI_SUM, MPI_FLOAT, MPI_INT
	from EMAN2	  	  import Processor
	from time         import time	
	
	number_of_proc = mpi_comm_size(MPI_COMM_WORLD)
	myid = mpi_comm_rank(MPI_COMM_WORLD)
	main_node = 0
		
	ftp = file_type(stack)

	if myid == main_node:
		print_begin_msg("shiftali_MPI")

	max_iter=int(maxit)

	if myid == main_node:
		if ftp == "bdb":
			from EMAN2db import db_open_dict
			dummy = db_open_dict(stack, True)
		nima = EMUtil.get_image_count(stack)
	else:
		nima = 0
	nima = bcast_number_to_all(nima, source_node = main_node)
	list_of_particles = range(nima)
	
	image_start, image_end = MPI_start_end(nima, number_of_proc, myid)
	list_of_particles = list_of_particles[image_start: image_end]

	# read nx and ctf_app (if CTF) and broadcast to all nodes
	if myid == main_node:
		ima = EMData()
		ima.read_image(stack, list_of_particles[0], True)
		nx = ima.get_xsize()
		ny = ima.get_ysize()
		if CTF:	ctf_app = ima.get_attr_default('ctf_applied', 2)
		del ima
	else:
		nx = 0
		ny = 0
		if CTF:	ctf_app = 0
	nx = bcast_number_to_all(nx, source_node = main_node)
	ny = bcast_number_to_all(ny, source_node = main_node)
	if CTF:
		ctf_app = bcast_number_to_all(ctf_app, source_node = main_node)
		if ctf_app > 0:	ERROR("data cannot be ctf-applied", "shiftali_MPI", 1, myid)

	if maskfile == None:
		mrad = min(nx, ny)
		mask = model_circle(mrad//2-2, nx, ny)
	else:
		mask = get_im(maskfile)

	if CTF:
		from filter import filt_ctf
		from morphology   import ctf_img
		ctf_abs_sum = EMData(nx, ny, 1, False)
		ctf_2_sum = EMData(nx, ny, 1, False)
	else:
		ctf_2_sum = None

	from global_def import CACHE_DISABLE
	if CACHE_DISABLE:
		data = EMData.read_images(stack, list_of_particles)
	else:
		for i in xrange(number_of_proc):
			if myid == i:
				data = EMData.read_images(stack, list_of_particles)
			if ftp == "bdb": mpi_barrier(MPI_COMM_WORLD)


	for im in xrange(len(data)):
		data[im].set_attr('ID', list_of_particles[im])
		st = Util.infomask(data[im], mask, False)
		data[im] -= st[0]
		if CTF:
			ctf_params = data[im].get_attr("ctf")
			ctfimg = ctf_img(nx, ctf_params, ny=ny)
			Util.add_img2(ctf_2_sum, ctfimg)
			Util.add_img_abs(ctf_abs_sum, ctfimg)

	if CTF:
		reduce_EMData_to_root(ctf_2_sum, myid, main_node)
		reduce_EMData_to_root(ctf_abs_sum, myid, main_node)
	else:  ctf_2_sum = None
	if CTF:
		if myid != main_node:
			del ctf_2_sum
			del ctf_abs_sum
		else:
			temp = EMData(nx, ny, 1, False)
			for i in xrange(0,nx,2):
				for j in xrange(ny):
					temp.set_value_at(i,j,snr)
			Util.add_img(ctf_2_sum, temp)
			del temp

	total_iter = 0

	# apply initial xform.align2d parameters stored in header
	init_params = []
	for im in xrange(len(data)):
		t = data[im].get_attr('xform.align2d')
		init_params.append(t)
		p = t.get_params("2d")
		data[im] = rot_shift2D(data[im], p['alpha'], sx=p['tx'], sy=p['ty'], mirror=p['mirror'], scale=p['scale'])

	# fourier transform all images, and apply ctf if CTF
	for im in xrange(len(data)):
		if CTF:
			ctf_params = data[im].get_attr("ctf")
			data[im] = filt_ctf(fft(data[im]), ctf_params)
		else:
			data[im] = fft(data[im])

	sx_sum=0
	sy_sum=0
	sx_sum_total=0
	sy_sum_total=0
	shift_x = [0.0]*len(data)
	shift_y = [0.0]*len(data)
	ishift_x = [0.0]*len(data)
	ishift_y = [0.0]*len(data)

	for Iter in xrange(max_iter):
		if myid == main_node:
			start_time = time()
			print_msg("Iteration #%4d\n"%(total_iter))
		total_iter += 1
		avg = EMData(nx, ny, 1, False)
		for im in data:  Util.add_img(avg, im)

		reduce_EMData_to_root(avg, myid, main_node)

		if myid == main_node:
			if CTF:
				tavg = Util.divn_filter(avg, ctf_2_sum)
			else:	 tavg = Util.mult_scalar(avg, 1.0/float(nima))
		else:
			tavg = EMData(nx, ny, 1, False)                               

		if Fourvar:
			bcast_EMData_to_all(tavg, myid, main_node)
			vav, rvar = varf2d_MPI(myid, data, tavg, mask, "a", CTF)

		if myid == main_node:
			if Fourvar:
				tavg    = fft(Util.divn_img(fft(tavg), vav))
				vav_r	= Util.pack_complex_to_real(vav)

			# normalize and mask tavg in real space
			tavg = fft(tavg)
			stat = Util.infomask( tavg, mask, False ) 
			tavg -= stat[0]
			Util.mul_img(tavg, mask)
			# For testing purposes: shift tavg to some random place and see if the centering is still correct
			#tavg = rot_shift3D(tavg,sx=3,sy=-4)
			tavg = fft(tavg)

		if Fourvar:  del vav
		bcast_EMData_to_all(tavg, myid, main_node)

		sx_sum=0 
		sy_sum=0 
		if search_rng > 0: nwx = 2*search_rng+1
		else:              nwx = nx
		
		if search_rng_y > 0: nwy = 2*search_rng_y+1
		else:                nwy = ny

		not_zero = 0
		for im in xrange(len(data)):
			if oneDx:
				ctx = Util.window(ccf(data[im],tavg),nwx,1)
				p1  = peak_search(ctx)
				p1_x = -int(p1[0][3])
				ishift_x[im] = p1_x
				sx_sum += p1_x
			else:
				p1 = peak_search(Util.window(ccf(data[im],tavg), nwx,nwy))
				p1_x = -int(p1[0][4])
				p1_y = -int(p1[0][5])
				ishift_x[im] = p1_x
				ishift_y[im] = p1_y
				sx_sum += p1_x
				sy_sum += p1_y

			if not_zero == 0:
				if (not(ishift_x[im] == 0.0)) or (not(ishift_y[im] == 0.0)):
					not_zero = 1

		sx_sum = mpi_reduce(sx_sum, 1, MPI_INT, MPI_SUM, main_node, MPI_COMM_WORLD)  

		if not oneDx:
			sy_sum = mpi_reduce(sy_sum, 1, MPI_INT, MPI_SUM, main_node, MPI_COMM_WORLD)

		if myid == main_node:
			sx_sum_total = int(sx_sum[0])
			if not oneDx:
				sy_sum_total = int(sy_sum[0])
		else:
			sx_sum_total = 0	
			sy_sum_total = 0

		sx_sum_total = bcast_number_to_all(sx_sum_total, source_node = main_node)

		if not oneDx:
			sy_sum_total = bcast_number_to_all(sy_sum_total, source_node = main_node)

		sx_ave = round(float(sx_sum_total)/nima)
		sy_ave = round(float(sy_sum_total)/nima)
		for im in xrange(len(data)): 
			p1_x = ishift_x[im] - sx_ave
			p1_y = ishift_y[im] - sy_ave
			params2 = {"filter_type" : Processor.fourier_filter_types.SHIFT, "x_shift" : p1_x, "y_shift" : p1_y, "z_shift" : 0.0}
			data[im] = Processor.EMFourierFilter(data[im], params2)
			shift_x[im] += p1_x
			shift_y[im] += p1_y
		# stop if all shifts are zero
		not_zero = mpi_reduce(not_zero, 1, MPI_INT, MPI_SUM, main_node, MPI_COMM_WORLD)  
		if myid == main_node:
			not_zero_all = int(not_zero[0])
		else:
			not_zero_all = 0
		not_zero_all = bcast_number_to_all(not_zero_all, source_node = main_node)

		if myid == main_node:
			print_msg("Time of iteration = %12.2f\n"%(time()-start_time))
			start_time = time()

		if not_zero_all == 0:  break

	#for im in xrange(len(data)): data[im] = fft(data[im])  This should not be required as only header information is used
	# combine shifts found with the original parameters
	for im in xrange(len(data)):		
		t0 = init_params[im]
		t1 = Transform()
		t1.set_params({"type":"2D","alpha":0,"scale":t0.get_scale(),"mirror":0,"tx":shift_x[im],"ty":shift_y[im]})
		# combine t0 and t1
		tt = t1*t0
		data[im].set_attr("xform.align2d", tt)  

	# write out headers and STOP, under MPI writing has to be done sequentially
	mpi_barrier(MPI_COMM_WORLD)
	par_str = ["xform.align2d", "ID"]
	if myid == main_node:
		from utilities import file_type
		if(file_type(stack) == "bdb"):
			from utilities import recv_attr_dict_bdb
			recv_attr_dict_bdb(main_node, stack, data, par_str, image_start, image_end, number_of_proc)
		else:
			from utilities import recv_attr_dict
			recv_attr_dict(main_node, stack, data, par_str, image_start, image_end, number_of_proc)
		
	else:           send_attr_dict(main_node, data, par_str, image_start, image_end)
	if myid == main_node: print_end_msg("shiftali_MPI")				
Beispiel #20
0
def helicalshiftali_MPI(stack, maskfile=None, maxit=100, CTF=False, snr=1.0, Fourvar=False, search_rng=-1):
	from applications import MPI_start_end
	from utilities    import model_circle, model_blank, get_image, peak_search, get_im, pad
	from utilities    import reduce_EMData_to_root, bcast_EMData_to_all, send_attr_dict, file_type, bcast_number_to_all, bcast_list_to_all
	from statistics   import varf2d_MPI
	from fundamentals import fft, ccf, rot_shift3D, rot_shift2D, fshift
	from utilities    import get_params2D, set_params2D, chunks_distribution
	from utilities    import print_msg, print_begin_msg, print_end_msg
	import os
	import sys
	from mpi 	  	  import mpi_init, mpi_comm_size, mpi_comm_rank, MPI_COMM_WORLD
	from mpi 	  	  import mpi_reduce, mpi_bcast, mpi_barrier, mpi_gatherv
	from mpi 	  	  import MPI_SUM, MPI_FLOAT, MPI_INT
	from time         import time	
	from pixel_error  import ordersegments
	from math         import sqrt, atan2, tan, pi
	
	nproc = mpi_comm_size(MPI_COMM_WORLD)
	myid = mpi_comm_rank(MPI_COMM_WORLD)
	main_node = 0
		
	ftp = file_type(stack)

	if myid == main_node:
		print_begin_msg("helical-shiftali_MPI")

	max_iter=int(maxit)
	if( myid == main_node):
		infils = EMUtil.get_all_attributes(stack, "filament")
		ptlcoords = EMUtil.get_all_attributes(stack, 'ptcl_source_coord')
		filaments = ordersegments(infils, ptlcoords)
		total_nfils = len(filaments)
		inidl = [0]*total_nfils
		for i in xrange(total_nfils):  inidl[i] = len(filaments[i])
		linidl = sum(inidl)
		nima = linidl
		tfilaments = []
		for i in xrange(total_nfils):  tfilaments += filaments[i]
		del filaments
	else:
		total_nfils = 0
		linidl = 0
	total_nfils = bcast_number_to_all(total_nfils, source_node = main_node)
	if myid != main_node:
		inidl = [-1]*total_nfils
	inidl = bcast_list_to_all(inidl, myid, source_node = main_node)
	linidl = bcast_number_to_all(linidl, source_node = main_node)
	if myid != main_node:
		tfilaments = [-1]*linidl
	tfilaments = bcast_list_to_all(tfilaments, myid, source_node = main_node)
	filaments = []
	iendi = 0
	for i in xrange(total_nfils):
		isti = iendi
		iendi = isti+inidl[i]
		filaments.append(tfilaments[isti:iendi])
	del tfilaments,inidl

	if myid == main_node:
		print_msg( "total number of filaments: %d"%total_nfils)
	if total_nfils< nproc:
		ERROR('number of CPUs (%i) is larger than the number of filaments (%i), please reduce the number of CPUs used'%(nproc, total_nfils), "ehelix_MPI", 1,myid)

	#  balanced load
	temp = chunks_distribution([[len(filaments[i]), i] for i in xrange(len(filaments))], nproc)[myid:myid+1][0]
	filaments = [filaments[temp[i][1]] for i in xrange(len(temp))]
	nfils     = len(filaments)

	#filaments = [[0,1]]
	#print "filaments",filaments
	list_of_particles = []
	indcs = []
	k = 0
	for i in xrange(nfils):
		list_of_particles += filaments[i]
		k1 = k+len(filaments[i])
		indcs.append([k,k1])
		k = k1
	data = EMData.read_images(stack, list_of_particles)
	ldata = len(data)
	print "ldata=", ldata
	nx = data[0].get_xsize()
	ny = data[0].get_ysize()
	if maskfile == None:
		mrad = min(nx, ny)//2-2
		mask = pad( model_blank(2*mrad+1, ny, 1, 1.0), nx, ny, 1, 0.0)
	else:
		mask = get_im(maskfile)

	# apply initial xform.align2d parameters stored in header
	init_params = []
	for im in xrange(ldata):
		t = data[im].get_attr('xform.align2d')
		init_params.append(t)
		p = t.get_params("2d")
		data[im] = rot_shift2D(data[im], p['alpha'], p['tx'], p['ty'], p['mirror'], p['scale'])

	if CTF:
		from filter import filt_ctf
		from morphology   import ctf_img
		ctf_abs_sum = EMData(nx, ny, 1, False)
		ctf_2_sum = EMData(nx, ny, 1, False)
	else:
		ctf_2_sum = None
		ctf_abs_sum = None



	from utilities import info

	for im in xrange(ldata):
		data[im].set_attr('ID', list_of_particles[im])
		st = Util.infomask(data[im], mask, False)
		data[im] -= st[0]
		if CTF:
			ctf_params = data[im].get_attr("ctf")
			qctf = data[im].get_attr("ctf_applied")
			if qctf == 0:
				data[im] = filt_ctf(fft(data[im]), ctf_params)
				data[im].set_attr('ctf_applied', 1)
			elif qctf != 1:
				ERROR('Incorrectly set qctf flag', "helicalshiftali_MPI", 1,myid)
			ctfimg = ctf_img(nx, ctf_params, ny=ny)
			Util.add_img2(ctf_2_sum, ctfimg)
			Util.add_img_abs(ctf_abs_sum, ctfimg)
		else:  data[im] = fft(data[im])

	del list_of_particles		

	if CTF:
		reduce_EMData_to_root(ctf_2_sum, myid, main_node)
		reduce_EMData_to_root(ctf_abs_sum, myid, main_node)
	if CTF:
		if myid != main_node:
			del ctf_2_sum
			del ctf_abs_sum
		else:
			temp = EMData(nx, ny, 1, False)
			tsnr = 1./snr
			for i in xrange(0,nx+2,2):
				for j in xrange(ny):
					temp.set_value_at(i,j,tsnr)
					temp.set_value_at(i+1,j,0.0)
			#info(ctf_2_sum)
			Util.add_img(ctf_2_sum, temp)
			#info(ctf_2_sum)
			del temp

	total_iter = 0
	shift_x = [0.0]*ldata

	for Iter in xrange(max_iter):
		if myid == main_node:
			start_time = time()
			print_msg("Iteration #%4d\n"%(total_iter))
		total_iter += 1
		avg = EMData(nx, ny, 1, False)
		for im in xrange(ldata):
			Util.add_img(avg, fshift(data[im], shift_x[im]))

		reduce_EMData_to_root(avg, myid, main_node)

		if myid == main_node:
			if CTF:  tavg = Util.divn_filter(avg, ctf_2_sum)
			else:    tavg = Util.mult_scalar(avg, 1.0/float(nima))
		else:
			tavg = model_blank(nx,ny)

		if Fourvar:
			bcast_EMData_to_all(tavg, myid, main_node)
			vav, rvar = varf2d_MPI(myid, data, tavg, mask, "a", CTF)

		if myid == main_node:
			if Fourvar:
				tavg    = fft(Util.divn_img(fft(tavg), vav))
				vav_r	= Util.pack_complex_to_real(vav)
			# normalize and mask tavg in real space
			tavg = fft(tavg)
			stat = Util.infomask( tavg, mask, False )
			tavg -= stat[0]
			Util.mul_img(tavg, mask)
			tavg.write_image("tavg.hdf",Iter)
			# For testing purposes: shift tavg to some random place and see if the centering is still correct
			#tavg = rot_shift3D(tavg,sx=3,sy=-4)

		if Fourvar:  del vav
		bcast_EMData_to_all(tavg, myid, main_node)
		tavg = fft(tavg)

		sx_sum = 0.0
		nxc = nx//2
		
		for ifil in xrange(nfils):
			"""
			# Calculate filament average
			avg = EMData(nx, ny, 1, False)
			filnima = 0
			for im in xrange(indcs[ifil][0], indcs[ifil][1]):
				Util.add_img(avg, data[im])
				filnima += 1
			tavg = Util.mult_scalar(avg, 1.0/float(filnima))
			"""
			# Calculate 1D ccf between each segment and filament average
			nsegms = indcs[ifil][1]-indcs[ifil][0]
			ctx = [None]*nsegms
			pcoords = [None]*nsegms
			for im in xrange(indcs[ifil][0], indcs[ifil][1]):
				ctx[im-indcs[ifil][0]] = Util.window(ccf(tavg, data[im]), nx, 1)
				pcoords[im-indcs[ifil][0]] = data[im].get_attr('ptcl_source_coord')
				#ctx[im-indcs[ifil][0]].write_image("ctx.hdf",im-indcs[ifil][0])
				#print "  CTX  ",myid,im,Util.infomask(ctx[im-indcs[ifil][0]], None, True)
			# search for best x-shift
			cents = nsegms//2
			
			dst = sqrt(max((pcoords[cents][0] - pcoords[0][0])**2 + (pcoords[cents][1] - pcoords[0][1])**2, (pcoords[cents][0] - pcoords[-1][0])**2 + (pcoords[cents][1] - pcoords[-1][1])**2))
			maxincline = atan2(ny//2-2-float(search_rng),dst)
			kang = int(dst*tan(maxincline)+0.5)
			#print  "  settings ",nsegms,cents,dst,search_rng,maxincline,kang
			
			# ## C code for alignment. @ming
 			results = [0.0]*3;
 			results = Util.helixshiftali(ctx, pcoords, nsegms, maxincline, kang, search_rng,nxc)
			sib = int(results[0])
 			bang = results[1]
 			qm = results[2]
			#print qm, sib, bang
			
			# qm = -1.e23	
# 				
# 			for six in xrange(-search_rng, search_rng+1,1):
# 				q0 = ctx[cents].get_value_at(six+nxc)
# 				for incline in xrange(kang+1):
# 					qt = q0
# 					qu = q0
# 					if(kang>0):  tang = tan(maxincline/kang*incline)
# 					else:        tang = 0.0
# 					for kim in xrange(cents+1,nsegms):
# 						dst = sqrt((pcoords[cents][0] - pcoords[kim][0])**2 + (pcoords[cents][1] - pcoords[kim][1])**2)
# 						xl = dst*tang+six+nxc
# 						ixl = int(xl)
# 						dxl = xl - ixl
# 						#print "  A  ", ifil,six,incline,kim,xl,ixl,dxl
# 						qt += (1.0-dxl)*ctx[kim].get_value_at(ixl) + dxl*ctx[kim].get_value_at(ixl+1)
# 						xl = -dst*tang+six+nxc
# 						ixl = int(xl)
# 						dxl = xl - ixl
# 						qu += (1.0-dxl)*ctx[kim].get_value_at(ixl) + dxl*ctx[kim].get_value_at(ixl+1)
# 					for kim in xrange(cents):
# 						dst = sqrt((pcoords[cents][0] - pcoords[kim][0])**2 + (pcoords[cents][1] - pcoords[kim][1])**2)
# 						xl = -dst*tang+six+nxc
# 						ixl = int(xl)
# 						dxl = xl - ixl
# 						qt += (1.0-dxl)*ctx[kim].get_value_at(ixl) + dxl*ctx[kim].get_value_at(ixl+1)
# 						xl =  dst*tang+six+nxc
# 						ixl = int(xl)
# 						dxl = xl - ixl
# 						qu += (1.0-dxl)*ctx[kim].get_value_at(ixl) + dxl*ctx[kim].get_value_at(ixl+1)
# 					if( qt > qm ):
# 						qm = qt
# 						sib = six
# 						bang = tang
# 					if( qu > qm ):
# 						qm = qu
# 						sib = six
# 						bang = -tang
					#if incline == 0:  print  "incline = 0  ",six,tang,qt,qu
			#print qm,six,sib,bang
			#print " got results   ",indcs[ifil][0], indcs[ifil][1], ifil,myid,qm,sib,tang,bang,len(ctx),Util.infomask(ctx[0], None, True)
			for im in xrange(indcs[ifil][0], indcs[ifil][1]):
				kim = im-indcs[ifil][0]
				dst = sqrt((pcoords[cents][0] - pcoords[kim][0])**2 + (pcoords[cents][1] - pcoords[kim][1])**2)
				if(kim < cents):  xl = -dst*bang+sib
				else:             xl =  dst*bang+sib
				shift_x[im] = xl
							
			# Average shift
			sx_sum += shift_x[indcs[ifil][0]+cents]
			
			
		# #print myid,sx_sum,total_nfils
		sx_sum = mpi_reduce(sx_sum, 1, MPI_FLOAT, MPI_SUM, main_node, MPI_COMM_WORLD)
		if myid == main_node:
			sx_sum = float(sx_sum[0])/total_nfils
			print_msg("Average shift  %6.2f\n"%(sx_sum))
		else:
			sx_sum = 0.0
		sx_sum = 0.0
		sx_sum = bcast_number_to_all(sx_sum, source_node = main_node)
		for im in xrange(ldata):
			shift_x[im] -= sx_sum
			#print  "   %3d  %6.3f"%(im,shift_x[im])
		#exit()


			
	# combine shifts found with the original parameters
	for im in xrange(ldata):		
		t1 = Transform()
		##import random
		##shix=random.randint(-10, 10)
		##t1.set_params({"type":"2D","tx":shix})
		t1.set_params({"type":"2D","tx":shift_x[im]})
		# combine t0 and t1
		tt = t1*init_params[im]
		data[im].set_attr("xform.align2d", tt)
	# write out headers and STOP, under MPI writing has to be done sequentially
	mpi_barrier(MPI_COMM_WORLD)
	par_str = ["xform.align2d", "ID"]
	if myid == main_node:
		from utilities import file_type
		if(file_type(stack) == "bdb"):
			from utilities import recv_attr_dict_bdb
			recv_attr_dict_bdb(main_node, stack, data, par_str, 0, ldata, nproc)
		else:
			from utilities import recv_attr_dict
			recv_attr_dict(main_node, stack, data, par_str, 0, ldata, nproc)
	else:           send_attr_dict(main_node, data, par_str, 0, ldata)
	if myid == main_node: print_end_msg("helical-shiftali_MPI")				
Beispiel #21
0
def do_volume_mrk03(ref_data):
	"""
		data - projections (scattered between cpus) or the volume.  If volume, just do the volume processing
		options - the same for all cpus
		return - volume the same for all cpus
	"""
	from EMAN2          import Util
	from mpi            import mpi_comm_rank, mpi_comm_size, MPI_COMM_WORLD
	from filter         import filt_table
	from reconstruction import recons3d_4nn_MPI, recons3d_4nnw_MPI  #  recons3d_4nn_ctf_MPI
	from utilities      import bcast_EMData_to_all, bcast_number_to_all, model_blank
	from fundamentals import rops_table, fftip, fft
	import types

	# Retrieve the function specific input arguments from ref_data
	data     = ref_data[0]
	Tracker  = ref_data[1]
	iter     = ref_data[2]
	mpi_comm = ref_data[3]
	
	# # For DEBUG
	# print "Type of data %s" % (type(data))
	# print "Type of Tracker %s" % (type(Tracker))
	# print "Type of iter %s" % (type(iter))
	# print "Type of mpi_comm %s" % (type(mpi_comm))
	
	if(mpi_comm == None):  mpi_comm = MPI_COMM_WORLD
	myid  = mpi_comm_rank(mpi_comm)
	nproc = mpi_comm_size(mpi_comm)
	
	try:     local_filter = Tracker["local_filter"]
	except:  local_filter = False
	#=========================================================================
	# volume reconstruction
	if( type(data) == types.ListType ):
		if Tracker["constants"]["CTF"]:
			#vol = recons3d_4nn_ctf_MPI(myid, data, Tracker["constants"]["snr"], \
			#		symmetry=Tracker["constants"]["sym"], npad=Tracker["constants"]["npad"], mpi_comm=mpi_comm, smearstep = Tracker["smearstep"])
			vol = recons3d_4nnw_MPI(myid, data, Tracker["bckgnoise"], Tracker["constants"]["snr"], \
				symmetry=Tracker["constants"]["sym"], npad=Tracker["constants"]["npad"], mpi_comm=mpi_comm, smearstep = Tracker["smearstep"])
		else:
			vol = recons3d_4nn_MPI    (myid, data,\
					symmetry=Tracker["constants"]["sym"], npad=Tracker["constants"]["npad"], mpi_comm=mpi_comm)
	else:
		vol = data

	if myid == 0:
		from morphology import threshold
		from filter     import filt_tanl, filt_btwl
		from utilities  import model_circle, get_im
		import types
		nx = vol.get_xsize()
		if(Tracker["constants"]["mask3D"] == None):
			mask3D = model_circle(int(Tracker["constants"]["radius"]*float(nx)/float(Tracker["constants"]["nnxo"])+0.5), nx, nx, nx)
		elif(Tracker["constants"]["mask3D"] == "auto"):
			from utilities import adaptive_mask
			mask3D = adaptive_mask(vol)
		else:
			if( type(Tracker["constants"]["mask3D"]) == types.StringType ):  mask3D = get_im(Tracker["constants"]["mask3D"])
			else:  mask3D = (Tracker["constants"]["mask3D"]).copy()
			nxm = mask3D.get_xsize()
			if( nx != nxm):
				from fundamentals import rot_shift3D
				mask3D = Util.window(rot_shift3D(mask3D,scale=float(nx)/float(nxm)),nx,nx,nx)
				nxm = mask3D.get_xsize()
				assert(nx == nxm)

		stat = Util.infomask(vol, mask3D, False)
		vol -= stat[0]
		Util.mul_scalar(vol, 1.0/stat[1])
		vol = threshold(vol)
		Util.mul_img(vol, mask3D)
		if not local_filter:
			if( type(Tracker["lowpass"]) == types.ListType ):
				vol = filt_table(vol, Tracker["lowpass"])
			else:
				vol = filt_tanl(vol, Tracker["lowpass"], Tracker["falloff"])

	if local_filter:
		from morphology import binarize
		if(myid == 0): nx = mask3D.get_xsize()
		else:  nx = 0
		nx = bcast_number_to_all(nx, source_node = 0)
		#  only main processor needs the two input volumes
		if(myid == 0):
			mask = binarize(mask3D, 0.5)
			locres = get_im(Tracker["local_filter"])
			lx = locres.get_xsize()
			if(lx != nx):
				if(lx < nx):
					from fundamentals import fdecimate, rot_shift3D
					mask = Util.window(rot_shift3D(mask,scale=float(lx)/float(nx)),lx,lx,lx)
					vol = fdecimate(vol, lx,lx,lx)
				else:  ERROR("local filter cannot be larger than input volume","user function",1)
			stat = Util.infomask(vol, mask, False)
			vol -= stat[0]
			Util.mul_scalar(vol, 1.0/stat[1])
		else:
			lx = 0
			locres = model_blank(1,1,1)
			vol = model_blank(1,1,1)
		lx = bcast_number_to_all(lx, source_node = 0)
		if( myid != 0 ):  mask = model_blank(lx,lx,lx)
		bcast_EMData_to_all(mask, myid, 0, comm=mpi_comm)
		from filter import filterlocal
		vol = filterlocal( locres, vol, mask, Tracker["falloff"], myid, 0, nproc)

		if myid == 0:
			if(lx < nx):
				from fundamentals import fpol
				vol = fpol(vol, nx,nx,nx)
			vol = threshold(vol)
			Util.mul_img(vol, mask3D)
			del mask3D
			# vol.write_image('toto%03d.hdf'%iter)
		else:
			vol = model_blank(nx,nx,nx)
	"""
	else:
		if myid == 0:
			#from utilities import write_text_file
			#write_text_file(rops_table(vol,1),"goo.txt")
			stat = Util.infomask(vol, mask3D, False)
			vol -= stat[0]
			Util.mul_scalar(vol, 1.0/stat[1])
			vol = threshold(vol)
			Util.mul_img(vol, mask3D)
			del mask3D
			# vol.write_image('toto%03d.hdf'%iter)
	"""
	# broadcast volume
	bcast_EMData_to_all(vol, myid, 0, comm=mpi_comm)
	#=========================================================================
	return vol
Beispiel #22
0
def prepare_refringsHelical( volft, kb, nx, delta, ref_a, oplane, sym, numr, MPI=False):
	"""
	prepare projections for helical processing
	rotation 180 degrees inplane & specified out-of-plane
	"""
	from alignment import ringwe, Applyws
	from projection   import prgs
	from math	 import sin, cos, pi
	from applications import MPI_start_end
	from utilities      import bcast_list_to_all, bcast_number_to_all, reduce_EMData_to_root, bcast_EMData_to_all 
	import re

	# convert csym to integer:
	sym = int(re.sub("\D", "", sym))
	# generate list of Eulerian angles for reference projections
	#  phi, theta, psi
	mode = "F"
	ref_angles = []
	inplane=int((179.99/sym)/delta) + 1
	# first create 0 and positive out-of-plane tilts
	i = 0
	while i < oplane:
		for j in xrange(inplane):
			t = j*delta
			ref_angles.append([t,90.0+i,90.0])
		i+=delta
	# negative out of plane rotation
	i = -(delta)
	while i > -(oplane):
		for j in xrange(inplane):
			t = j*delta
			ref_angles.append([t,90.0+i,90.0])
		i-=delta
	
	wr_four  = ringwe(numr, mode)
	cnx = nx//2 + 1
	cny = nx//2 + 1
	qv = pi/180.
	num_ref = len(ref_angles)

	if MPI:
		from mpi import mpi_comm_rank, mpi_comm_size, MPI_COMM_WORLD
		myid = mpi_comm_rank( MPI_COMM_WORLD )
		ncpu = mpi_comm_size( MPI_COMM_WORLD )
	else:
		ncpu = 1
		myid = 0
	from applications import MPI_start_end
	ref_start,ref_end = MPI_start_end( num_ref, ncpu, myid )

	refrings = []     # list of (image objects) reference projections in Fourier representation

	sizex = numr[ len(numr)-2 ] + numr[ len(numr)-1 ] - 1

	for i in xrange(num_ref):
		prjref = EMData()
		prjref.set_size(sizex, 1, 1)
		refrings.append(prjref)

	for i in xrange(ref_start, ref_end):
		prjref = prgs(volft, kb, [ref_angles[i][0], ref_angles[i][1], ref_angles[i][2], 0.0, 0.0])
		cimage = Util.Polar2Dm(prjref, cnx, cny, numr, mode)  # currently set to quadratic....
		Util.Normalize_ring(cimage, numr)

		Util.Frngs(cimage, numr)
		Applyws(cimage, numr, wr_four)
		refrings[i] = cimage

	if MPI:
		from utilities import bcast_EMData_to_all
		for i in xrange(num_ref):
			for j in xrange(ncpu):
				ref_start,ref_end = MPI_start_end(num_ref,ncpu,j)
				if i >= ref_start and i < ref_end: rootid = j

			bcast_EMData_to_all( refrings[i], myid, rootid )
	for i in xrange(len(ref_angles)):
		n1 = sin(ref_angles[i][1]*qv)*cos(ref_angles[i][0]*qv)
		n2 = sin(ref_angles[i][1]*qv)*sin(ref_angles[i][0]*qv)
		n3 = cos(ref_angles[i][1]*qv)
		refrings[i].set_attr_dict( {"n1":n1, "n2":n2, "n3":n3} )
		refrings[i].set_attr("phi", ref_angles[i][0])
		refrings[i].set_attr("theta", ref_angles[i][1])
		refrings[i].set_attr("psi", ref_angles[i][2])

	return refrings
Beispiel #23
0
def main():

    progname = os.path.basename(sys.argv[0])
    usage = progname + " proj_stack output_averages --MPI"
    parser = OptionParser(usage, version=SPARXVERSION)

    parser.add_option("--img_per_group",
                      type="int",
                      default=100,
                      help="number of images per group")
    parser.add_option("--radius",
                      type="int",
                      default=-1,
                      help="radius for alignment")
    parser.add_option(
        "--xr",
        type="string",
        default="2 1",
        help="range for translation search in x direction, search is +/xr")
    parser.add_option(
        "--yr",
        type="string",
        default="-1",
        help=
        "range for translation search in y direction, search is +/yr (default = same as xr)"
    )
    parser.add_option(
        "--ts",
        type="string",
        default="1 0.5",
        help=
        "step size of the translation search in both directions, search is -xr, -xr+ts, 0, xr-ts, xr, can be fractional"
    )
    parser.add_option(
        "--iter",
        type="int",
        default=30,
        help="number of iterations within alignment (default = 30)")
    parser.add_option(
        "--num_ali",
        type="int",
        default=5,
        help="number of alignments performed for stability (default = 5)")
    parser.add_option("--thld_err",
                      type="float",
                      default=1.0,
                      help="threshold of pixel error (default = 1.732)")
    parser.add_option(
        "--grouping",
        type="string",
        default="GRP",
        help=
        "do grouping of projections: PPR - per projection, GRP - different size groups, exclusive (default), GEV - grouping equal size"
    )
    parser.add_option(
        "--delta",
        type="float",
        default=-1.0,
        help="angular step for reference projections (required for GEV method)"
    )
    parser.add_option(
        "--fl",
        type="float",
        default=0.3,
        help="cut-off frequency of hyperbolic tangent low-pass Fourier filter")
    parser.add_option(
        "--aa",
        type="float",
        default=0.2,
        help="fall-off of hyperbolic tangent low-pass Fourier filter")
    parser.add_option("--CTF",
                      action="store_true",
                      default=False,
                      help="Consider CTF correction during the alignment ")
    parser.add_option("--MPI",
                      action="store_true",
                      default=False,
                      help="use MPI version")

    (options, args) = parser.parse_args()

    myid = mpi.mpi_comm_rank(MPI_COMM_WORLD)
    number_of_proc = mpi.mpi_comm_size(MPI_COMM_WORLD)
    main_node = 0

    if len(args) == 2:
        stack = args[0]
        outdir = args[1]
    else:
        sp_global_def.ERROR("Incomplete list of arguments",
                            "sxproj_stability.main",
                            1,
                            myid=myid)
        return
    if not options.MPI:
        sp_global_def.ERROR("Non-MPI not supported!",
                            "sxproj_stability.main",
                            1,
                            myid=myid)
        return

    if sp_global_def.CACHE_DISABLE:
        from sp_utilities import disable_bdb_cache
        disable_bdb_cache()
    sp_global_def.BATCH = True

    img_per_grp = options.img_per_group
    radius = options.radius
    ite = options.iter
    num_ali = options.num_ali
    thld_err = options.thld_err

    xrng = get_input_from_string(options.xr)
    if options.yr == "-1":
        yrng = xrng
    else:
        yrng = get_input_from_string(options.yr)

    step = get_input_from_string(options.ts)

    if myid == main_node:
        nima = EMUtil.get_image_count(stack)
        img = get_image(stack)
        nx = img.get_xsize()
        ny = img.get_ysize()
    else:
        nima = 0
        nx = 0
        ny = 0
    nima = bcast_number_to_all(nima)
    nx = bcast_number_to_all(nx)
    ny = bcast_number_to_all(ny)
    if radius == -1: radius = nx / 2 - 2
    mask = model_circle(radius, nx, nx)

    st = time()
    if options.grouping == "GRP":
        if myid == main_node:
            sxprint("  A  ", myid, "  ", time() - st)
            proj_attr = EMUtil.get_all_attributes(stack, "xform.projection")
            proj_params = []
            for i in range(nima):
                dp = proj_attr[i].get_params("spider")
                phi, theta, psi, s2x, s2y = dp["phi"], dp["theta"], dp[
                    "psi"], -dp["tx"], -dp["ty"]
                proj_params.append([phi, theta, psi, s2x, s2y])

            # Here is where the grouping is done, I didn't put enough annotation in the group_proj_by_phitheta,
            # So I will briefly explain it here
            # proj_list  : Returns a list of list of particle numbers, each list contains img_per_grp particle numbers
            #              except for the last one. Depending on the number of particles left, they will either form a
            #              group or append themselves to the last group
            # angle_list : Also returns a list of list, each list contains three numbers (phi, theta, delta), (phi,
            #              theta) is the projection angle of the center of the group, delta is the range of this group
            # mirror_list: Also returns a list of list, each list contains img_per_grp True or False, which indicates
            #              whether it should take mirror position.
            # In this program angle_list and mirror list are not of interest.

            proj_list_all, angle_list, mirror_list = group_proj_by_phitheta(
                proj_params, img_per_grp=img_per_grp)
            del proj_params
            sxprint("  B  number of groups  ", myid, "  ", len(proj_list_all),
                    time() - st)
        mpi_barrier(MPI_COMM_WORLD)

        # Number of groups, actually there could be one or two more groups, since the size of the remaining group varies
        # we will simply assign them to main node.
        n_grp = nima / img_per_grp - 1

        # Divide proj_list_all equally to all nodes, and becomes proj_list
        proj_list = []
        for i in range(n_grp):
            proc_to_stay = i % number_of_proc
            if proc_to_stay == main_node:
                if myid == main_node: proj_list.append(proj_list_all[i])
            elif myid == main_node:
                mpi_send(len(proj_list_all[i]), 1, MPI_INT, proc_to_stay,
                         SPARX_MPI_TAG_UNIVERSAL, MPI_COMM_WORLD)
                mpi_send(proj_list_all[i], len(proj_list_all[i]), MPI_INT,
                         proc_to_stay, SPARX_MPI_TAG_UNIVERSAL, MPI_COMM_WORLD)
            elif myid == proc_to_stay:
                img_per_grp = mpi_recv(1, MPI_INT, main_node,
                                       SPARX_MPI_TAG_UNIVERSAL, MPI_COMM_WORLD)
                img_per_grp = int(img_per_grp[0])
                temp = mpi_recv(img_per_grp, MPI_INT, main_node,
                                SPARX_MPI_TAG_UNIVERSAL, MPI_COMM_WORLD)
                proj_list.append(list(map(int, temp)))
                del temp
            mpi_barrier(MPI_COMM_WORLD)
        sxprint("  C  ", myid, "  ", time() - st)
        if myid == main_node:
            # Assign the remaining groups to main_node
            for i in range(n_grp, len(proj_list_all)):
                proj_list.append(proj_list_all[i])
            del proj_list_all, angle_list, mirror_list

    #   Compute stability per projection projection direction, equal number assigned, thus overlaps
    elif options.grouping == "GEV":

        if options.delta == -1.0:
            ERROR(
                "Angular step for reference projections is required for GEV method"
            )
            return

        from sp_utilities import even_angles, nearestk_to_refdir, getvec
        refproj = even_angles(options.delta)
        img_begin, img_end = MPI_start_end(len(refproj), number_of_proc, myid)
        # Now each processor keeps its own share of reference projections
        refprojdir = refproj[img_begin:img_end]
        del refproj

        ref_ang = [0.0] * (len(refprojdir) * 2)
        for i in range(len(refprojdir)):
            ref_ang[i * 2] = refprojdir[0][0]
            ref_ang[i * 2 + 1] = refprojdir[0][1] + i * 0.1

        sxprint("  A  ", myid, "  ", time() - st)
        proj_attr = EMUtil.get_all_attributes(stack, "xform.projection")
        #  the solution below is very slow, do not use it unless there is a problem with the i/O
        """
		for i in xrange(number_of_proc):
			if myid == i:
				proj_attr = EMUtil.get_all_attributes(stack, "xform.projection")
			mpi_barrier(MPI_COMM_WORLD)
		"""
        sxprint("  B  ", myid, "  ", time() - st)

        proj_ang = [0.0] * (nima * 2)
        for i in range(nima):
            dp = proj_attr[i].get_params("spider")
            proj_ang[i * 2] = dp["phi"]
            proj_ang[i * 2 + 1] = dp["theta"]
        sxprint("  C  ", myid, "  ", time() - st)
        asi = Util.nearestk_to_refdir(proj_ang, ref_ang, img_per_grp)
        del proj_ang, ref_ang
        proj_list = []
        for i in range(len(refprojdir)):
            proj_list.append(asi[i * img_per_grp:(i + 1) * img_per_grp])
        del asi
        sxprint("  D  ", myid, "  ", time() - st)
        #from sys import exit
        #exit()

    #   Compute stability per projection
    elif options.grouping == "PPR":
        sxprint("  A  ", myid, "  ", time() - st)
        proj_attr = EMUtil.get_all_attributes(stack, "xform.projection")
        sxprint("  B  ", myid, "  ", time() - st)
        proj_params = []
        for i in range(nima):
            dp = proj_attr[i].get_params("spider")
            phi, theta, psi, s2x, s2y = dp["phi"], dp["theta"], dp[
                "psi"], -dp["tx"], -dp["ty"]
            proj_params.append([phi, theta, psi, s2x, s2y])
        img_begin, img_end = MPI_start_end(nima, number_of_proc, myid)
        sxprint("  C  ", myid, "  ", time() - st)
        from sp_utilities import nearest_proj
        proj_list, mirror_list = nearest_proj(
            proj_params, img_per_grp,
            list(range(img_begin, img_begin + 1)))  #range(img_begin, img_end))
        refprojdir = proj_params[img_begin:img_end]
        del proj_params, mirror_list
        sxprint("  D  ", myid, "  ", time() - st)

    else:
        ERROR("Incorrect projection grouping option")
        return

    ###########################################################################################################
    # Begin stability test
    from sp_utilities import get_params_proj, read_text_file
    #if myid == 0:
    #	from utilities import read_text_file
    #	proj_list[0] = map(int, read_text_file("lggrpp0.txt"))

    from sp_utilities import model_blank
    aveList = [model_blank(nx, ny)] * len(proj_list)
    if options.grouping == "GRP":
        refprojdir = [[0.0, 0.0, -1.0]] * len(proj_list)
    for i in range(len(proj_list)):
        sxprint("  E  ", myid, "  ", time() - st)
        class_data = EMData.read_images(stack, proj_list[i])
        #print "  R  ",myid,"  ",time()-st
        if options.CTF:
            from sp_filter import filt_ctf
            for im in range(len(class_data)):  #  MEM LEAK!!
                atemp = class_data[im].copy()
                btemp = filt_ctf(atemp, atemp.get_attr("ctf"), binary=1)
                class_data[im] = btemp
                #class_data[im] = filt_ctf(class_data[im], class_data[im].get_attr("ctf"), binary=1)
        for im in class_data:
            try:
                t = im.get_attr(
                    "xform.align2d")  # if they are there, no need to set them!
            except:
                try:
                    t = im.get_attr("xform.projection")
                    d = t.get_params("spider")
                    set_params2D(im, [0.0, -d["tx"], -d["ty"], 0, 1.0])
                except:
                    set_params2D(im, [0.0, 0.0, 0.0, 0, 1.0])
        #print "  F  ",myid,"  ",time()-st
        # Here, we perform realignment num_ali times
        all_ali_params = []
        for j in range(num_ali):
            if (xrng[0] == 0.0 and yrng[0] == 0.0):
                avet = ali2d_ras(class_data,
                                 randomize=True,
                                 ir=1,
                                 ou=radius,
                                 rs=1,
                                 step=1.0,
                                 dst=90.0,
                                 maxit=ite,
                                 check_mirror=True,
                                 FH=options.fl,
                                 FF=options.aa)
            else:
                avet = within_group_refinement(class_data, mask, True, 1,
                                               radius, 1, xrng, yrng, step,
                                               90.0, ite, options.fl,
                                               options.aa)
            ali_params = []
            for im in range(len(class_data)):
                alpha, sx, sy, mirror, scale = get_params2D(class_data[im])
                ali_params.extend([alpha, sx, sy, mirror])
            all_ali_params.append(ali_params)
        #aveList[i] = avet
        #print "  G  ",myid,"  ",time()-st
        del ali_params
        # We determine the stability of this group here.
        # stable_set contains all particles deemed stable, it is a list of list
        # each list has two elements, the first is the pixel error, the second is the image number
        # stable_set is sorted based on pixel error
        #from utilities import write_text_file
        #write_text_file(all_ali_params, "all_ali_params%03d.txt"%myid)
        stable_set, mir_stab_rate, average_pix_err = multi_align_stability(
            all_ali_params, 0.0, 10000.0, thld_err, False, 2 * radius + 1)
        #print "  H  ",myid,"  ",time()-st
        if (len(stable_set) > 5):
            stable_set_id = []
            members = []
            pix_err = []
            # First put the stable members into attr 'members' and 'pix_err'
            for s in stable_set:
                # s[1] - number in this subset
                stable_set_id.append(s[1])
                # the original image number
                members.append(proj_list[i][s[1]])
                pix_err.append(s[0])
            # Then put the unstable members into attr 'members' and 'pix_err'
            from sp_fundamentals import rot_shift2D
            avet.to_zero()
            if options.grouping == "GRP":
                aphi = 0.0
                atht = 0.0
                vphi = 0.0
                vtht = 0.0
            l = -1
            for j in range(len(proj_list[i])):
                #  Here it will only work if stable_set_id is sorted in the increasing number, see how l progresses
                if j in stable_set_id:
                    l += 1
                    avet += rot_shift2D(class_data[j], stable_set[l][2][0],
                                        stable_set[l][2][1],
                                        stable_set[l][2][2],
                                        stable_set[l][2][3])
                    if options.grouping == "GRP":
                        phi, theta, psi, sxs, sy_s = get_params_proj(
                            class_data[j])
                        if (theta > 90.0):
                            phi = (phi + 540.0) % 360.0
                            theta = 180.0 - theta
                        aphi += phi
                        atht += theta
                        vphi += phi * phi
                        vtht += theta * theta
                else:
                    members.append(proj_list[i][j])
                    pix_err.append(99999.99)
            aveList[i] = avet.copy()
            if l > 1:
                l += 1
                aveList[i] /= l
                if options.grouping == "GRP":
                    aphi /= l
                    atht /= l
                    vphi = (vphi - l * aphi * aphi) / l
                    vtht = (vtht - l * atht * atht) / l
                    from math import sqrt
                    refprojdir[i] = [
                        aphi, atht,
                        (sqrt(max(vphi, 0.0)) + sqrt(max(vtht, 0.0))) / 2.0
                    ]

            # Here more information has to be stored, PARTICULARLY WHAT IS THE REFERENCE DIRECTION
            aveList[i].set_attr('members', members)
            aveList[i].set_attr('refprojdir', refprojdir[i])
            aveList[i].set_attr('pixerr', pix_err)
        else:
            sxprint(" empty group ", i, refprojdir[i])
            aveList[i].set_attr('members', [-1])
            aveList[i].set_attr('refprojdir', refprojdir[i])
            aveList[i].set_attr('pixerr', [99999.])

    del class_data

    if myid == main_node:
        km = 0
        for i in range(number_of_proc):
            if i == main_node:
                for im in range(len(aveList)):
                    aveList[im].write_image(args[1], km)
                    km += 1
            else:
                nl = mpi_recv(1, MPI_INT, i, SPARX_MPI_TAG_UNIVERSAL,
                              MPI_COMM_WORLD)
                nl = int(nl[0])
                for im in range(nl):
                    ave = recv_EMData(i, im + i + 70000)
                    nm = mpi_recv(1, MPI_INT, i, SPARX_MPI_TAG_UNIVERSAL,
                                  MPI_COMM_WORLD)
                    nm = int(nm[0])
                    members = mpi_recv(nm, MPI_INT, i, SPARX_MPI_TAG_UNIVERSAL,
                                       MPI_COMM_WORLD)
                    ave.set_attr('members', list(map(int, members)))
                    members = mpi_recv(nm, MPI_FLOAT, i,
                                       SPARX_MPI_TAG_UNIVERSAL, MPI_COMM_WORLD)
                    ave.set_attr('pixerr', list(map(float, members)))
                    members = mpi_recv(3, MPI_FLOAT, i,
                                       SPARX_MPI_TAG_UNIVERSAL, MPI_COMM_WORLD)
                    ave.set_attr('refprojdir', list(map(float, members)))
                    ave.write_image(args[1], km)
                    km += 1
    else:
        mpi_send(len(aveList), 1, MPI_INT, main_node, SPARX_MPI_TAG_UNIVERSAL,
                 MPI_COMM_WORLD)
        for im in range(len(aveList)):
            send_EMData(aveList[im], main_node, im + myid + 70000)
            members = aveList[im].get_attr('members')
            mpi_send(len(members), 1, MPI_INT, main_node,
                     SPARX_MPI_TAG_UNIVERSAL, MPI_COMM_WORLD)
            mpi_send(members, len(members), MPI_INT, main_node,
                     SPARX_MPI_TAG_UNIVERSAL, MPI_COMM_WORLD)
            members = aveList[im].get_attr('pixerr')
            mpi_send(members, len(members), MPI_FLOAT, main_node,
                     SPARX_MPI_TAG_UNIVERSAL, MPI_COMM_WORLD)
            try:
                members = aveList[im].get_attr('refprojdir')
                mpi_send(members, 3, MPI_FLOAT, main_node,
                         SPARX_MPI_TAG_UNIVERSAL, MPI_COMM_WORLD)
            except:
                mpi_send([-999.0, -999.0, -999.0], 3, MPI_FLOAT, main_node,
                         SPARX_MPI_TAG_UNIVERSAL, MPI_COMM_WORLD)

    sp_global_def.BATCH = False
    mpi_barrier(MPI_COMM_WORLD)
Beispiel #24
0
def main():
    global Tracker, Blockdata
    progname = os.path.basename(sys.argv[0])
    usage = progname + " --output_dir=output_dir  --isac_dir=output_dir_of_isac "
    parser = optparse.OptionParser(usage, version=sp_global_def.SPARXVERSION)
    parser.add_option(
        "--pw_adjustment",
        type="string",
        default="analytical_model",
        help=
        "adjust power spectrum of 2-D averages to an analytic model. Other opions: no_adjustment; bfactor; a text file of 1D rotationally averaged PW",
    )
    #### Four options for --pw_adjustment:
    # 1> analytical_model(default);
    # 2> no_adjustment;
    # 3> bfactor;
    # 4> adjust_to_given_pw2(user has to provide a text file that contains 1D rotationally averaged PW)

    # options in common
    parser.add_option(
        "--isac_dir",
        type="string",
        default="",
        help="ISAC run output directory, input directory for this command",
    )
    parser.add_option(
        "--output_dir",
        type="string",
        default="",
        help="output directory where computed averages are saved",
    )
    parser.add_option(
        "--pixel_size",
        type="float",
        default=-1.0,
        help=
        "pixel_size of raw images. one can put 1.0 in case of negative stain data",
    )
    parser.add_option(
        "--fl",
        type="float",
        default=-1.0,
        help=
        "low pass filter, = -1.0, not applied; =0.0, using FH1 (initial resolution), = 1.0 using FH2 (resolution after local alignment), or user provided value in absolute freqency [0.0:0.5]",
    )
    parser.add_option("--stack",
                      type="string",
                      default="",
                      help="data stack used in ISAC")
    parser.add_option("--radius", type="int", default=-1, help="radius")
    parser.add_option("--xr",
                      type="float",
                      default=-1.0,
                      help="local alignment search range")
    # parser.add_option("--ts",                    type   ="float",          default =1.0,    help= "local alignment search step")
    parser.add_option(
        "--fh",
        type="float",
        default=-1.0,
        help="local alignment high frequencies limit",
    )
    # parser.add_option("--maxit",                 type   ="int",            default =5,      help= "local alignment iterations")
    parser.add_option("--navg",
                      type="int",
                      default=1000000,
                      help="number of aveages")
    parser.add_option(
        "--local_alignment",
        action="store_true",
        default=False,
        help="do local alignment",
    )
    parser.add_option(
        "--noctf",
        action="store_true",
        default=False,
        help=
        "no ctf correction, useful for negative stained data. always ctf for cryo data",
    )
    parser.add_option(
        "--B_start",
        type="float",
        default=45.0,
        help=
        "start frequency (Angstrom) of power spectrum for B_factor estimation",
    )
    parser.add_option(
        "--Bfactor",
        type="float",
        default=-1.0,
        help=
        "User defined bactors (e.g. 25.0[A^2]). By default, the program automatically estimates B-factor. ",
    )

    (options, args) = parser.parse_args(sys.argv[1:])

    adjust_to_analytic_model = (True if options.pw_adjustment
                                == "analytical_model" else False)
    no_adjustment = True if options.pw_adjustment == "no_adjustment" else False
    B_enhance = True if options.pw_adjustment == "bfactor" else False
    adjust_to_given_pw2 = (
        True if not (adjust_to_analytic_model or no_adjustment or B_enhance)
        else False)

    # mpi
    nproc = mpi.mpi_comm_size(mpi.MPI_COMM_WORLD)
    myid = mpi.mpi_comm_rank(mpi.MPI_COMM_WORLD)

    Blockdata = {}
    Blockdata["nproc"] = nproc
    Blockdata["myid"] = myid
    Blockdata["main_node"] = 0
    Blockdata["shared_comm"] = mpi.mpi_comm_split_type(
        mpi.MPI_COMM_WORLD, mpi.MPI_COMM_TYPE_SHARED, 0, mpi.MPI_INFO_NULL)
    Blockdata["myid_on_node"] = mpi.mpi_comm_rank(Blockdata["shared_comm"])
    Blockdata["no_of_processes_per_group"] = mpi.mpi_comm_size(
        Blockdata["shared_comm"])
    masters_from_groups_vs_everything_else_comm = mpi.mpi_comm_split(
        mpi.MPI_COMM_WORLD,
        Blockdata["main_node"] == Blockdata["myid_on_node"],
        Blockdata["myid_on_node"],
    )
    Blockdata["color"], Blockdata[
        "no_of_groups"], balanced_processor_load_on_nodes = sp_utilities.get_colors_and_subsets(
            Blockdata["main_node"],
            mpi.MPI_COMM_WORLD,
            Blockdata["myid"],
            Blockdata["shared_comm"],
            Blockdata["myid_on_node"],
            masters_from_groups_vs_everything_else_comm,
        )
    #  We need two nodes for processing of volumes
    Blockdata["node_volume"] = [
        Blockdata["no_of_groups"] - 3,
        Blockdata["no_of_groups"] - 2,
        Blockdata["no_of_groups"] - 1,
    ]  # For 3D stuff take three last nodes
    #  We need two CPUs for processing of volumes, they are taken to be main CPUs on each volume
    #  We have to send the two myids to all nodes so we can identify main nodes on two selected groups.
    Blockdata["nodes"] = [
        Blockdata["node_volume"][0] * Blockdata["no_of_processes_per_group"],
        Blockdata["node_volume"][1] * Blockdata["no_of_processes_per_group"],
        Blockdata["node_volume"][2] * Blockdata["no_of_processes_per_group"],
    ]
    # End of Blockdata: sorting requires at least three nodes, and the used number of nodes be integer times of three
    sp_global_def.BATCH = True
    sp_global_def.MPI = True

    if adjust_to_given_pw2:
        checking_flag = 0
        if Blockdata["myid"] == Blockdata["main_node"]:
            if not os.path.exists(options.pw_adjustment):
                checking_flag = 1
        checking_flag = sp_utilities.bcast_number_to_all(
            checking_flag, Blockdata["main_node"], mpi.MPI_COMM_WORLD)

        if checking_flag == 1:
            sp_global_def.ERROR("User provided power spectrum does not exist",
                                myid=Blockdata["myid"])

    Tracker = {}
    Constants = {}
    Constants["isac_dir"] = options.isac_dir
    Constants["masterdir"] = options.output_dir
    Constants["pixel_size"] = options.pixel_size
    Constants["orgstack"] = options.stack
    Constants["radius"] = options.radius
    Constants["xrange"] = options.xr
    Constants["FH"] = options.fh
    Constants["low_pass_filter"] = options.fl
    # Constants["maxit"]                        = options.maxit
    Constants["navg"] = options.navg
    Constants["B_start"] = options.B_start
    Constants["Bfactor"] = options.Bfactor

    if adjust_to_given_pw2:
        Constants["modelpw"] = options.pw_adjustment
    Tracker["constants"] = Constants
    # -------------------------------------------------------------
    #
    # Create and initialize Tracker dictionary with input options  # State Variables

    # <<<---------------------->>>imported functions<<<---------------------------------------------

    # x_range = max(Tracker["constants"]["xrange"], int(1./Tracker["ini_shrink"])+1)
    # y_range =  x_range

    ####-----------------------------------------------------------
    # Create Master directory and associated subdirectories
    line = time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()) + " =>"
    if Tracker["constants"]["masterdir"] == Tracker["constants"]["isac_dir"]:
        masterdir = os.path.join(Tracker["constants"]["isac_dir"], "sharpen")
    else:
        masterdir = Tracker["constants"]["masterdir"]

    if Blockdata["myid"] == Blockdata["main_node"]:
        msg = "Postprocessing ISAC 2D averages starts"
        sp_global_def.sxprint(line, "Postprocessing ISAC 2D averages starts")
        if not masterdir:
            timestring = time.strftime("_%d_%b_%Y_%H_%M_%S", time.localtime())
            masterdir = "sharpen_" + Tracker["constants"]["isac_dir"]
            os.makedirs(masterdir)
        else:
            if os.path.exists(masterdir):
                sp_global_def.sxprint("%s already exists" % masterdir)
            else:
                os.makedirs(masterdir)
        sp_global_def.write_command(masterdir)
        subdir_path = os.path.join(masterdir, "ali2d_local_params_avg")
        if not os.path.exists(subdir_path):
            os.mkdir(subdir_path)
        subdir_path = os.path.join(masterdir, "params_avg")
        if not os.path.exists(subdir_path):
            os.mkdir(subdir_path)
        li = len(masterdir)
    else:
        li = 0
    li = mpi.mpi_bcast(li, 1, mpi.MPI_INT, Blockdata["main_node"],
                       mpi.MPI_COMM_WORLD)[0]
    masterdir = mpi.mpi_bcast(masterdir, li, mpi.MPI_CHAR,
                              Blockdata["main_node"], mpi.MPI_COMM_WORLD)
    masterdir = b"".join(masterdir).decode('latin1')
    Tracker["constants"]["masterdir"] = masterdir
    log_main = sp_logger.Logger(sp_logger.BaseLogger_Files())
    log_main.prefix = Tracker["constants"]["masterdir"] + "/"

    while not os.path.exists(Tracker["constants"]["masterdir"]):
        sp_global_def.sxprint(
            "Node ",
            Blockdata["myid"],
            "  waiting...",
            Tracker["constants"]["masterdir"],
        )
        time.sleep(1)
    mpi.mpi_barrier(mpi.MPI_COMM_WORLD)

    if Blockdata["myid"] == Blockdata["main_node"]:
        init_dict = {}
        sp_global_def.sxprint(Tracker["constants"]["isac_dir"])
        Tracker["directory"] = os.path.join(Tracker["constants"]["isac_dir"],
                                            "2dalignment")
        core = sp_utilities.read_text_row(
            os.path.join(Tracker["directory"], "initial2Dparams.txt"))
        for im in range(len(core)):
            init_dict[im] = core[im]
        del core
    else:
        init_dict = 0
    init_dict = sp_utilities.wrap_mpi_bcast(init_dict,
                                            Blockdata["main_node"],
                                            communicator=mpi.MPI_COMM_WORLD)
    ###
    do_ctf = True
    if options.noctf:
        do_ctf = False
    if Blockdata["myid"] == Blockdata["main_node"]:
        if do_ctf:
            sp_global_def.sxprint("CTF correction is on")
        else:
            sp_global_def.sxprint("CTF correction is off")
        if options.local_alignment:
            sp_global_def.sxprint("local refinement is on")
        else:
            sp_global_def.sxprint("local refinement is off")
        if B_enhance:
            sp_global_def.sxprint("Bfactor is to be applied on averages")
        elif adjust_to_given_pw2:
            sp_global_def.sxprint(
                "PW of averages is adjusted to a given 1D PW curve")
        elif adjust_to_analytic_model:
            sp_global_def.sxprint(
                "PW of averages is adjusted to analytical model")
        else:
            sp_global_def.sxprint("PW of averages is not adjusted")
        # Tracker["constants"]["orgstack"] = "bdb:"+ os.path.join(Tracker["constants"]["isac_dir"],"../","sparx_stack")
        image = sp_utilities.get_im(Tracker["constants"]["orgstack"], 0)
        Tracker["constants"]["nnxo"] = image.get_xsize()
        if Tracker["constants"]["pixel_size"] == -1.0:
            sp_global_def.sxprint(
                "Pixel size value is not provided by user. extracting it from ctf header entry of the original stack."
            )
            try:
                ctf_params = image.get_attr("ctf")
                Tracker["constants"]["pixel_size"] = ctf_params.apix
            except:
                sp_global_def.ERROR(
                    "Pixel size could not be extracted from the original stack.",
                    myid=Blockdata["myid"],
                )
        ## Now fill in low-pass filter

        isac_shrink_path = os.path.join(Tracker["constants"]["isac_dir"],
                                        "README_shrink_ratio.txt")

        if not os.path.exists(isac_shrink_path):
            sp_global_def.ERROR(
                "%s does not exist in the specified ISAC run output directory"
                % (isac_shrink_path),
                myid=Blockdata["myid"],
            )

        isac_shrink_file = open(isac_shrink_path, "r")
        isac_shrink_lines = isac_shrink_file.readlines()
        isac_shrink_ratio = float(
            isac_shrink_lines[5]
        )  # 6th line: shrink ratio (= [target particle radius]/[particle radius]) used in the ISAC run
        isac_radius = float(
            isac_shrink_lines[6]
        )  # 7th line: particle radius at original pixel size used in the ISAC run
        isac_shrink_file.close()
        print("Extracted parameter values")
        print("ISAC shrink ratio    : {0}".format(isac_shrink_ratio))
        print("ISAC particle radius : {0}".format(isac_radius))
        Tracker["ini_shrink"] = isac_shrink_ratio
    else:
        Tracker["ini_shrink"] = 0.0
    Tracker = sp_utilities.wrap_mpi_bcast(Tracker,
                                          Blockdata["main_node"],
                                          communicator=mpi.MPI_COMM_WORLD)

    # print(Tracker["constants"]["pixel_size"], "pixel_size")
    x_range = max(
        Tracker["constants"]["xrange"],
        int(old_div(1.0, Tracker["ini_shrink"]) + 0.99999),
    )
    a_range = y_range = x_range

    if Blockdata["myid"] == Blockdata["main_node"]:
        parameters = sp_utilities.read_text_row(
            os.path.join(Tracker["constants"]["isac_dir"],
                         "all_parameters.txt"))
    else:
        parameters = 0
    parameters = sp_utilities.wrap_mpi_bcast(parameters,
                                             Blockdata["main_node"],
                                             communicator=mpi.MPI_COMM_WORLD)
    params_dict = {}
    list_dict = {}
    # parepare params_dict

    # navg = min(Tracker["constants"]["navg"]*Blockdata["nproc"], EMUtil.get_image_count(os.path.join(Tracker["constants"]["isac_dir"], "class_averages.hdf")))
    navg = min(
        Tracker["constants"]["navg"],
        EMAN2_cppwrap.EMUtil.get_image_count(
            os.path.join(Tracker["constants"]["isac_dir"],
                         "class_averages.hdf")),
    )
    global_dict = {}
    ptl_list = []
    memlist = []
    if Blockdata["myid"] == Blockdata["main_node"]:
        sp_global_def.sxprint("Number of averages computed in this run is %d" %
                              navg)
        for iavg in range(navg):
            params_of_this_average = []
            image = sp_utilities.get_im(
                os.path.join(Tracker["constants"]["isac_dir"],
                             "class_averages.hdf"),
                iavg,
            )
            members = sorted(image.get_attr("members"))
            memlist.append(members)
            for im in range(len(members)):
                abs_id = members[im]
                global_dict[abs_id] = [iavg, im]
                P = sp_utilities.combine_params2(
                    init_dict[abs_id][0],
                    init_dict[abs_id][1],
                    init_dict[abs_id][2],
                    init_dict[abs_id][3],
                    parameters[abs_id][0],
                    old_div(parameters[abs_id][1], Tracker["ini_shrink"]),
                    old_div(parameters[abs_id][2], Tracker["ini_shrink"]),
                    parameters[abs_id][3],
                )
                if parameters[abs_id][3] == -1:
                    sp_global_def.sxprint(
                        "WARNING: Image #{0} is an unaccounted particle with invalid 2D alignment parameters and should not be the member of any classes. Please check the consitency of input dataset."
                        .format(abs_id)
                    )  # How to check what is wrong about mirror = -1 (Toshio 2018/01/11)
                params_of_this_average.append([P[0], P[1], P[2], P[3], 1.0])
                ptl_list.append(abs_id)
            params_dict[iavg] = params_of_this_average
            list_dict[iavg] = members
            sp_utilities.write_text_row(
                params_of_this_average,
                os.path.join(
                    Tracker["constants"]["masterdir"],
                    "params_avg",
                    "params_avg_%03d.txt" % iavg,
                ),
            )
        ptl_list.sort()
        init_params = [None for im in range(len(ptl_list))]
        for im in range(len(ptl_list)):
            init_params[im] = [ptl_list[im]] + params_dict[global_dict[
                ptl_list[im]][0]][global_dict[ptl_list[im]][1]]
        sp_utilities.write_text_row(
            init_params,
            os.path.join(Tracker["constants"]["masterdir"],
                         "init_isac_params.txt"),
        )
    else:
        params_dict = 0
        list_dict = 0
        memlist = 0
    params_dict = sp_utilities.wrap_mpi_bcast(params_dict,
                                              Blockdata["main_node"],
                                              communicator=mpi.MPI_COMM_WORLD)
    list_dict = sp_utilities.wrap_mpi_bcast(list_dict,
                                            Blockdata["main_node"],
                                            communicator=mpi.MPI_COMM_WORLD)
    memlist = sp_utilities.wrap_mpi_bcast(memlist,
                                          Blockdata["main_node"],
                                          communicator=mpi.MPI_COMM_WORLD)
    # Now computing!
    del init_dict
    tag_sharpen_avg = 1000
    ## always apply low pass filter to B_enhanced images to suppress noise in high frequencies
    enforced_to_H1 = False
    if B_enhance:
        if Tracker["constants"]["low_pass_filter"] == -1.0:
            enforced_to_H1 = True

    # distribute workload among mpi processes
    image_start, image_end = sp_applications.MPI_start_end(
        navg, Blockdata["nproc"], Blockdata["myid"])

    if Blockdata["myid"] == Blockdata["main_node"]:
        cpu_dict = {}
        for iproc in range(Blockdata["nproc"]):
            local_image_start, local_image_end = sp_applications.MPI_start_end(
                navg, Blockdata["nproc"], iproc)
            for im in range(local_image_start, local_image_end):
                cpu_dict[im] = iproc
    else:
        cpu_dict = 0

    cpu_dict = sp_utilities.wrap_mpi_bcast(cpu_dict,
                                           Blockdata["main_node"],
                                           communicator=mpi.MPI_COMM_WORLD)

    slist = [None for im in range(navg)]
    ini_list = [None for im in range(navg)]
    avg1_list = [None for im in range(navg)]
    avg2_list = [None for im in range(navg)]
    data_list = [None for im in range(navg)]
    plist_dict = {}

    if Blockdata["myid"] == Blockdata["main_node"]:
        if B_enhance:
            sp_global_def.sxprint(
                "Avg ID   B-factor  FH1(Res before ali) FH2(Res after ali)")
        else:
            sp_global_def.sxprint(
                "Avg ID   FH1(Res before ali)  FH2(Res after ali)")

    FH_list = [[0, 0.0, 0.0] for im in range(navg)]
    for iavg in range(image_start, image_end):

        mlist = EMAN2_cppwrap.EMData.read_images(
            Tracker["constants"]["orgstack"], list_dict[iavg])

        for im in range(len(mlist)):
            sp_utilities.set_params2D(mlist[im],
                                      params_dict[iavg][im],
                                      xform="xform.align2d")

        if options.local_alignment:
            new_avg, plist, FH2 = sp_applications.refinement_2d_local(
                mlist,
                Tracker["constants"]["radius"],
                a_range,
                x_range,
                y_range,
                CTF=do_ctf,
                SNR=1.0e10,
            )
            plist_dict[iavg] = plist
            FH1 = -1.0

        else:
            new_avg, frc, plist = compute_average(
                mlist, Tracker["constants"]["radius"], do_ctf)
            FH1 = get_optimistic_res(frc)
            FH2 = -1.0

        FH_list[iavg] = [iavg, FH1, FH2]

        if B_enhance:
            new_avg, gb = apply_enhancement(
                new_avg,
                Tracker["constants"]["B_start"],
                Tracker["constants"]["pixel_size"],
                Tracker["constants"]["Bfactor"],
            )
            sp_global_def.sxprint("  %6d      %6.3f  %4.3f  %4.3f" %
                                  (iavg, gb, FH1, FH2))

        elif adjust_to_given_pw2:
            roo = sp_utilities.read_text_file(Tracker["constants"]["modelpw"],
                                              -1)
            roo = roo[0]  # always on the first column
            new_avg = adjust_pw_to_model(new_avg,
                                         Tracker["constants"]["pixel_size"],
                                         roo)
            sp_global_def.sxprint("  %6d      %4.3f  %4.3f  " %
                                  (iavg, FH1, FH2))

        elif adjust_to_analytic_model:
            new_avg = adjust_pw_to_model(new_avg,
                                         Tracker["constants"]["pixel_size"],
                                         None)
            sp_global_def.sxprint("  %6d      %4.3f  %4.3f   " %
                                  (iavg, FH1, FH2))

        elif no_adjustment:
            pass

        if Tracker["constants"]["low_pass_filter"] != -1.0:
            if Tracker["constants"]["low_pass_filter"] == 0.0:
                low_pass_filter = FH1
            elif Tracker["constants"]["low_pass_filter"] == 1.0:
                low_pass_filter = FH2
                if not options.local_alignment:
                    low_pass_filter = FH1
            else:
                low_pass_filter = Tracker["constants"]["low_pass_filter"]
                if low_pass_filter >= 0.45:
                    low_pass_filter = 0.45
            new_avg = sp_filter.filt_tanl(new_avg, low_pass_filter, 0.02)
        else:  # No low pass filter but if enforced
            if enforced_to_H1:
                new_avg = sp_filter.filt_tanl(new_avg, FH1, 0.02)
        if B_enhance:
            new_avg = sp_fundamentals.fft(new_avg)

        new_avg.set_attr("members", list_dict[iavg])
        new_avg.set_attr("n_objects", len(list_dict[iavg]))
        slist[iavg] = new_avg
        sp_global_def.sxprint(
            time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()) + " =>",
            "Refined average %7d" % iavg,
        )

    ## send to main node to write
    mpi.mpi_barrier(mpi.MPI_COMM_WORLD)

    for im in range(navg):
        # avg
        if (cpu_dict[im] == Blockdata["myid"]
                and Blockdata["myid"] != Blockdata["main_node"]):
            sp_utilities.send_EMData(slist[im], Blockdata["main_node"],
                                     tag_sharpen_avg)

        elif (cpu_dict[im] == Blockdata["myid"]
              and Blockdata["myid"] == Blockdata["main_node"]):
            slist[im].set_attr("members", memlist[im])
            slist[im].set_attr("n_objects", len(memlist[im]))
            slist[im].write_image(
                os.path.join(Tracker["constants"]["masterdir"],
                             "class_averages.hdf"),
                im,
            )

        elif (cpu_dict[im] != Blockdata["myid"]
              and Blockdata["myid"] == Blockdata["main_node"]):
            new_avg_other_cpu = sp_utilities.recv_EMData(
                cpu_dict[im], tag_sharpen_avg)
            new_avg_other_cpu.set_attr("members", memlist[im])
            new_avg_other_cpu.set_attr("n_objects", len(memlist[im]))
            new_avg_other_cpu.write_image(
                os.path.join(Tracker["constants"]["masterdir"],
                             "class_averages.hdf"),
                im,
            )

        if options.local_alignment:
            if cpu_dict[im] == Blockdata["myid"]:
                sp_utilities.write_text_row(
                    plist_dict[im],
                    os.path.join(
                        Tracker["constants"]["masterdir"],
                        "ali2d_local_params_avg",
                        "ali2d_local_params_avg_%03d.txt" % im,
                    ),
                )

            if (cpu_dict[im] == Blockdata["myid"]
                    and cpu_dict[im] != Blockdata["main_node"]):
                sp_utilities.wrap_mpi_send(plist_dict[im],
                                           Blockdata["main_node"],
                                           mpi.MPI_COMM_WORLD)
                sp_utilities.wrap_mpi_send(FH_list, Blockdata["main_node"],
                                           mpi.MPI_COMM_WORLD)

            elif (cpu_dict[im] != Blockdata["main_node"]
                  and Blockdata["myid"] == Blockdata["main_node"]):
                dummy = sp_utilities.wrap_mpi_recv(cpu_dict[im],
                                                   mpi.MPI_COMM_WORLD)
                plist_dict[im] = dummy
                dummy = sp_utilities.wrap_mpi_recv(cpu_dict[im],
                                                   mpi.MPI_COMM_WORLD)
                FH_list[im] = dummy[im]
        else:
            if (cpu_dict[im] == Blockdata["myid"]
                    and cpu_dict[im] != Blockdata["main_node"]):
                sp_utilities.wrap_mpi_send(FH_list, Blockdata["main_node"],
                                           mpi.MPI_COMM_WORLD)

            elif (cpu_dict[im] != Blockdata["main_node"]
                  and Blockdata["myid"] == Blockdata["main_node"]):
                dummy = sp_utilities.wrap_mpi_recv(cpu_dict[im],
                                                   mpi.MPI_COMM_WORLD)
                FH_list[im] = dummy[im]

        mpi.mpi_barrier(mpi.MPI_COMM_WORLD)
    mpi.mpi_barrier(mpi.MPI_COMM_WORLD)

    if options.local_alignment:
        if Blockdata["myid"] == Blockdata["main_node"]:
            ali3d_local_params = [None for im in range(len(ptl_list))]
            for im in range(len(ptl_list)):
                ali3d_local_params[im] = [ptl_list[im]] + plist_dict[
                    global_dict[ptl_list[im]][0]][global_dict[ptl_list[im]][1]]
            sp_utilities.write_text_row(
                ali3d_local_params,
                os.path.join(Tracker["constants"]["masterdir"],
                             "ali2d_local_params.txt"),
            )
            sp_utilities.write_text_row(
                FH_list,
                os.path.join(Tracker["constants"]["masterdir"], "FH_list.txt"))
    else:
        if Blockdata["myid"] == Blockdata["main_node"]:
            sp_utilities.write_text_row(
                FH_list,
                os.path.join(Tracker["constants"]["masterdir"], "FH_list.txt"))

    mpi.mpi_barrier(mpi.MPI_COMM_WORLD)
    target_xr = 3
    target_yr = 3
    if Blockdata["myid"] == 0:
        cmd = "{} {} {} {} {} {} {} {} {} {}".format(
            "sp_chains.py",
            os.path.join(Tracker["constants"]["masterdir"],
                         "class_averages.hdf"),
            os.path.join(Tracker["constants"]["masterdir"], "junk.hdf"),
            os.path.join(Tracker["constants"]["masterdir"],
                         "ordered_class_averages.hdf"),
            "--circular",
            "--radius=%d" % Tracker["constants"]["radius"],
            "--xr=%d" % (target_xr + 1),
            "--yr=%d" % (target_yr + 1),
            "--align",
            ">/dev/null",
        )
        junk = sp_utilities.cmdexecute(cmd)
        cmd = "{} {}".format(
            "rm -rf",
            os.path.join(Tracker["constants"]["masterdir"], "junk.hdf"))
        junk = sp_utilities.cmdexecute(cmd)

    return
def rec3D_MPI_noCTF(data, symmetry, mask3D, fsc_curve, myid, main_node = 0, rstep = 1.0, odd_start=0, eve_start=1, finfo=None, index = -1, npad = 4, hparams=None):
	'''
	  This function is to be called within an MPI program to do a reconstruction on a dataset kept in the memory 
	  Computes reconstruction and through odd-even, in order to get the resolution
	  if index > -1, projections should have attribute group set and only those whose group matches index will be used in the reconstruction
	    this is for multireference alignment
	'''
	import os
	from statistics import fsc_mask
	from utilities  import model_blank, reduce_EMData_to_root, get_image,send_EMData, recv_EMData
	from random     import randint
	from mpi        import mpi_comm_size, mpi_comm_rank, MPI_COMM_WORLD
	nproc = mpi_comm_size(MPI_COMM_WORLD)
       
	if nproc==1:
		assert main_node==0
		main_node_odd = main_node
		main_node_eve = main_node
		main_node_all = main_node
	elif nproc==2:
		main_node_odd = main_node
		main_node_eve = (main_node+1)%2
		main_node_all = main_node

		tag_voleve     = 1000
		tag_fftvol_eve = 1001
		tag_weight_eve = 1002
	else:
		#spread CPUs between different nodes to save memory
		main_node_odd = main_node
		main_node_eve = (int(main_node)+nproc-1)%int(nproc)
		main_node_all = (int(main_node)+nproc//2)%int(nproc)

		tag_voleve     = 1000
		tag_fftvol_eve = 1001
		tag_weight_eve = 1002

		tag_fftvol_odd = 1003
		tag_weight_odd = 1004
		tag_volall     = 1005
 
        nx = data[0].get_xsize()

        fftvol_odd_file,weight_odd_file = prepare_recons(data, symmetry, myid, main_node_odd, odd_start, 2, index, finfo, npad)
        fftvol_eve_file,weight_eve_file = prepare_recons(data, symmetry, myid, main_node_eve, eve_start, 2, index, finfo, npad) 
	
	if nproc == 1:
		fftvol = get_image( fftvol_odd_file )
		weight = get_image( weight_odd_file )
		volodd = recons_from_fftvol(nx, fftvol, weight, symmetry, npad)

		fftvol = get_image( fftvol_eve_file )
		weight = get_image( weight_eve_file )
		voleve = recons_from_fftvol(nx, fftvol, weight, symmetry, npad)

		fftvol = get_image( fftvol_odd_file )
		Util.add_img( fftvol, get_image(fftvol_eve_file) )

		weight = get_image( weight_odd_file )
		Util.add_img( weight, get_image(weight_eve_file) )

		volall = recons_from_fftvol(nx, fftvol, weight, symmetry, npad)

		# if helical, find & apply symmetry to volume
		if hparams is not None:
			volodd,voleve,volall = hsymVols(volodd,voleve,volall,hparams)
		fscdat = fsc_mask( volodd, voleve, mask3D, rstep, fsc_curve)

		os.system( "rm -f " + fftvol_odd_file + " " + weight_odd_file );
		os.system( "rm -f " + fftvol_eve_file + " " + weight_eve_file );
		return volall,fscdat,volodd,voleve

	if nproc == 2:
		if myid == main_node_odd:
			fftvol = get_image( fftvol_odd_file )
			weight = get_image( weight_odd_file )
			volodd = recons_from_fftvol(nx, fftvol, weight, symmetry, npad)
			voleve = recv_EMData(main_node_eve, tag_voleve)
		else:
			assert myid == main_node_eve
			fftvol = get_image( fftvol_eve_file )
			weight = get_image( weight_eve_file )
			voleve = recons_from_fftvol(nx, fftvol, weight, symmetry, npad)
			send_EMData(voleve, main_node_odd, tag_voleve)

		if myid == main_node_odd:
			fftvol = get_image( fftvol_odd_file )
			fftvol_tmp = recv_EMData( main_node_eve, tag_fftvol_eve )
			Util.add_img( fftvol, fftvol_tmp )
			fftvol_tmp = None

			weight = get_image( weight_odd_file )
			weight_tmp = recv_EMData( main_node_eve, tag_weight_eve )
			Util.add_img( weight, weight_tmp )
			weight_tmp = None
			volall = recons_from_fftvol(nx, fftvol, weight, symmetry, npad)

			# if helical, find & apply symmetry to volume
			if hparams is not None:
				volodd,voleve,volall = hsymVols(volodd,voleve,volall,hparams)
			fscdat = fsc_mask( volodd, voleve, mask3D, rstep, fsc_curve)

			os.system( "rm -f " + fftvol_odd_file + " " + weight_odd_file );
			os.system( "rm -f " + fftvol_eve_file + " " + weight_eve_file );
			return volall,fscdat,volodd,voleve
		else:
			assert myid == main_node_eve
			fftvol = get_image( fftvol_eve_file )
			send_EMData(fftvol, main_node_odd, tag_fftvol_eve )

			weight = get_image( weight_eve_file )
			send_EMData(weight, main_node_odd, tag_weight_eve )
			import os
			os.system( "rm -f " + fftvol_eve_file + " " + weight_eve_file );
			return model_blank(nx,nx,nx), None, model_blank(nx,nx,nx), model_blank(nx,nx,nx)
	# cases from all other number of processors situations
	if myid == main_node_odd:
		fftvol = get_image( fftvol_odd_file )
		send_EMData(fftvol, main_node_eve, tag_fftvol_odd )

		if not(finfo is None):
			finfo.write("fftvol odd sent\n")
			finfo.flush()

		weight = get_image( weight_odd_file )
		send_EMData(weight, main_node_all, tag_weight_odd )

		if not(finfo is None):
			finfo.write("weight odd sent\n")
			finfo.flush()

		volodd = recons_from_fftvol(nx, fftvol, weight, symmetry, npad)
		del fftvol, weight
		voleve = recv_EMData(main_node_eve, tag_voleve)
		volall = recv_EMData(main_node_all, tag_volall)

		# if helical, find & apply symmetry to volume
		if hparams is not None:
			volodd,voleve,volall = hsymVols(volodd,voleve,volall,hparams)
		fscdat = fsc_mask( volodd, voleve, mask3D, rstep, fsc_curve)

		os.system( "rm -f " + fftvol_odd_file + " " + weight_odd_file );
		return volall,fscdat,volodd,voleve

	if myid == main_node_eve:
		ftmp = recv_EMData(main_node_odd, tag_fftvol_odd)
		fftvol = get_image( fftvol_eve_file )
		Util.add_img( ftmp, fftvol )
		send_EMData(ftmp, main_node_all, tag_fftvol_eve )
		del ftmp

		weight = get_image( weight_eve_file )
		send_EMData(weight, main_node_all, tag_weight_eve )

		voleve = recons_from_fftvol(nx, fftvol, weight, symmetry, npad)
		send_EMData(voleve, main_node_odd, tag_voleve)
		os.system( "rm -f " + fftvol_eve_file + " " + weight_eve_file );

		return model_blank(nx,nx,nx), None, model_blank(nx,nx,nx), model_blank(nx,nx,nx)


	if myid == main_node_all:
		fftvol = recv_EMData(main_node_eve, tag_fftvol_eve)
		if not(finfo is None):
			finfo.write( "fftvol odd received\n" )
			finfo.flush()

		weight = recv_EMData(main_node_odd, tag_weight_odd)
		weight_tmp = recv_EMData(main_node_eve, tag_weight_eve)
		Util.add_img( weight, weight_tmp )
		weight_tmp = None

		volall = recons_from_fftvol(nx, fftvol, weight, symmetry, npad)
		send_EMData(volall, main_node_odd, tag_volall)

		return model_blank(nx,nx,nx),None, model_blank(nx,nx,nx), model_blank(nx,nx,nx)


	return model_blank(nx,nx,nx), None, model_blank(nx,nx,nx), model_blank(nx,nx,nx)
Beispiel #26
0
def main():
    progname = os.path.basename(sys.argv[0])
    usage = progname + """  input_micrograph_list_file  input_micrograph_pattern  input_coordinates_pattern  output_directory  --coordinates_format  --box_size=box_size  --invert  --import_ctf=ctf_file  --limit_ctf  --resample_ratio=resample_ratio  --defocus_error=defocus_error  --astigmatism_error=astigmatism_error
	
Window particles from micrographs in input list file. The coordinates of the particles should be given as input.
Please specify name pattern of input micrographs and coordinates files with a wild card (*). Use the wild card to indicate the place of micrograph ID (e.g. serial number, time stamp, and etc). 
The name patterns must be enclosed by single quotes (') or double quotes ("). (Note: sxgui.py automatically adds single quotes (')). 
BDB files can not be selected as input micrographs.
	
	sxwindow.py  mic_list.txt  ./mic*.hdf  info/mic*_info.json  particles  --coordinates_format=eman2  --box_size=64  --invert  --import_ctf=outdir_cter/partres/partres.txt
	
If micrograph list file name is not provided, all files matched with the micrograph name pattern will be processed.
	
	sxwindow.py  ./mic*.hdf  info/mic*_info.json  particles  --coordinates_format=eman2  --box_size=64  --invert  --import_ctf=outdir_cter/partres/partres.txt
	
"""
    parser = OptionParser(usage, version=SPARXVERSION)
    parser.add_option(
        "--coordinates_format",
        type="string",
        default="eman1",
        help=
        "format of input coordinates files: 'sparx', 'eman1', 'eman2', or 'spider'. the coordinates of sparx, eman2, and spider format is particle center. the coordinates of eman1 format is particle box conner associated with the original box size. (default eman1)"
    )
    parser.add_option(
        "--box_size",
        type="int",
        default=256,
        help=
        "x and y dimension of square area to be windowed (in pixels): pixel size after resampling is assumed when resample_ratio < 1.0 (default 256)"
    )
    parser.add_option(
        "--invert",
        action="store_true",
        default=False,
        help="invert image contrast: recommended for cryo data (default False)"
    )
    parser.add_option(
        "--import_ctf",
        type="string",
        default="",
        help="file name of sxcter output: normally partres.txt (default none)")
    parser.add_option(
        "--limit_ctf",
        action="store_true",
        default=False,
        help=
        "filter micrographs based on the CTF limit: this option requires --import_ctf. (default False)"
    )
    parser.add_option(
        "--resample_ratio",
        type="float",
        default=1.0,
        help=
        "ratio of new to old image size (or old to new pixel size) for resampling: Valid range is 0.0 < resample_ratio <= 1.0. (default 1.0)"
    )
    parser.add_option(
        "--defocus_error",
        type="float",
        default=1000000.0,
        help=
        "defocus errror limit: exclude micrographs whose relative defocus error as estimated by sxcter is larger than defocus_error percent. the error is computed as (std dev defocus)/defocus*100%. (default 1000000.0)"
    )
    parser.add_option(
        "--astigmatism_error",
        type="float",
        default=360.0,
        help=
        "astigmatism error limit: Set to zero astigmatism for micrographs whose astigmatism angular error as estimated by sxcter is larger than astigmatism_error degrees. (default 360.0)"
    )

    ### detect if program is running under MPI
    RUNNING_UNDER_MPI = "OMPI_COMM_WORLD_SIZE" in os.environ

    main_node = 0

    if RUNNING_UNDER_MPI:
        from mpi import mpi_init
        from mpi import MPI_COMM_WORLD, mpi_comm_rank, mpi_comm_size, mpi_barrier, mpi_reduce, MPI_INT, MPI_SUM

        mpi_init(0, [])
        myid = mpi_comm_rank(MPI_COMM_WORLD)
        number_of_processes = mpi_comm_size(MPI_COMM_WORLD)
    else:
        number_of_processes = 1
        myid = 0

    (options, args) = parser.parse_args(sys.argv[1:])

    mic_list_file_path = None
    mic_pattern = None
    coords_pattern = None
    error_status = None
    while True:
        if len(args) < 3 or len(args) > 4:
            error_status = (
                "Please check usage for number of arguments.\n Usage: " +
                usage + "\n" + "Please run %s -h for help." % (progname),
                getframeinfo(currentframe()))
            break

        if len(args) == 3:
            mic_pattern = args[0]
            coords_pattern = args[1]
            out_dir = args[2]
        else:  # assert(len(args) == 4)
            mic_list_file_path = args[0]
            mic_pattern = args[1]
            coords_pattern = args[2]
            out_dir = args[3]

        if mic_list_file_path != None:
            if os.path.splitext(mic_list_file_path)[1] != ".txt":
                error_status = (
                    "Extension of input micrograph list file must be \".txt\". Please check input_micrograph_list_file argument. Run %s -h for help."
                    % (progname), getframeinfo(currentframe()))
                break

        if mic_pattern[:len("bdb:")].lower() == "bdb":
            error_status = (
                "BDB file can not be selected as input micrographs. Please convert the format, and restart the program. Run %s -h for help."
                % (progname), getframeinfo(currentframe()))
            break

        if mic_pattern.find("*") == -1:
            error_status = (
                "Input micrograph file name pattern must contain wild card (*). Please check input_micrograph_pattern argument. Run %s -h for help."
                % (progname), getframeinfo(currentframe()))
            break

        if coords_pattern.find("*") == -1:
            error_status = (
                "Input coordinates file name pattern must contain wild card (*). Please check input_coordinates_pattern argument. Run %s -h for help."
                % (progname), getframeinfo(currentframe()))
            break

        if myid == main_node:
            if os.path.exists(out_dir):
                error_status = (
                    "Output directory exists. Please change the name and restart the program.",
                    getframeinfo(currentframe()))
                break

        break
    if_error_then_all_processes_exit_program(error_status)

    # Check invalid conditions of options
    check_options(options, progname)

    mic_name_list = None
    error_status = None
    if myid == main_node:
        if mic_list_file_path != None:
            print("Loading micrograph list from %s file ..." %
                  (mic_list_file_path))
            mic_name_list = read_text_file(mic_list_file_path)
            if len(mic_name_list) == 0:
                print("Directory of first micrograph entry is " %
                      (os.path.dirname(mic_name_list[0])))
        else:  # assert (mic_list_file_path == None)
            print("Generating micrograph list in %s directory..." %
                  (os.path.dirname(mic_pattern)))
            mic_name_list = glob.glob(mic_pattern)
        if len(mic_name_list) == 0:
            error_status = (
                "No micrograph file is found. Please check input_micrograph_pattern and/or input_micrograph_list_file argument. Run %s -h for help."
                % (progname), getframeinfo(currentframe()))
        else:
            print("Found %d microgarphs" % len(mic_name_list))

    if_error_then_all_processes_exit_program(error_status)
    if RUNNING_UNDER_MPI:
        mic_name_list = wrap_mpi_bcast(mic_name_list, main_node)

    coords_name_list = None
    error_status = None
    if myid == main_node:
        coords_name_list = glob.glob(coords_pattern)
        if len(coords_name_list) == 0:
            error_status = (
                "No coordinates file is found. Please check input_coordinates_pattern argument. Run %s -h for help."
                % (progname), getframeinfo(currentframe()))
    if_error_then_all_processes_exit_program(error_status)
    if RUNNING_UNDER_MPI:
        coords_name_list = wrap_mpi_bcast(coords_name_list, main_node)

##################################################################################################################################################################################################################
##################################################################################################################################################################################################################
##################################################################################################################################################################################################################

# all processes must have access to indices
    if options.import_ctf:
        i_enum = -1
        i_enum += 1
        idx_cter_def = i_enum  # defocus [um]; index must be same as ctf object format
        i_enum += 1
        idx_cter_cs = i_enum  # Cs [mm]; index must be same as ctf object format
        i_enum += 1
        idx_cter_vol = i_enum  # voltage[kV]; index must be same as ctf object format
        i_enum += 1
        idx_cter_apix = i_enum  # pixel size [A]; index must be same as ctf object format
        i_enum += 1
        idx_cter_bfactor = i_enum  # B-factor [A^2]; index must be same as ctf object format
        i_enum += 1
        idx_cter_ac = i_enum  # amplitude contrast [%]; index must be same as ctf object format
        i_enum += 1
        idx_cter_astig_amp = i_enum  # astigmatism amplitude [um]; index must be same as ctf object format
        i_enum += 1
        idx_cter_astig_ang = i_enum  # astigmatism angle [degree]; index must be same as ctf object format
        i_enum += 1
        idx_cter_sd_def = i_enum  # std dev of defocus [um]
        i_enum += 1
        idx_cter_sd_astig_amp = i_enum  # std dev of ast amp [A]
        i_enum += 1
        idx_cter_sd_astig_ang = i_enum  # std dev of ast angle [degree]
        i_enum += 1
        idx_cter_cv_def = i_enum  # coefficient of variation of defocus [%]
        i_enum += 1
        idx_cter_cv_astig_amp = i_enum  # coefficient of variation of ast amp [%]
        i_enum += 1
        idx_cter_spectra_diff = i_enum  # average of differences between with- and without-astig. experimental 1D spectra at extrema
        i_enum += 1
        idx_cter_error_def = i_enum  # frequency at which signal drops by 50% due to estimated error of defocus alone [1/A]
        i_enum += 1
        idx_cter_error_astig = i_enum  # frequency at which signal drops by 50% due to estimated error of defocus and astigmatism [1/A]
        i_enum += 1
        idx_cter_error_ctf = i_enum  # limit frequency by CTF error [1/A]
        i_enum += 1
        idx_cter_mic_name = i_enum  # micrograph name
        i_enum += 1
        n_idx_cter = i_enum

    # Prepare loop variables
    mic_basename_pattern = os.path.basename(
        mic_pattern)  # file pattern without path
    mic_baseroot_pattern = os.path.splitext(mic_basename_pattern)[
        0]  # file pattern without path and extension
    coords_format = options.coordinates_format.lower()
    box_size = options.box_size
    box_half = box_size // 2
    mask2d = model_circle(
        box_size // 2, box_size, box_size
    )  # Create circular 2D mask to Util.infomask of particle images
    resample_ratio = options.resample_ratio

    n_mic_process = 0
    n_mic_reject_no_coords = 0
    n_mic_reject_no_cter_entry = 0
    n_global_coords_detect = 0
    n_global_coords_process = 0
    n_global_coords_reject_out_of_boundary = 0

    serial_id_list = []
    error_status = None
    ## not a real while, an if with the opportunity to use break when errors need to be reported
    while myid == main_node:
        #
        # NOTE: 2016/05/24 Toshio Moriya
        # Now, ignores the path in mic_pattern and entries of mic_name_list to create serial ID
        # Only the basename (file name) in micrograph path must be match
        #
        # Create list of micrograph serial ID
        # Break micrograph name pattern into prefix and suffix to find the head index of the micrograph serial id
        #
        mic_basename_tokens = mic_basename_pattern.split('*')
        # assert (len(mic_basename_tokens) == 2)
        serial_id_head_index = len(mic_basename_tokens[0])
        # Loop through micrograph names
        for mic_name in mic_name_list:
            # Find the tail index of the serial id and extract serial id from the micrograph name
            mic_basename = os.path.basename(mic_name)
            serial_id_tail_index = mic_basename.index(mic_basename_tokens[1])
            serial_id = mic_basename[serial_id_head_index:serial_id_tail_index]
            serial_id_list.append(serial_id)
        # assert (len(serial_id_list) == len(mic_name))
        del mic_name_list  # Do not need this anymore

        # Load CTFs if necessary
        if options.import_ctf:

            ctf_list = read_text_row(options.import_ctf)
            # print("Detected CTF entries : %6d ..." % (len(ctf_list)))

            if len(ctf_list) == 0:
                error_status = (
                    "No CTF entry is found in %s. Please check --import_ctf option. Run %s -h for help."
                    % (options.import_ctf, progname),
                    getframeinfo(currentframe()))
                break

            if (len(ctf_list[0]) != n_idx_cter):
                error_status = (
                    "Number of columns (%d) must be %d in %s. The format might be old. Please run sxcter.py again."
                    % (len(ctf_list[0]), n_idx_cter, options.import_ctf),
                    getframeinfo(currentframe()))
                break

            ctf_dict = {}
            n_reject_defocus_error = 0
            ctf_error_limit = [
                options.defocus_error / 100.0, options.astigmatism_error
            ]
            for ctf_params in ctf_list:
                assert (len(ctf_params) == n_idx_cter)
                # mic_baseroot is name of micrograph minus the path and extension
                mic_baseroot = os.path.splitext(
                    os.path.basename(ctf_params[idx_cter_mic_name]))[0]
                if (ctf_params[idx_cter_sd_def] / ctf_params[idx_cter_def] >
                        ctf_error_limit[0]):
                    print(
                        "Defocus error %f exceeds the threshold. Micrograph %s is rejected."
                        % (ctf_params[idx_cter_sd_def] /
                           ctf_params[idx_cter_def], mic_baseroot))
                    n_reject_defocus_error += 1
                else:
                    if (ctf_params[idx_cter_sd_astig_ang] >
                            ctf_error_limit[1]):
                        ctf_params[idx_cter_astig_amp] = 0.0
                        ctf_params[idx_cter_astig_ang] = 0.0
                    ctf_dict[mic_baseroot] = ctf_params
            del ctf_list  # Do not need this anymore

        break

    if_error_then_all_processes_exit_program(error_status)

    if options.import_ctf:
        if options.limit_ctf:
            cutoff_histogram = [
            ]  #@ming compute the histogram for micrographs cut of by ctf_params limit.

##################################################################################################################################################################################################################
##################################################################################################################################################################################################################
##################################################################################################################################################################################################################

    restricted_serial_id_list = []
    if myid == main_node:
        # Loop over serial IDs of micrographs
        for serial_id in serial_id_list:
            # mic_baseroot is name of micrograph minus the path and extension
            mic_baseroot = mic_baseroot_pattern.replace("*", serial_id)
            mic_name = mic_pattern.replace("*", serial_id)
            coords_name = coords_pattern.replace("*", serial_id)

            ########### # CHECKS: BEGIN
            if coords_name not in coords_name_list:
                print("    Cannot read %s. Skipping %s ..." %
                      (coords_name, mic_baseroot))
                n_mic_reject_no_coords += 1
                continue

            # IF mic is in CTER results
            if options.import_ctf:
                if mic_baseroot not in ctf_dict:
                    print(
                        "    Is not listed in CTER results. Skipping %s ..." %
                        (mic_baseroot))
                    n_mic_reject_no_cter_entry += 1
                    continue
                else:
                    ctf_params = ctf_dict[mic_baseroot]
            # CHECKS: END

            n_mic_process += 1

            restricted_serial_id_list.append(serial_id)
        # restricted_serial_id_list = restricted_serial_id_list[:128]  ## for testing against the nonMPI version

    if myid != main_node:
        if options.import_ctf:
            ctf_dict = None

    error_status = None
    if len(restricted_serial_id_list) < number_of_processes:
        error_status = (
            'Number of processes (%d) supplied by --np in mpirun cannot be greater than %d (number of micrographs that satisfy all criteria to be processed) '
            % (number_of_processes, len(restricted_serial_id_list)),
            getframeinfo(currentframe()))
    if_error_then_all_processes_exit_program(error_status)

    ## keep a copy of the original output directory where the final bdb will be created
    original_out_dir = out_dir
    if RUNNING_UNDER_MPI:
        mpi_barrier(MPI_COMM_WORLD)
        restricted_serial_id_list = wrap_mpi_bcast(restricted_serial_id_list,
                                                   main_node)
        mic_start, mic_end = MPI_start_end(len(restricted_serial_id_list),
                                           number_of_processes, myid)
        restricted_serial_id_list_not_sliced = restricted_serial_id_list
        restricted_serial_id_list = restricted_serial_id_list[
            mic_start:mic_end]

        if options.import_ctf:
            ctf_dict = wrap_mpi_bcast(ctf_dict, main_node)

        # generate subdirectories of out_dir, one for each process
        out_dir = os.path.join(out_dir, "%03d" % myid)

    if myid == main_node:
        print(
            "Micrographs processed by main process (including percent complete):"
        )

    len_processed_by_main_node_divided_by_100 = len(
        restricted_serial_id_list) / 100.0

    ##################################################################################################################################################################################################################
    ##################################################################################################################################################################################################################
    ##################################################################################################################################################################################################################
    #####  Starting main parallel execution

    for my_idx, serial_id in enumerate(restricted_serial_id_list):
        mic_baseroot = mic_baseroot_pattern.replace("*", serial_id)
        mic_name = mic_pattern.replace("*", serial_id)
        coords_name = coords_pattern.replace("*", serial_id)

        if myid == main_node:
            print(
                mic_name, " ---> % 2.2f%%" %
                (my_idx / len_processed_by_main_node_divided_by_100))
        mic_img = get_im(mic_name)

        # Read coordinates according to the specified format and
        # make the coordinates the center of particle image
        if coords_format == "sparx":
            coords_list = read_text_row(coords_name)
        elif coords_format == "eman1":
            coords_list = read_text_row(coords_name)
            for i in xrange(len(coords_list)):
                coords_list[i] = [(coords_list[i][0] + coords_list[i][2] // 2),
                                  (coords_list[i][1] + coords_list[i][3] // 2)]
        elif coords_format == "eman2":
            coords_list = js_open_dict(coords_name)["boxes"]
            for i in xrange(len(coords_list)):
                coords_list[i] = [coords_list[i][0], coords_list[i][1]]
        elif coords_format == "spider":
            coords_list = read_text_row(coords_name)
            for i in xrange(len(coords_list)):
                coords_list[i] = [coords_list[i][2], coords_list[i][3]]
            # else: assert (False) # Unreachable code

        # Calculate the new pixel size
        if options.import_ctf:
            ctf_params = ctf_dict[mic_baseroot]
            pixel_size_origin = ctf_params[idx_cter_apix]

            if resample_ratio < 1.0:
                # assert (resample_ratio > 0.0)
                new_pixel_size = pixel_size_origin / resample_ratio
                print(
                    "Resample micrograph to pixel size %6.4f and window segments from resampled micrograph."
                    % new_pixel_size)
            else:
                # assert (resample_ratio == 1.0)
                new_pixel_size = pixel_size_origin

            # Set ctf along with new pixel size in resampled micrograph
            ctf_params[idx_cter_apix] = new_pixel_size
        else:
            # assert (not options.import_ctf)
            if resample_ratio < 1.0:
                # assert (resample_ratio > 0.0)
                print(
                    "Resample micrograph with ratio %6.4f and window segments from resampled micrograph."
                    % resample_ratio)
            # else:
            #	assert (resample_ratio == 1.0)

        # Apply filters to micrograph
        fftip(mic_img)
        if options.limit_ctf:
            # assert (options.import_ctf)
            # Cut off frequency components higher than CTF limit
            q1, q2 = ctflimit(box_size, ctf_params[idx_cter_def],
                              ctf_params[idx_cter_cs],
                              ctf_params[idx_cter_vol], new_pixel_size)

            # This is absolute frequency of CTF limit in scale of original micrograph
            if resample_ratio < 1.0:
                # assert (resample_ratio > 0.0)
                q1 = resample_ratio * q1 / float(
                    box_size
                )  # q1 = (pixel_size_origin / new_pixel_size) * q1/float(box_size)
            else:
                # assert (resample_ratio == 1.0) -> pixel_size_origin == new_pixel_size -> pixel_size_origin / new_pixel_size == 1.0
                q1 = q1 / float(box_size)

            if q1 < 0.5:
                mic_img = filt_tanl(mic_img, q1, 0.01)
                cutoff_histogram.append(q1)

        # Cut off frequency components lower than the box size can express
        mic_img = fft(filt_gaussh(mic_img, resample_ratio / box_size))

        # Resample micrograph, map coordinates, and window segments from resampled micrograph using new coordinates
        # after resampling by resample_ratio, new pixel size will be pixel_size/resample_ratio = new_pixel_size
        # NOTE: 2015/04/13 Toshio Moriya
        # resample() efficiently takes care of the case resample_ratio = 1.0 but
        # it does not set apix_*. Even though it sets apix_* when resample_ratio < 1.0 ...
        mic_img = resample(mic_img, resample_ratio)

        if options.invert:
            mic_stats = Util.infomask(
                mic_img, None, True)  # mic_stat[0:mean, 1:SD, 2:min, 3:max]
            Util.mul_scalar(mic_img, -1.0)
            mic_img += 2 * mic_stats[0]

        if options.import_ctf:
            from utilities import generate_ctf
            ctf_obj = generate_ctf(
                ctf_params
            )  # indexes 0 to 7 (idx_cter_def to idx_cter_astig_ang) must be same in cter format & ctf object format.

        # Prepare loop variables
        nx = mic_img.get_xsize()
        ny = mic_img.get_ysize()
        x0 = nx // 2
        y0 = ny // 2

        n_coords_reject_out_of_boundary = 0
        local_stack_name = "bdb:%s#" % out_dir + mic_baseroot + '_ptcls'
        local_particle_id = 0  # can be different from coordinates_id
        # Loop over coordinates
        for coords_id in xrange(len(coords_list)):

            x = int(coords_list[coords_id][0])
            y = int(coords_list[coords_id][1])

            if resample_ratio < 1.0:
                # assert (resample_ratio > 0.0)
                x = int(x * resample_ratio)
                y = int(y * resample_ratio)
            # else:
            # 	assert(resample_ratio == 1.0)

            if ((0 <= x - box_half) and (x + box_half <= nx)
                    and (0 <= y - box_half) and (y + box_half <= ny)):
                particle_img = Util.window(mic_img, box_size, box_size, 1,
                                           x - x0, y - y0)
            else:
                print(
                    "In %s, coordinates ID = %04d (x = %4d, y = %4d, box_size = %4d) is out of micrograph bound, skipping ..."
                    % (mic_baseroot, coords_id, x, y, box_size))
                n_coords_reject_out_of_boundary += 1
                continue

            particle_img = ramp(particle_img)
            particle_stats = Util.infomask(
                particle_img, mask2d,
                False)  # particle_stats[0:mean, 1:SD, 2:min, 3:max]
            particle_img -= particle_stats[0]
            particle_img /= particle_stats[1]

            # NOTE: 2015/04/09 Toshio Moriya
            # ptcl_source_image might be redundant information ...
            # Consider re-organizing header entries...
            particle_img.set_attr("ptcl_source_image", mic_name)
            particle_img.set_attr("ptcl_source_coord_id", coords_id)
            particle_img.set_attr("ptcl_source_coord", [
                int(coords_list[coords_id][0]),
                int(coords_list[coords_id][1])
            ])
            particle_img.set_attr("resample_ratio", resample_ratio)

            # NOTE: 2015/04/13 Toshio Moriya
            # apix_* attributes are updated by resample() only when resample_ratio != 1.0
            # Let's make sure header info is consistent by setting apix_* = 1.0
            # regardless of options, so it is not passed down the processing line
            particle_img.set_attr("apix_x", 1.0)
            particle_img.set_attr("apix_y", 1.0)
            particle_img.set_attr("apix_z", 1.0)
            if options.import_ctf:
                particle_img.set_attr("ctf", ctf_obj)
                particle_img.set_attr("ctf_applied", 0)
                particle_img.set_attr("pixel_size_origin", pixel_size_origin)
                # particle_img.set_attr("apix_x", new_pixel_size)
                # particle_img.set_attr("apix_y", new_pixel_size)
                # particle_img.set_attr("apix_z", new_pixel_size)
            # NOTE: 2015/04/13 Toshio Moriya
            # Pawel Comment: Micrograph is not supposed to have CTF header info.
            # So, let's assume it does not exist & ignore its presence.
            # Note that resample() "correctly" updates pixel size of CTF header info if it exists
            # elif (particle_img.has_ctff()):
            # 	assert(not options.import_ctf)
            # 	ctf_origin = particle_img.get_attr("ctf_obj")
            # 	pixel_size_origin = round(ctf_origin.apix, 5) # Because SXCTER ouputs up to 5 digits
            # 	particle_img.set_attr("apix_x",pixel_size_origin)
            # 	particle_img.set_attr("apix_y",pixel_size_origin)
            # 	particle_img.set_attr("apix_z",pixel_size_origin)

            # print("local_stack_name, local_particle_id", local_stack_name, local_particle_id)
            particle_img.write_image(local_stack_name, local_particle_id)
            local_particle_id += 1

        n_global_coords_detect += len(coords_list)
        n_global_coords_process += local_particle_id
        n_global_coords_reject_out_of_boundary += n_coords_reject_out_of_boundary

        #		# MRK_DEBUG: Toshio Moriya 2016/05/03
        #		# Following codes are for debugging bdb. Delete in future
        #		result = db_check_dict(local_stack_name)
        #		print('# MRK_DEBUG: result = db_check_dict(local_stack_name): %s' % (result))
        #		result = db_list_dicts('bdb:%s' % out_dir)
        #		print('# MRK_DEBUG: result = db_list_dicts(out_dir): %s' % (result))
        #		result = db_get_image_info(local_stack_name)
        #		print('# MRK_DEBUG: result = db_get_image_info(local_stack_name)', result)

        # Release the data base of local stack from this process
        # so that the subprocess can access to the data base
        db_close_dict(local_stack_name)


#		# MRK_DEBUG: Toshio Moriya 2016/05/03
#		# Following codes are for debugging bdb. Delete in future
#		cmd_line = "e2iminfo.py %s" % (local_stack_name)
#		print('# MRK_DEBUG: Executing the command: %s' % (cmd_line))
#		cmdexecute(cmd_line)

#		# MRK_DEBUG: Toshio Moriya 2016/05/03
#		# Following codes are for debugging bdb. Delete in future
#		cmd_line = "e2iminfo.py bdb:%s#data" % (out_dir)
#		print('# MRK_DEBUG: Executing the command: %s' % (cmd_line))
#		cmdexecute(cmd_line)

    if RUNNING_UNDER_MPI:
        if options.import_ctf:
            if options.limit_ctf:
                cutoff_histogram = wrap_mpi_gatherv(cutoff_histogram,
                                                    main_node)

    if myid == main_node:
        if options.limit_ctf:
            # Print out the summary of CTF-limit filtering
            print(" ")
            print("Global summary of CTF-limit filtering (--limit_ctf) ...")
            print("Percentage of filtered micrographs: %8.2f\n" %
                  (len(cutoff_histogram) * 100.0 /
                   len(restricted_serial_id_list_not_sliced)))

            n_bins = 10
            if len(cutoff_histogram) >= n_bins:
                from statistics import hist_list
                cutoff_region, cutoff_counts = hist_list(
                    cutoff_histogram, n_bins)
                print("      Histogram of cut-off frequency")
                print("      cut-off       counts")
                for bin_id in xrange(n_bins):
                    print(" %14.7f     %7d" %
                          (cutoff_region[bin_id], cutoff_counts[bin_id]))
            else:
                print(
                    "The number of filtered micrographs (%d) is less than the number of bins (%d). No histogram is produced."
                    % (len(cutoff_histogram), n_bins))

    n_mic_process = mpi_reduce(n_mic_process, 1, MPI_INT, MPI_SUM, main_node,
                               MPI_COMM_WORLD)
    n_mic_reject_no_coords = mpi_reduce(n_mic_reject_no_coords, 1, MPI_INT,
                                        MPI_SUM, main_node, MPI_COMM_WORLD)
    n_mic_reject_no_cter_entry = mpi_reduce(n_mic_reject_no_cter_entry, 1,
                                            MPI_INT, MPI_SUM, main_node,
                                            MPI_COMM_WORLD)
    n_global_coords_detect = mpi_reduce(n_global_coords_detect, 1, MPI_INT,
                                        MPI_SUM, main_node, MPI_COMM_WORLD)
    n_global_coords_process = mpi_reduce(n_global_coords_process, 1, MPI_INT,
                                         MPI_SUM, main_node, MPI_COMM_WORLD)
    n_global_coords_reject_out_of_boundary = mpi_reduce(
        n_global_coords_reject_out_of_boundary, 1, MPI_INT, MPI_SUM, main_node,
        MPI_COMM_WORLD)

    # Print out the summary of all micrographs
    if main_node == myid:
        print(" ")
        print("Global summary of micrographs ...")
        print("Detected                        : %6d" %
              (len(restricted_serial_id_list_not_sliced)))
        print("Processed                       : %6d" % (n_mic_process))
        print("Rejected by no coordinates file : %6d" %
              (n_mic_reject_no_coords))
        print("Rejected by no CTER entry       : %6d" %
              (n_mic_reject_no_cter_entry))
        print(" ")
        print("Global summary of coordinates ...")
        print("Detected                        : %6d" %
              (n_global_coords_detect))
        print("Processed                       : %6d" %
              (n_global_coords_process))
        print("Rejected by out of boundary     : %6d" %
              (n_global_coords_reject_out_of_boundary))
        # print(" ")
        # print("DONE!!!")

    mpi_barrier(MPI_COMM_WORLD)

    if main_node == myid:

        import time
        time.sleep(1)
        print("\n Creating bdb:%s/data\n" % original_out_dir)
        for proc_i in range(number_of_processes):
            mic_start, mic_end = MPI_start_end(
                len(restricted_serial_id_list_not_sliced), number_of_processes,
                proc_i)
            for serial_id in restricted_serial_id_list_not_sliced[
                    mic_start:mic_end]:
                e2bdb_command = "e2bdb.py "
                mic_baseroot = mic_baseroot_pattern.replace("*", serial_id)
                if RUNNING_UNDER_MPI:
                    e2bdb_command += "bdb:" + os.path.join(
                        original_out_dir,
                        "%03d/" % proc_i) + mic_baseroot + "_ptcls "
                else:
                    e2bdb_command += "bdb:" + os.path.join(
                        original_out_dir, mic_baseroot + "_ptcls ")

                e2bdb_command += " --appendvstack=bdb:%s/data  1>/dev/null" % original_out_dir
                cmdexecute(e2bdb_command, printing_on_success=False)

        print("Done!\n")

    if RUNNING_UNDER_MPI:
        mpi_barrier(MPI_COMM_WORLD)
        from mpi import mpi_finalize
        mpi_finalize()

    sys.stdout.flush()
    sys.exit(0)
Beispiel #27
0
def main():
	progname = os.path.basename(sys.argv[0])
	usage = progname + """  input_micrograph_list_file  input_micrograph_pattern  input_coordinates_pattern  output_directory  --coordinates_format  --box_size=box_size  --invert  --import_ctf=ctf_file  --limit_ctf  --resample_ratio=resample_ratio  --defocus_error=defocus_error  --astigmatism_error=astigmatism_error
	
Window particles from micrographs in input list file. The coordinates of the particles should be given as input.
Please specify name pattern of input micrographs and coordinates files with a wild card (*). Use the wild card to indicate the place of micrograph ID (e.g. serial number, time stamp, and etc). 
The name patterns must be enclosed by single quotes (') or double quotes ("). (Note: sxgui.py automatically adds single quotes (')). 
BDB files can not be selected as input micrographs.
	
	sxwindow.py  mic_list.txt  ./mic*.hdf  info/mic*_info.json  particles  --coordinates_format=eman2  --box_size=64  --invert  --import_ctf=outdir_cter/partres/partres.txt
	
If micrograph list file name is not provided, all files matched with the micrograph name pattern will be processed.
	
	sxwindow.py  ./mic*.hdf  info/mic*_info.json  particles  --coordinates_format=eman2  --box_size=64  --invert  --import_ctf=outdir_cter/partres/partres.txt
	
"""
	parser = OptionParser(usage, version=SPARXVERSION)
	parser.add_option("--coordinates_format",  type="string",        default="eman1",   help="format of input coordinates files: 'sparx', 'eman1', 'eman2', or 'spider'. the coordinates of sparx, eman2, and spider format is particle center. the coordinates of eman1 format is particle box conner associated with the original box size. (default eman1)")
	parser.add_option("--box_size",            type="int",           default=256,       help="x and y dimension of square area to be windowed (in pixels): pixel size after resampling is assumed when resample_ratio < 1.0 (default 256)")
	parser.add_option("--invert",              action="store_true",  default=False,     help="invert image contrast: recommended for cryo data (default False)")
	parser.add_option("--import_ctf",          type="string",        default="",        help="file name of sxcter output: normally partres.txt (default none)") 
	parser.add_option("--limit_ctf",           action="store_true",  default=False,     help="filter micrographs based on the CTF limit: this option requires --import_ctf. (default False)")	
	parser.add_option("--resample_ratio",      type="float",         default=1.0,       help="ratio of new to old image size (or old to new pixel size) for resampling: Valid range is 0.0 < resample_ratio <= 1.0. (default 1.0)")
	parser.add_option("--defocus_error",       type="float",         default=1000000.0, help="defocus errror limit: exclude micrographs whose relative defocus error as estimated by sxcter is larger than defocus_error percent. the error is computed as (std dev defocus)/defocus*100%. (default 1000000.0)" )
	parser.add_option("--astigmatism_error",   type="float",         default=360.0,     help="astigmatism error limit: Set to zero astigmatism for micrographs whose astigmatism angular error as estimated by sxcter is larger than astigmatism_error degrees. (default 360.0)")

	### detect if program is running under MPI
	RUNNING_UNDER_MPI = "OMPI_COMM_WORLD_SIZE" in os.environ
	
	main_node = 0
	
	if RUNNING_UNDER_MPI:
		from mpi import mpi_init
		from mpi import MPI_COMM_WORLD, mpi_comm_rank, mpi_comm_size, mpi_barrier, mpi_reduce, MPI_INT, MPI_SUM
		
		
		mpi_init(0, [])
		myid = mpi_comm_rank(MPI_COMM_WORLD)
		number_of_processes = mpi_comm_size(MPI_COMM_WORLD)
	else:
		number_of_processes = 1
		myid = 0
	
	(options, args) = parser.parse_args(sys.argv[1:])
	
	mic_list_file_path = None
	mic_pattern = None
	coords_pattern = None
	error_status = None
	while True:
		if len(args) < 3 or len(args) > 4:
			error_status = ("Please check usage for number of arguments.\n Usage: " + usage + "\n" + "Please run %s -h for help." % (progname), getframeinfo(currentframe()))
			break
		
		if len(args) == 3:
			mic_pattern = args[0]
			coords_pattern = args[1]
			out_dir = args[2]
		else: # assert(len(args) == 4)
			mic_list_file_path = args[0]
			mic_pattern = args[1]
			coords_pattern = args[2]
			out_dir = args[3]
		
		if mic_list_file_path != None:
			if os.path.splitext(mic_list_file_path)[1] != ".txt":
				error_status = ("Extension of input micrograph list file must be \".txt\". Please check input_micrograph_list_file argument. Run %s -h for help." % (progname), getframeinfo(currentframe()))
				break
		
		if mic_pattern[:len("bdb:")].lower() == "bdb":
			error_status = ("BDB file can not be selected as input micrographs. Please convert the format, and restart the program. Run %s -h for help." % (progname), getframeinfo(currentframe()))
			break
		
		if mic_pattern.find("*") == -1:
			error_status = ("Input micrograph file name pattern must contain wild card (*). Please check input_micrograph_pattern argument. Run %s -h for help." % (progname), getframeinfo(currentframe()))
			break
		
		if coords_pattern.find("*") == -1:
			error_status = ("Input coordinates file name pattern must contain wild card (*). Please check input_coordinates_pattern argument. Run %s -h for help." % (progname), getframeinfo(currentframe()))
			break
		
		if myid == main_node:
			if os.path.exists(out_dir):
				error_status = ("Output directory exists. Please change the name and restart the program.", getframeinfo(currentframe()))
				break

		break
	if_error_then_all_processes_exit_program(error_status)
	
	# Check invalid conditions of options
	check_options(options, progname)
	
	mic_name_list = None
	error_status = None
	if myid == main_node:
		if mic_list_file_path != None:
			print("Loading micrograph list from %s file ..." % (mic_list_file_path))
			mic_name_list = read_text_file(mic_list_file_path)
			if len(mic_name_list) == 0:
				print("Directory of first micrograph entry is " % (os.path.dirname(mic_name_list[0])))
		else: # assert (mic_list_file_path == None)
			print("Generating micrograph list in %s directory..." % (os.path.dirname(mic_pattern)))
			mic_name_list = glob.glob(mic_pattern)
		if len(mic_name_list) == 0:
			error_status = ("No micrograph file is found. Please check input_micrograph_pattern and/or input_micrograph_list_file argument. Run %s -h for help." % (progname), getframeinfo(currentframe()))
		else:
			print("Found %d microgarphs" % len(mic_name_list))
			
	if_error_then_all_processes_exit_program(error_status)
	if RUNNING_UNDER_MPI:
		mic_name_list = wrap_mpi_bcast(mic_name_list, main_node)
	
	coords_name_list = None
	error_status = None
	if myid == main_node:
		coords_name_list = glob.glob(coords_pattern)
		if len(coords_name_list) == 0:
			error_status = ("No coordinates file is found. Please check input_coordinates_pattern argument. Run %s -h for help." % (progname), getframeinfo(currentframe()))
	if_error_then_all_processes_exit_program(error_status)
	if RUNNING_UNDER_MPI:
		coords_name_list = wrap_mpi_bcast(coords_name_list, main_node)
	
##################################################################################################################################################################################################################	
##################################################################################################################################################################################################################	
##################################################################################################################################################################################################################	

	# all processes must have access to indices
	if options.import_ctf:
		i_enum = -1
		i_enum += 1; idx_cter_def          = i_enum # defocus [um]; index must be same as ctf object format
		i_enum += 1; idx_cter_cs           = i_enum # Cs [mm]; index must be same as ctf object format
		i_enum += 1; idx_cter_vol          = i_enum # voltage[kV]; index must be same as ctf object format
		i_enum += 1; idx_cter_apix         = i_enum # pixel size [A]; index must be same as ctf object format
		i_enum += 1; idx_cter_bfactor      = i_enum # B-factor [A^2]; index must be same as ctf object format
		i_enum += 1; idx_cter_ac           = i_enum # amplitude contrast [%]; index must be same as ctf object format
		i_enum += 1; idx_cter_astig_amp    = i_enum # astigmatism amplitude [um]; index must be same as ctf object format
		i_enum += 1; idx_cter_astig_ang    = i_enum # astigmatism angle [degree]; index must be same as ctf object format
		i_enum += 1; idx_cter_sd_def       = i_enum # std dev of defocus [um]
		i_enum += 1; idx_cter_sd_astig_amp = i_enum # std dev of ast amp [A]
		i_enum += 1; idx_cter_sd_astig_ang = i_enum # std dev of ast angle [degree]
		i_enum += 1; idx_cter_cv_def       = i_enum # coefficient of variation of defocus [%]
		i_enum += 1; idx_cter_cv_astig_amp = i_enum # coefficient of variation of ast amp [%]
		i_enum += 1; idx_cter_spectra_diff = i_enum # average of differences between with- and without-astig. experimental 1D spectra at extrema
		i_enum += 1; idx_cter_error_def    = i_enum # frequency at which signal drops by 50% due to estimated error of defocus alone [1/A]
		i_enum += 1; idx_cter_error_astig  = i_enum # frequency at which signal drops by 50% due to estimated error of defocus and astigmatism [1/A]
		i_enum += 1; idx_cter_error_ctf    = i_enum # limit frequency by CTF error [1/A]
		i_enum += 1; idx_cter_mic_name     = i_enum # micrograph name
		i_enum += 1; n_idx_cter            = i_enum
	
	
	# Prepare loop variables
	mic_basename_pattern = os.path.basename(mic_pattern)              # file pattern without path
	mic_baseroot_pattern = os.path.splitext(mic_basename_pattern)[0]  # file pattern without path and extension
	coords_format = options.coordinates_format.lower()
	box_size = options.box_size
	box_half = box_size // 2
	mask2d = model_circle(box_size//2, box_size, box_size) # Create circular 2D mask to Util.infomask of particle images
	resample_ratio = options.resample_ratio
	
	n_mic_process = 0
	n_mic_reject_no_coords = 0
	n_mic_reject_no_cter_entry = 0
	n_global_coords_detect = 0
	n_global_coords_process = 0
	n_global_coords_reject_out_of_boundary = 0
	
	serial_id_list = []
	error_status = None
	## not a real while, an if with the opportunity to use break when errors need to be reported
	while myid == main_node:
		# 
		# NOTE: 2016/05/24 Toshio Moriya
		# Now, ignores the path in mic_pattern and entries of mic_name_list to create serial ID
		# Only the basename (file name) in micrograph path must be match
		# 
		# Create list of micrograph serial ID
		# Break micrograph name pattern into prefix and suffix to find the head index of the micrograph serial id
		# 
		mic_basename_tokens = mic_basename_pattern.split('*')
		# assert (len(mic_basename_tokens) == 2)
		serial_id_head_index = len(mic_basename_tokens[0])
		# Loop through micrograph names
		for mic_name in mic_name_list:
			# Find the tail index of the serial id and extract serial id from the micrograph name
			mic_basename = os.path.basename(mic_name)
			serial_id_tail_index = mic_basename.index(mic_basename_tokens[1])
			serial_id = mic_basename[serial_id_head_index:serial_id_tail_index]
			serial_id_list.append(serial_id)
		# assert (len(serial_id_list) == len(mic_name))
		del mic_name_list # Do not need this anymore
		
		# Load CTFs if necessary
		if options.import_ctf:
			
			ctf_list = read_text_row(options.import_ctf)
			# print("Detected CTF entries : %6d ..." % (len(ctf_list)))
			
			if len(ctf_list) == 0:
				error_status = ("No CTF entry is found in %s. Please check --import_ctf option. Run %s -h for help." % (options.import_ctf, progname), getframeinfo(currentframe()))
				break
			
			if (len(ctf_list[0]) != n_idx_cter):
				error_status = ("Number of columns (%d) must be %d in %s. The format might be old. Please run sxcter.py again." % (len(ctf_list[0]), n_idx_cter, options.import_ctf), getframeinfo(currentframe()))
				break
			
			ctf_dict={}
			n_reject_defocus_error = 0
			ctf_error_limit = [options.defocus_error/100.0, options.astigmatism_error]
			for ctf_params in ctf_list:
				assert(len(ctf_params) == n_idx_cter)
				# mic_baseroot is name of micrograph minus the path and extension
				mic_baseroot = os.path.splitext(os.path.basename(ctf_params[idx_cter_mic_name]))[0]
				if(ctf_params[idx_cter_sd_def] / ctf_params[idx_cter_def] > ctf_error_limit[0]):
					print("Defocus error %f exceeds the threshold. Micrograph %s is rejected." % (ctf_params[idx_cter_sd_def] / ctf_params[idx_cter_def], mic_baseroot))
					n_reject_defocus_error += 1
				else:
					if(ctf_params[idx_cter_sd_astig_ang] > ctf_error_limit[1]):
						ctf_params[idx_cter_astig_amp] = 0.0
						ctf_params[idx_cter_astig_ang] = 0.0
					ctf_dict[mic_baseroot] = ctf_params
			del ctf_list # Do not need this anymore
		
		break
		
	if_error_then_all_processes_exit_program(error_status)

	if options.import_ctf:
		if options.limit_ctf:
			cutoff_histogram = []  #@ming compute the histogram for micrographs cut of by ctf_params limit.
	
##################################################################################################################################################################################################################	
##################################################################################################################################################################################################################	
##################################################################################################################################################################################################################	
	
	restricted_serial_id_list = []
	if myid == main_node:
		# Loop over serial IDs of micrographs
		for serial_id in serial_id_list:
			# mic_baseroot is name of micrograph minus the path and extension
			mic_baseroot = mic_baseroot_pattern.replace("*", serial_id)
			mic_name = mic_pattern.replace("*", serial_id)
			coords_name = coords_pattern.replace("*", serial_id)
			
			########### # CHECKS: BEGIN
			if coords_name not in coords_name_list:
				print("    Cannot read %s. Skipping %s ..." % (coords_name, mic_baseroot))
				n_mic_reject_no_coords += 1
				continue
			
			# IF mic is in CTER results
			if options.import_ctf:
				if mic_baseroot not in ctf_dict:
					print("    Is not listed in CTER results. Skipping %s ..." % (mic_baseroot))
					n_mic_reject_no_cter_entry += 1
					continue
				else:
					ctf_params = ctf_dict[mic_baseroot]
			# CHECKS: END
			
			n_mic_process += 1
			
			restricted_serial_id_list.append(serial_id)
		# restricted_serial_id_list = restricted_serial_id_list[:128]  ## for testing against the nonMPI version

	
	if myid != main_node:
		if options.import_ctf:
			ctf_dict = None

	error_status = None
	if len(restricted_serial_id_list) < number_of_processes:
		error_status = ('Number of processes (%d) supplied by --np in mpirun cannot be greater than %d (number of micrographs that satisfy all criteria to be processed) ' % (number_of_processes, len(restricted_serial_id_list)), getframeinfo(currentframe()))
	if_error_then_all_processes_exit_program(error_status)

	## keep a copy of the original output directory where the final bdb will be created
	original_out_dir = out_dir
	if RUNNING_UNDER_MPI:
		mpi_barrier(MPI_COMM_WORLD)
		restricted_serial_id_list = wrap_mpi_bcast(restricted_serial_id_list, main_node)
		mic_start, mic_end = MPI_start_end(len(restricted_serial_id_list), number_of_processes, myid)
		restricted_serial_id_list_not_sliced = restricted_serial_id_list
		restricted_serial_id_list = restricted_serial_id_list[mic_start:mic_end]
	
		if options.import_ctf:
			ctf_dict = wrap_mpi_bcast(ctf_dict, main_node)

		# generate subdirectories of out_dir, one for each process
		out_dir = os.path.join(out_dir,"%03d"%myid)
	
	if myid == main_node:
		print("Micrographs processed by main process (including percent complete):")

	len_processed_by_main_node_divided_by_100 = len(restricted_serial_id_list)/100.0

##################################################################################################################################################################################################################	
##################################################################################################################################################################################################################	
##################################################################################################################################################################################################################	
#####  Starting main parallel execution

	for my_idx, serial_id in enumerate(restricted_serial_id_list):
		mic_baseroot = mic_baseroot_pattern.replace("*", serial_id)
		mic_name = mic_pattern.replace("*", serial_id)
		coords_name = coords_pattern.replace("*", serial_id)

		if myid == main_node:
			print(mic_name, " ---> % 2.2f%%"%(my_idx/len_processed_by_main_node_divided_by_100))
		mic_img = get_im(mic_name)

		# Read coordinates according to the specified format and 
		# make the coordinates the center of particle image 
		if coords_format == "sparx":
			coords_list = read_text_row(coords_name)
		elif coords_format == "eman1":
			coords_list = read_text_row(coords_name)
			for i in xrange(len(coords_list)):
				coords_list[i] = [(coords_list[i][0] + coords_list[i][2] // 2), (coords_list[i][1] + coords_list[i][3] // 2)]
		elif coords_format == "eman2":
			coords_list = js_open_dict(coords_name)["boxes"]
			for i in xrange(len(coords_list)):
				coords_list[i] = [coords_list[i][0], coords_list[i][1]]
		elif coords_format == "spider":
			coords_list = read_text_row(coords_name)
			for i in xrange(len(coords_list)):
				coords_list[i] = [coords_list[i][2], coords_list[i][3]]
			# else: assert (False) # Unreachable code
		
		# Calculate the new pixel size
		if options.import_ctf:
			ctf_params = ctf_dict[mic_baseroot]
			pixel_size_origin = ctf_params[idx_cter_apix]
			
			if resample_ratio < 1.0:
				# assert (resample_ratio > 0.0)
				new_pixel_size = pixel_size_origin / resample_ratio
				print("Resample micrograph to pixel size %6.4f and window segments from resampled micrograph." % new_pixel_size)
			else:
				# assert (resample_ratio == 1.0)
				new_pixel_size = pixel_size_origin
		
			# Set ctf along with new pixel size in resampled micrograph
			ctf_params[idx_cter_apix] = new_pixel_size
		else:
			# assert (not options.import_ctf)
			if resample_ratio < 1.0:
				# assert (resample_ratio > 0.0)
				print("Resample micrograph with ratio %6.4f and window segments from resampled micrograph." % resample_ratio)
			# else:
			#	assert (resample_ratio == 1.0)
		
		# Apply filters to micrograph
		fftip(mic_img)
		if options.limit_ctf:
			# assert (options.import_ctf)
			# Cut off frequency components higher than CTF limit 
			q1, q2 = ctflimit(box_size, ctf_params[idx_cter_def], ctf_params[idx_cter_cs], ctf_params[idx_cter_vol], new_pixel_size)
			
			# This is absolute frequency of CTF limit in scale of original micrograph
			if resample_ratio < 1.0:
				# assert (resample_ratio > 0.0)
				q1 = resample_ratio * q1 / float(box_size) # q1 = (pixel_size_origin / new_pixel_size) * q1/float(box_size)
			else:
				# assert (resample_ratio == 1.0) -> pixel_size_origin == new_pixel_size -> pixel_size_origin / new_pixel_size == 1.0
				q1 = q1 / float(box_size)
			
			if q1 < 0.5:
				mic_img = filt_tanl(mic_img, q1, 0.01)
				cutoff_histogram.append(q1)
		
		# Cut off frequency components lower than the box size can express 
		mic_img = fft(filt_gaussh(mic_img, resample_ratio / box_size))
		
		# Resample micrograph, map coordinates, and window segments from resampled micrograph using new coordinates
		# after resampling by resample_ratio, new pixel size will be pixel_size/resample_ratio = new_pixel_size
		# NOTE: 2015/04/13 Toshio Moriya
		# resample() efficiently takes care of the case resample_ratio = 1.0 but
		# it does not set apix_*. Even though it sets apix_* when resample_ratio < 1.0 ...
		mic_img = resample(mic_img, resample_ratio)
		
		if options.invert:
			mic_stats = Util.infomask(mic_img, None, True) # mic_stat[0:mean, 1:SD, 2:min, 3:max]
			Util.mul_scalar(mic_img, -1.0)
			mic_img += 2 * mic_stats[0]
		
		if options.import_ctf:
			from utilities import generate_ctf
			ctf_obj = generate_ctf(ctf_params) # indexes 0 to 7 (idx_cter_def to idx_cter_astig_ang) must be same in cter format & ctf object format.
		
		# Prepare loop variables
		nx = mic_img.get_xsize() 
		ny = mic_img.get_ysize()
		x0 = nx//2
		y0 = ny//2

		n_coords_reject_out_of_boundary = 0
		local_stack_name  = "bdb:%s#" % out_dir + mic_baseroot + '_ptcls'
		local_particle_id = 0 # can be different from coordinates_id
		# Loop over coordinates
		for coords_id in xrange(len(coords_list)):
			
			x = int(coords_list[coords_id][0])
			y = int(coords_list[coords_id][1])
			
			if resample_ratio < 1.0:
				# assert (resample_ratio > 0.0)
				x = int(x * resample_ratio)	
				y = int(y * resample_ratio)
			# else:
			# 	assert(resample_ratio == 1.0)
				
			if( (0 <= x - box_half) and ( x + box_half <= nx ) and (0 <= y - box_half) and ( y + box_half <= ny ) ):
				particle_img = Util.window(mic_img, box_size, box_size, 1, x-x0, y-y0)
			else:
				print("In %s, coordinates ID = %04d (x = %4d, y = %4d, box_size = %4d) is out of micrograph bound, skipping ..." % (mic_baseroot, coords_id, x, y, box_size))
				n_coords_reject_out_of_boundary += 1
				continue
			
			particle_img = ramp(particle_img)
			particle_stats = Util.infomask(particle_img, mask2d, False) # particle_stats[0:mean, 1:SD, 2:min, 3:max]
			particle_img -= particle_stats[0]
			particle_img /= particle_stats[1]
			
			# NOTE: 2015/04/09 Toshio Moriya
			# ptcl_source_image might be redundant information ...
			# Consider re-organizing header entries...
			particle_img.set_attr("ptcl_source_image", mic_name)
			particle_img.set_attr("ptcl_source_coord_id", coords_id)
			particle_img.set_attr("ptcl_source_coord", [int(coords_list[coords_id][0]), int(coords_list[coords_id][1])])
			particle_img.set_attr("resample_ratio", resample_ratio)
			
			# NOTE: 2015/04/13 Toshio Moriya
			# apix_* attributes are updated by resample() only when resample_ratio != 1.0
			# Let's make sure header info is consistent by setting apix_* = 1.0 
			# regardless of options, so it is not passed down the processing line
			particle_img.set_attr("apix_x", 1.0)
			particle_img.set_attr("apix_y", 1.0)
			particle_img.set_attr("apix_z", 1.0)
			if options.import_ctf:
				particle_img.set_attr("ctf",ctf_obj)
				particle_img.set_attr("ctf_applied", 0)
				particle_img.set_attr("pixel_size_origin", pixel_size_origin)
				# particle_img.set_attr("apix_x", new_pixel_size)
				# particle_img.set_attr("apix_y", new_pixel_size)
				# particle_img.set_attr("apix_z", new_pixel_size)
			# NOTE: 2015/04/13 Toshio Moriya 
			# Pawel Comment: Micrograph is not supposed to have CTF header info.
			# So, let's assume it does not exist & ignore its presence.
			# Note that resample() "correctly" updates pixel size of CTF header info if it exists
			# elif (particle_img.has_ctff()):
			# 	assert(not options.import_ctf)
			# 	ctf_origin = particle_img.get_attr("ctf_obj")
			# 	pixel_size_origin = round(ctf_origin.apix, 5) # Because SXCTER ouputs up to 5 digits 
			# 	particle_img.set_attr("apix_x",pixel_size_origin)
			# 	particle_img.set_attr("apix_y",pixel_size_origin)
			# 	particle_img.set_attr("apix_z",pixel_size_origin)	
			
			# print("local_stack_name, local_particle_id", local_stack_name, local_particle_id)
			particle_img.write_image(local_stack_name, local_particle_id)
			local_particle_id += 1
		
		n_global_coords_detect += len(coords_list)
		n_global_coords_process += local_particle_id
		n_global_coords_reject_out_of_boundary += n_coords_reject_out_of_boundary
		
#		# MRK_DEBUG: Toshio Moriya 2016/05/03
#		# Following codes are for debugging bdb. Delete in future
#		result = db_check_dict(local_stack_name)
#		print('# MRK_DEBUG: result = db_check_dict(local_stack_name): %s' % (result))
#		result = db_list_dicts('bdb:%s' % out_dir)
#		print('# MRK_DEBUG: result = db_list_dicts(out_dir): %s' % (result))
#		result = db_get_image_info(local_stack_name)
#		print('# MRK_DEBUG: result = db_get_image_info(local_stack_name)', result)
		
		# Release the data base of local stack from this process
		# so that the subprocess can access to the data base
		db_close_dict(local_stack_name)
		
#		# MRK_DEBUG: Toshio Moriya 2016/05/03
#		# Following codes are for debugging bdb. Delete in future
#		cmd_line = "e2iminfo.py %s" % (local_stack_name)
#		print('# MRK_DEBUG: Executing the command: %s' % (cmd_line))
#		cmdexecute(cmd_line)
		
#		# MRK_DEBUG: Toshio Moriya 2016/05/03
#		# Following codes are for debugging bdb. Delete in future
#		cmd_line = "e2iminfo.py bdb:%s#data" % (out_dir)
#		print('# MRK_DEBUG: Executing the command: %s' % (cmd_line))
#		cmdexecute(cmd_line)
		
	if RUNNING_UNDER_MPI:
		if options.import_ctf:
			if options.limit_ctf:
				cutoff_histogram = wrap_mpi_gatherv(cutoff_histogram, main_node)

	if myid == main_node:
		if options.limit_ctf:
			# Print out the summary of CTF-limit filtering
			print(" ")
			print("Global summary of CTF-limit filtering (--limit_ctf) ...")
			print("Percentage of filtered micrographs: %8.2f\n" % (len(cutoff_histogram) * 100.0 / len(restricted_serial_id_list_not_sliced)))

			n_bins = 10
			if len(cutoff_histogram) >= n_bins:
				from statistics import hist_list
				cutoff_region, cutoff_counts = hist_list(cutoff_histogram, n_bins)
				print("      Histogram of cut-off frequency")
				print("      cut-off       counts")
				for bin_id in xrange(n_bins):
					print(" %14.7f     %7d" % (cutoff_region[bin_id], cutoff_counts[bin_id]))
			else:
				print("The number of filtered micrographs (%d) is less than the number of bins (%d). No histogram is produced." % (len(cutoff_histogram), n_bins))
	
	n_mic_process = mpi_reduce(n_mic_process, 1, MPI_INT, MPI_SUM, main_node, MPI_COMM_WORLD)
	n_mic_reject_no_coords = mpi_reduce(n_mic_reject_no_coords, 1, MPI_INT, MPI_SUM, main_node, MPI_COMM_WORLD)
	n_mic_reject_no_cter_entry = mpi_reduce(n_mic_reject_no_cter_entry, 1, MPI_INT, MPI_SUM, main_node, MPI_COMM_WORLD)
	n_global_coords_detect = mpi_reduce(n_global_coords_detect, 1, MPI_INT, MPI_SUM, main_node, MPI_COMM_WORLD)
	n_global_coords_process = mpi_reduce(n_global_coords_process, 1, MPI_INT, MPI_SUM, main_node, MPI_COMM_WORLD)
	n_global_coords_reject_out_of_boundary = mpi_reduce(n_global_coords_reject_out_of_boundary, 1, MPI_INT, MPI_SUM, main_node, MPI_COMM_WORLD)
	
	# Print out the summary of all micrographs
	if main_node == myid:
		print(" ")
		print("Global summary of micrographs ...")
		print("Detected                        : %6d" % (len(restricted_serial_id_list_not_sliced)))
		print("Processed                       : %6d" % (n_mic_process))
		print("Rejected by no coordinates file : %6d" % (n_mic_reject_no_coords))
		print("Rejected by no CTER entry       : %6d" % (n_mic_reject_no_cter_entry))
		print(" ")
		print("Global summary of coordinates ...")
		print("Detected                        : %6d" % (n_global_coords_detect))
		print("Processed                       : %6d" % (n_global_coords_process))
		print("Rejected by out of boundary     : %6d" % (n_global_coords_reject_out_of_boundary))
		# print(" ")
		# print("DONE!!!")
	
	mpi_barrier(MPI_COMM_WORLD)
	
	if main_node == myid:
	
		import time
		time.sleep(1)
		print("\n Creating bdb:%s/data\n"%original_out_dir)
		for proc_i in range(number_of_processes):
			mic_start, mic_end = MPI_start_end(len(restricted_serial_id_list_not_sliced), number_of_processes, proc_i)
			for serial_id in restricted_serial_id_list_not_sliced[mic_start:mic_end]:
				e2bdb_command = "e2bdb.py "
				mic_baseroot = mic_baseroot_pattern.replace("*", serial_id)
				if RUNNING_UNDER_MPI:
					e2bdb_command += "bdb:" + os.path.join(original_out_dir,"%03d/"%proc_i) + mic_baseroot + "_ptcls "
				else:
					e2bdb_command += "bdb:" + os.path.join(original_out_dir, mic_baseroot + "_ptcls ") 
				
				e2bdb_command += " --appendvstack=bdb:%s/data  1>/dev/null"%original_out_dir
				cmdexecute(e2bdb_command, printing_on_success = False)
				
		print("Done!\n")
				
	if RUNNING_UNDER_MPI:
		mpi_barrier(MPI_COMM_WORLD)
		from mpi import mpi_finalize
		mpi_finalize()

	sys.stdout.flush()
	sys.exit(0)
Beispiel #28
0
def ali3d_MPI(stack, ref_vol, outdir, maskfile = None, ir = 1, ou = -1, rs = 1, 
	    xr = "4 2 2 1", yr = "-1", ts = "1 1 0.5 0.25", delta = "10 6 4 4", an = "-1",
	    center = 0, maxit = 5, term = 95, CTF = False, fourvar = False, snr = 1.0,  ref_a = "S", sym = "c1", 
	    sort=True, cutoff=999.99, pix_cutoff="0", two_tail=False, model_jump="1 1 1 1 1", restart=False, save_half=False,
	    protos=None, oplane=None, lmask=-1, ilmask=-1, findseam=False, vertstep=None, hpars="-1", hsearch="73.0 170.0",
	    full_output = False, compare_repro = False, compare_ref_free = "-1", ref_free_cutoff= "-1 -1 -1 -1",
	    wcmask = None, debug = False, recon_pad = 4):

	from alignment      import Numrinit, prepare_refrings
	from utilities      import model_circle, get_image, drop_image, get_input_from_string
	from utilities      import bcast_list_to_all, bcast_number_to_all, reduce_EMData_to_root, bcast_EMData_to_all 
	from utilities      import send_attr_dict
	from utilities      import get_params_proj, file_type
	from fundamentals   import rot_avg_image
	import os
	import types
	from utilities      import print_begin_msg, print_end_msg, print_msg
	from mpi	    import mpi_bcast, mpi_comm_size, mpi_comm_rank, MPI_FLOAT, MPI_COMM_WORLD, mpi_barrier, mpi_reduce
	from mpi	    import mpi_reduce, MPI_INT, MPI_SUM, mpi_finalize
	from filter	 import filt_ctf
	from projection     import prep_vol, prgs
	from statistics     import hist_list, varf3d_MPI, fsc_mask
	from numpy	  import array, bincount, array2string, ones

	number_of_proc = mpi_comm_size(MPI_COMM_WORLD)
	myid	   = mpi_comm_rank(MPI_COMM_WORLD)
	main_node = 0
	if myid == main_node:
		if os.path.exists(outdir):  ERROR('Output directory exists, please change the name and restart the program', "ali3d_MPI", 1)
		os.mkdir(outdir)
	mpi_barrier(MPI_COMM_WORLD)

	if debug:
		from time import sleep
		while not os.path.exists(outdir):
			print  "Node ",myid,"  waiting..."
			sleep(5)

		info_file = os.path.join(outdir, "progress%04d"%myid)
		finfo = open(info_file, 'w')
	else:
		finfo = None
	mjump = get_input_from_string(model_jump)
	xrng	= get_input_from_string(xr)
	if  yr == "-1":  yrng = xrng
	else	  :  yrng = get_input_from_string(yr)
	step	= get_input_from_string(ts)
	delta       = get_input_from_string(delta)
	ref_free_cutoff = get_input_from_string(ref_free_cutoff)	
	pix_cutoff = get_input_from_string(pix_cutoff)
	
	lstp = min(len(xrng), len(yrng), len(step), len(delta))
	if an == "-1":
		an = [-1] * lstp
	else:
		an = get_input_from_string(an)
	# make sure pix_cutoff is set for all iterations
	if len(pix_cutoff)<lstp:
		for i in xrange(len(pix_cutoff),lstp):
			pix_cutoff.append(pix_cutoff[-1])
	# don't waste time on sub-pixel alignment for low-resolution ang incr
	for i in range(len(step)):
		if (delta[i] > 4 or delta[i] == -1) and step[i] < 1:
			step[i] = 1

	first_ring  = int(ir)
	rstep       = int(rs)
	last_ring   = int(ou)
	max_iter    = int(maxit)
	center      = int(center)

	nrefs   = EMUtil.get_image_count( ref_vol )
	nmasks = 0
	if maskfile:
		# read number of masks within each maskfile (mc)
		nmasks   = EMUtil.get_image_count( maskfile )
		# open masks within maskfile (mc)
		maskF   = EMData.read_images(maskfile, xrange(nmasks))
	vol     = EMData.read_images(ref_vol, xrange(nrefs))
	nx      = vol[0].get_xsize()

	## make sure box sizes are the same
	if myid == main_node:
		im=EMData.read_images(stack,[0])
		bx = im[0].get_xsize()
		if bx!=nx:
			print_msg("Error: Stack box size (%i) differs from initial model (%i)\n"%(bx,nx))
			sys.exit()
		del im,bx
	
	# for helical processing:
	helicalrecon = False
	if protos is not None or hpars != "-1" or findseam is True:
		helicalrecon = True
		# if no out-of-plane param set, use 5 degrees
		if oplane is None:
			oplane=5.0
	if protos is not None:
		proto = get_input_from_string(protos)
		if len(proto) != nrefs:
			print_msg("Error: insufficient protofilament numbers supplied")
			sys.exit()
	if hpars != "-1":
		hpars = get_input_from_string(hpars)
		if len(hpars) != 2*nrefs:
			print_msg("Error: insufficient helical parameters supplied")
			sys.exit()
	## create helical parameter file for helical reconstruction
	if helicalrecon is True and myid == main_node:
		from hfunctions import createHpar
		# create initial helical parameter files
		dp=[0]*nrefs
		dphi=[0]*nrefs
		vdp=[0]*nrefs
		vdphi=[0]*nrefs
		for iref in xrange(nrefs):
			hpar = os.path.join(outdir,"hpar%02d.spi"%(iref))
			params = False
			if hpars != "-1":
				# if helical parameters explicitly given, set twist & rise
				params = [float(hpars[iref*2]),float(hpars[(iref*2)+1])]
			dp[iref],dphi[iref],vdp[iref],vdphi[iref] = createHpar(hpar,proto[iref],params,vertstep)

	# get values for helical search parameters
	hsearch = get_input_from_string(hsearch)
	if len(hsearch) != 2:
		print_msg("Error: specify outer and inner radii for helical search")
		sys.exit()

	if last_ring < 0 or last_ring > int(nx/2)-2 :	last_ring = int(nx/2) - 2

	if myid == main_node:
	#	import user_functions
	#	user_func = user_functions.factory[user_func_name]

		print_begin_msg("ali3d_MPI")
		print_msg("Input stack		 : %s\n"%(stack))
		print_msg("Reference volume	    : %s\n"%(ref_vol))	
		print_msg("Output directory	    : %s\n"%(outdir))
		if nmasks > 0:
			print_msg("Maskfile (number of masks)  : %s (%i)\n"%(maskfile,nmasks))
		print_msg("Inner radius		: %i\n"%(first_ring))
		print_msg("Outer radius		: %i\n"%(last_ring))
		print_msg("Ring step		   : %i\n"%(rstep))
		print_msg("X search range	      : %s\n"%(xrng))
		print_msg("Y search range	      : %s\n"%(yrng))
		print_msg("Translational step	  : %s\n"%(step))
		print_msg("Angular step		: %s\n"%(delta))
		print_msg("Angular search range	: %s\n"%(an))
		print_msg("Maximum iteration	   : %i\n"%(max_iter))
		print_msg("Center type		 : %i\n"%(center))
		print_msg("CTF correction	      : %s\n"%(CTF))
		print_msg("Signal-to-Noise Ratio       : %f\n"%(snr))
		print_msg("Reference projection method : %s\n"%(ref_a))
		print_msg("Symmetry group	      : %s\n"%(sym))
		print_msg("Fourier padding for 3D      : %i\n"%(recon_pad))
		print_msg("Number of reference models  : %i\n"%(nrefs))
		print_msg("Sort images between models  : %s\n"%(sort))
		print_msg("Allow images to jump	: %s\n"%(mjump))
		print_msg("CC cutoff standard dev      : %f\n"%(cutoff))
		print_msg("Two tail cutoff	     : %s\n"%(two_tail))
		print_msg("Termination pix error       : %f\n"%(term))
		print_msg("Pixel error cutoff	  : %s\n"%(pix_cutoff))
		print_msg("Restart		     : %s\n"%(restart))
		print_msg("Full output		 : %s\n"%(full_output))
		print_msg("Compare reprojections       : %s\n"%(compare_repro))
		print_msg("Compare ref free class avgs : %s\n"%(compare_ref_free))
		print_msg("Use cutoff from ref free    : %s\n"%(ref_free_cutoff))
		if protos:
			print_msg("Protofilament numbers	: %s\n"%(proto))
			print_msg("Using helical search range   : %s\n"%hsearch) 
		if findseam is True:
			print_msg("Using seam-based reconstruction\n")
		if hpars != "-1":
			print_msg("Using hpars		  : %s\n"%hpars)
		if vertstep != None:
			print_msg("Using vertical step    : %.2f\n"%vertstep)
		if save_half is True:
			print_msg("Saving even/odd halves\n")
		for i in xrange(100) : print_msg("*")
		print_msg("\n\n")
	if maskfile:
		if type(maskfile) is types.StringType: mask3D = get_image(maskfile)
		else:				  mask3D = maskfile
	else: mask3D = model_circle(last_ring, nx, nx, nx)

	numr	= Numrinit(first_ring, last_ring, rstep, "F")
	mask2D  = model_circle(last_ring,nx,nx) - model_circle(first_ring,nx,nx)

	fscmask = model_circle(last_ring,nx,nx,nx)
	if CTF:
		from filter	 import filt_ctf
	from reconstruction_rjh import rec3D_MPI_noCTF

	if myid == main_node:
		active = EMUtil.get_all_attributes(stack, 'active')
		list_of_particles = []
		for im in xrange(len(active)):
			if active[im]:  list_of_particles.append(im)
		del active
		nima = len(list_of_particles)
	else:
		nima = 0
	total_nima = bcast_number_to_all(nima, source_node = main_node)

	if myid != main_node:
		list_of_particles = [-1]*total_nima
	list_of_particles = bcast_list_to_all(list_of_particles, source_node = main_node)

	image_start, image_end = MPI_start_end(total_nima, number_of_proc, myid)

	# create a list of images for each node
	list_of_particles = list_of_particles[image_start: image_end]
	nima = len(list_of_particles)
	if debug:
		finfo.write("image_start, image_end: %d %d\n" %(image_start, image_end))
		finfo.flush()

	data = EMData.read_images(stack, list_of_particles)

	t_zero = Transform({"type":"spider","phi":0,"theta":0,"psi":0,"tx":0,"ty":0})
	transmulti = [[t_zero for i in xrange(nrefs)] for j in xrange(nima)]

	for iref,im in ((iref,im) for iref in xrange(nrefs) for im in xrange(nima)):
		if nrefs == 1:
			transmulti[im][iref] = data[im].get_attr("xform.projection")
		else:
			# if multi models, keep track of eulers for all models
			try:
				transmulti[im][iref] = data[im].get_attr("eulers_txty.%i"%iref)
			except:
				data[im].set_attr("eulers_txty.%i"%iref,t_zero)

	scoremulti = [[0.0 for i in xrange(nrefs)] for j in xrange(nima)] 
	pixelmulti = [[0.0 for i in xrange(nrefs)] for j in xrange(nima)] 
	ref_res = [0.0 for x in xrange(nrefs)] 
	apix = data[0].get_attr('apix_x')

	# for oplane parameter, create cylindrical mask
	if oplane is not None and myid == main_node:
		from hfunctions import createCylMask
		cmaskf=os.path.join(outdir, "mask3D_cyl.mrc")
		mask3D = createCylMask(data,ou,lmask,ilmask,cmaskf)
		# if finding seam of helix, create wedge masks
		if findseam is True:
			wedgemask=[]
			for pf in xrange(nrefs):
				wedgemask.append(EMData())
			# wedgemask option
			if wcmask is not None:
				wcmask = get_input_from_string(wcmask)
				if len(wcmask) != 3:
					print_msg("Error: wcmask option requires 3 values: x y radius")
					sys.exit()

	# determine if particles have helix info:
	try:
		data[0].get_attr('h_angle')
		original_data = []
		boxmask = True
		from hfunctions import createBoxMask
	except:
		boxmask = False

	# prepare particles
	for im in xrange(nima):
		data[im].set_attr('ID', list_of_particles[im])
		data[im].set_attr('pix_score', int(0))
		if CTF:
			# only phaseflip particles, not full CTF correction
			ctf_params = data[im].get_attr("ctf")
			st = Util.infomask(data[im], mask2D, False)
			data[im] -= st[0]
			data[im] = filt_ctf(data[im], ctf_params, sign = -1, binary=1)
			data[im].set_attr('ctf_applied', 1)
		# for window mask:
		if boxmask is True:
			h_angle = data[im].get_attr("h_angle")
			original_data.append(data[im].copy())
			bmask = createBoxMask(nx,apix,ou,lmask,h_angle)
			data[im]*=bmask
			del bmask
	if debug:
		finfo.write( '%d loaded  \n' % nima )
		finfo.flush()
	if myid == main_node:
		# initialize data for the reference preparation function
		ref_data = [ mask3D, max(center,0), None, None, None, None ]
		# for method -1, switch off centering in user function

	from time import time	

	#  this is needed for gathering of pixel errors
	disps = []
	recvcount = []
	disps_score = []
	recvcount_score = []
	for im in xrange(number_of_proc):
		if( im == main_node ):  
			disps.append(0)
			disps_score.append(0)
		else:		  
			disps.append(disps[im-1] + recvcount[im-1])
			disps_score.append(disps_score[im-1] + recvcount_score[im-1])
		ib, ie = MPI_start_end(total_nima, number_of_proc, im)
		recvcount.append( ie - ib )
		recvcount_score.append((ie-ib)*nrefs)

	pixer = [0.0]*nima
	cs = [0.0]*3
	total_iter = 0
	volodd = EMData.read_images(ref_vol, xrange(nrefs))
	voleve = EMData.read_images(ref_vol, xrange(nrefs))

	if restart:
		# recreate initial volumes from alignments stored in header
		itout = "000_00"
		for iref in xrange(nrefs):
			if(nrefs == 1):
				modout = ""
			else:
				modout = "_model_%02d"%(iref)	
	
			if(sort): 
				group = iref
				for im in xrange(nima):
					imgroup = data[im].get_attr('group')
					if imgroup == iref:
						data[im].set_attr('xform.projection',transmulti[im][iref])
			else: 
				group = int(999) 
				for im in xrange(nima):
					data[im].set_attr('xform.projection',transmulti[im][iref])
			
			fscfile = os.path.join(outdir, "fsc_%s%s"%(itout,modout))

			vol[iref], fscc, volodd[iref], voleve[iref] = rec3D_MPI_noCTF(data, sym, fscmask, fscfile, myid, main_node, index = group, npad = recon_pad)

			if myid == main_node:
				if helicalrecon:
					from hfunctions import processHelicalVol

					vstep=None
					if vertstep is not None:
						vstep=(vdp[iref],vdphi[iref])
					print_msg("Old rise and twist for model %i     : %8.3f, %8.3f\n"%(iref,dp[iref],dphi[iref]))
					hvals=processHelicalVol(vol[iref],voleve[iref],volodd[iref],iref,outdir,itout,
								dp[iref],dphi[iref],apix,hsearch,findseam,vstep,wcmask)
					(vol[iref],voleve[iref],volodd[iref],dp[iref],dphi[iref],vdp[iref],vdphi[iref])=hvals
					print_msg("New rise and twist for model %i     : %8.3f, %8.3f\n"%(iref,dp[iref],dphi[iref]))
					# get new FSC from symmetrized half volumes
					fscc = fsc_mask( volodd[iref], voleve[iref], mask3D, rstep, fscfile)
				else:
					vol[iref].write_image(os.path.join(outdir, "vol_%s.hdf"%itout),-1)

				if save_half is True:
					volodd[iref].write_image(os.path.join(outdir, "volodd_%s.hdf"%itout),-1)
					voleve[iref].write_image(os.path.join(outdir, "voleve_%s.hdf"%itout),-1)

				if nmasks > 1:
					# Read mask for multiplying
					ref_data[0] = maskF[iref]
				ref_data[2] = vol[iref]
				ref_data[3] = fscc
				#  call user-supplied function to prepare reference image, i.e., center and filter it
				vol[iref], cs,fl = ref_ali3d(ref_data)
				vol[iref].write_image(os.path.join(outdir, "volf_%s.hdf"%(itout)),-1)
				if (apix == 1):
					res_msg = "Models filtered at spatial frequency of:\t"
					res = fl
				else:
					res_msg = "Models filtered at resolution of:       \t"
					res = apix / fl	
				ares = array2string(array(res), precision = 2)
				print_msg("%s%s\n\n"%(res_msg,ares))	
			
			bcast_EMData_to_all(vol[iref], myid, main_node)
			# write out headers, under MPI writing has to be done sequentially
			mpi_barrier(MPI_COMM_WORLD)

	# projection matching	
	for N_step in xrange(lstp):
		terminate = 0
		Iter = -1
 		while(Iter < max_iter-1 and terminate == 0):
			Iter += 1
			total_iter += 1
			itout = "%03g_%02d" %(delta[N_step], Iter)
			if myid == main_node:
				print_msg("ITERATION #%3d, inner iteration #%3d\nDelta = %4.1f, an = %5.2f, xrange = %5.2f, yrange = %5.2f, step = %5.2f\n\n"%(N_step, Iter, delta[N_step], an[N_step], xrng[N_step],yrng[N_step],step[N_step]))
	
			for iref in xrange(nrefs):
				if myid == main_node: start_time = time()
				volft,kb = prep_vol( vol[iref] )

				## constrain projections to out of plane parameter
				theta1 = None
				theta2 = None
				if oplane is not None:
					theta1 = 90-oplane
					theta2 = 90+oplane
				refrings = prepare_refrings( volft, kb, nx, delta[N_step], ref_a, sym, numr, MPI=True, phiEqpsi = "Minus", initial_theta=theta1, delta_theta=theta2)
				
				del volft,kb

				if myid== main_node:
					print_msg( "Time to prepare projections for model %i: %s\n" % (iref, legibleTime(time()-start_time)) )
					start_time = time()
	
				for im in xrange( nima ):
					data[im].set_attr("xform.projection", transmulti[im][iref])
					if an[N_step] == -1:
						t1, peak, pixer[im] = proj_ali_incore(data[im],refrings,numr,xrng[N_step],yrng[N_step],step[N_step],finfo)
					else:
						t1, peak, pixer[im] = proj_ali_incore_local(data[im],refrings,numr,xrng[N_step],yrng[N_step],step[N_step],an[N_step],finfo)
					#data[im].set_attr("xform.projection"%iref, t1)
					if nrefs > 1: data[im].set_attr("eulers_txty.%i"%iref,t1)
					scoremulti[im][iref] = peak
					from pixel_error import max_3D_pixel_error
					# t1 is the current param, t2 is old
					t2 = transmulti[im][iref]
					pixelmulti[im][iref] = max_3D_pixel_error(t1,t2,numr[-3])
					transmulti[im][iref] = t1

				if myid == main_node:
					print_msg("Time of alignment for model %i: %s\n"%(iref, legibleTime(time()-start_time)))
					start_time = time()


			# gather scoring data from all processors
			from mpi import mpi_gatherv
			scoremultisend = sum(scoremulti,[])
			pixelmultisend = sum(pixelmulti,[])
			tmp = mpi_gatherv(scoremultisend,len(scoremultisend),MPI_FLOAT, recvcount_score, disps_score, MPI_FLOAT, main_node,MPI_COMM_WORLD)
			tmp1 = mpi_gatherv(pixelmultisend,len(pixelmultisend),MPI_FLOAT, recvcount_score, disps_score, MPI_FLOAT, main_node,MPI_COMM_WORLD)
			tmp = mpi_bcast(tmp,(total_nima * nrefs), MPI_FLOAT,0, MPI_COMM_WORLD)
			tmp1 = mpi_bcast(tmp1,(total_nima * nrefs), MPI_FLOAT,0, MPI_COMM_WORLD)
			tmp = map(float,tmp)
			tmp1 = map(float,tmp1)
			score = array(tmp).reshape(-1,nrefs)
			pixelerror = array(tmp1).reshape(-1,nrefs) 
			score_local = array(scoremulti)
			mean_score = score.mean(axis=0)
			std_score = score.std(axis=0)
			cut = mean_score - (cutoff * std_score)
			cut2 = mean_score + (cutoff * std_score)
			res_max = score_local.argmax(axis=1)
			minus_cc = [0.0 for x in xrange(nrefs)]
			minus_pix = [0.0 for x in xrange(nrefs)]
			minus_ref = [0.0 for x in xrange(nrefs)]
			
			#output pixel errors
			if(myid == main_node):
				from statistics import hist_list
				lhist = 20
				pixmin = pixelerror.min(axis=1)
				region, histo = hist_list(pixmin, lhist)
				if(region[0] < 0.0):  region[0] = 0.0
				print_msg("Histogram of pixel errors\n      ERROR       number of particles\n")
				for lhx in xrange(lhist):
					print_msg(" %10.3f     %7d\n"%(region[lhx], histo[lhx]))
				# Terminate if 95% within 1 pixel error
				im = 0
				for lhx in xrange(lhist):
					if(region[lhx] > 1.0): break
					im += histo[lhx]
				print_msg( "Percent of particles with pixel error < 1: %f\n\n"% (im/float(total_nima)*100))
				term_cond = float(term)/100
				if(im/float(total_nima) > term_cond): 
					terminate = 1
					print_msg("Terminating internal loop\n")
				del region, histo
			terminate = mpi_bcast(terminate, 1, MPI_INT, 0, MPI_COMM_WORLD)
			terminate = int(terminate[0])	
			
			for im in xrange(nima):
				if(sort==False):
					data[im].set_attr('group',999)
				elif (mjump[N_step]==1):
					data[im].set_attr('group',int(res_max[im]))
				
				pix_run = data[im].get_attr('pix_score')			
				if (pix_cutoff[N_step]==1 and (terminate==1 or Iter == max_iter-1)):
					if (pixelmulti[im][int(res_max[im])] > 1):
						data[im].set_attr('pix_score',int(777))

				if (score_local[im][int(res_max[im])]<cut[int(res_max[im])]) or (two_tail and score_local[im][int(res_max[im])]>cut2[int(res_max[im])]):
					data[im].set_attr('group',int(888))
					minus_cc[int(res_max[im])] = minus_cc[int(res_max[im])] + 1

				if(pix_run == 777):
					data[im].set_attr('group',int(777))
					minus_pix[int(res_max[im])] = minus_pix[int(res_max[im])] + 1

				if (compare_ref_free != "-1") and (ref_free_cutoff[N_step] != -1) and (total_iter > 1):
					id = data[im].get_attr('ID')
					if id in rejects:
						data[im].set_attr('group',int(666))
						minus_ref[int(res_max[im])] = minus_ref[int(res_max[im])] + 1	
						
				
			minus_cc_tot = mpi_reduce(minus_cc,nrefs,MPI_FLOAT,MPI_SUM,0,MPI_COMM_WORLD)	
			minus_pix_tot = mpi_reduce(minus_pix,nrefs,MPI_FLOAT,MPI_SUM,0,MPI_COMM_WORLD) 	
			minus_ref_tot = mpi_reduce(minus_ref,nrefs,MPI_FLOAT,MPI_SUM,0,MPI_COMM_WORLD)
			if (myid == main_node):
				if(sort):
					tot_max = score.argmax(axis=1)
					res = bincount(tot_max)
				else:
					res = ones(nrefs) * total_nima
				print_msg("Particle distribution:	     \t\t%s\n"%(res*1.0))
				afcut1 = res - minus_cc_tot
				afcut2 = afcut1 - minus_pix_tot
				afcut3 = afcut2 - minus_ref_tot
				print_msg("Particle distribution after cc cutoff:\t\t%s\n"%(afcut1))
				print_msg("Particle distribution after pix cutoff:\t\t%s\n"%(afcut2)) 
				print_msg("Particle distribution after ref cutoff:\t\t%s\n\n"%(afcut3)) 
					
						
			res = [0.0 for i in xrange(nrefs)]
			for iref in xrange(nrefs):
				if(center == -1):
					from utilities      import estimate_3D_center_MPI, rotate_3D_shift
					dummy=EMData()
					cs[0], cs[1], cs[2], dummy, dummy = estimate_3D_center_MPI(data, total_nima, myid, number_of_proc, main_node)				
					cs = mpi_bcast(cs, 3, MPI_FLOAT, main_node, MPI_COMM_WORLD)
					cs = [-float(cs[0]), -float(cs[1]), -float(cs[2])]
					rotate_3D_shift(data, cs)


				if(sort): 
					group = iref
					for im in xrange(nima):
						imgroup = data[im].get_attr('group')
						if imgroup == iref:
							data[im].set_attr('xform.projection',transmulti[im][iref])
				else: 
					group = int(999) 
					for im in xrange(nima):
						data[im].set_attr('xform.projection',transmulti[im][iref])
				if(nrefs == 1):
					modout = ""
				else:
					modout = "_model_%02d"%(iref)	
				
				fscfile = os.path.join(outdir, "fsc_%s%s"%(itout,modout))
				vol[iref], fscc, volodd[iref], voleve[iref] = rec3D_MPI_noCTF(data, sym, fscmask, fscfile, myid, main_node, index=group, npad=recon_pad)
	
				if myid == main_node:
					print_msg("3D reconstruction time for model %i: %s\n"%(iref, legibleTime(time()-start_time)))
					start_time = time()
	
				# Compute Fourier variance
				if fourvar:
					outvar = os.path.join(outdir, "volVar_%s.hdf"%(itout))
					ssnr_file = os.path.join(outdir, "ssnr_%s"%(itout))
					varf = varf3d_MPI(data, ssnr_text_file=ssnr_file, mask2D=None, reference_structure=vol[iref], ou=last_ring, rw=1.0, npad=1, CTF=None, sign=1, sym=sym, myid=myid)
					if myid == main_node:
						print_msg("Time to calculate 3D Fourier variance for model %i: %s\n"%(iref, legibleTime(time()-start_time)))
						start_time = time()
						varf = 1.0/varf
						varf.write_image(outvar,-1)
				else:  varf = None

				if myid == main_node:
					if helicalrecon:
						from hfunctions import processHelicalVol

						vstep=None
						if vertstep is not None:
							vstep=(vdp[iref],vdphi[iref])
						print_msg("Old rise and twist for model %i     : %8.3f, %8.3f\n"%(iref,dp[iref],dphi[iref]))
						hvals=processHelicalVol(vol[iref],voleve[iref],volodd[iref],iref,outdir,itout,
									dp[iref],dphi[iref],apix,hsearch,findseam,vstep,wcmask)
						(vol[iref],voleve[iref],volodd[iref],dp[iref],dphi[iref],vdp[iref],vdphi[iref])=hvals
						print_msg("New rise and twist for model %i     : %8.3f, %8.3f\n"%(iref,dp[iref],dphi[iref]))
						# get new FSC from symmetrized half volumes
						fscc = fsc_mask( volodd[iref], voleve[iref], mask3D, rstep, fscfile)

						print_msg("Time to search and apply helical symmetry for model %i: %s\n\n"%(iref, legibleTime(time()-start_time)))
						start_time = time()
					else:
						vol[iref].write_image(os.path.join(outdir, "vol_%s.hdf"%(itout)),-1)

					if save_half is True:
						volodd[iref].write_image(os.path.join(outdir, "volodd_%s.hdf"%(itout)),-1)
						voleve[iref].write_image(os.path.join(outdir, "voleve_%s.hdf"%(itout)),-1)

					if nmasks > 1:
						# Read mask for multiplying
						ref_data[0] = maskF[iref]
					ref_data[2] = vol[iref]
					ref_data[3] = fscc
					ref_data[4] = varf
					#  call user-supplied function to prepare reference image, i.e., center and filter it
					vol[iref], cs,fl = ref_ali3d(ref_data)
					vol[iref].write_image(os.path.join(outdir, "volf_%s.hdf"%(itout)),-1)
					if (apix == 1):
						res_msg = "Models filtered at spatial frequency of:\t"
						res[iref] = fl
					else:
						res_msg = "Models filtered at resolution of:       \t"
						res[iref] = apix / fl	
	
				del varf
				bcast_EMData_to_all(vol[iref], myid, main_node)
				
				if compare_ref_free != "-1": compare_repro = True
				if compare_repro:
					outfile_repro = comp_rep(refrings, data, itout, modout, vol[iref], group, nima, nx, myid, main_node, outdir)
					mpi_barrier(MPI_COMM_WORLD)
					if compare_ref_free != "-1":
						ref_free_output = os.path.join(outdir,"ref_free_%s%s"%(itout,modout))
						rejects = compare(compare_ref_free, outfile_repro,ref_free_output,yrng[N_step], xrng[N_step], rstep,nx,apix,ref_free_cutoff[N_step], number_of_proc, myid, main_node)

			# retrieve alignment params from all processors
			par_str = ['xform.projection','ID','group']
			if nrefs > 1:
				for iref in xrange(nrefs):
					par_str.append('eulers_txty.%i'%iref)

			if myid == main_node:
				from utilities import recv_attr_dict
				recv_attr_dict(main_node, stack, data, par_str, image_start, image_end, number_of_proc)
				
			else:	send_attr_dict(main_node, data, par_str, image_start, image_end)

			if myid == main_node:
				ares = array2string(array(res), precision = 2)
				print_msg("%s%s\n\n"%(res_msg,ares))
				dummy = EMData()
				if full_output:
					nimat = EMUtil.get_image_count(stack)
					output_file = os.path.join(outdir, "paramout_%s"%itout)
					foutput = open(output_file, 'w')
					for im in xrange(nimat):
						# save the parameters for each of the models
						outstring = ""
						dummy.read_image(stack,im,True)
						param3d = dummy.get_attr('xform.projection')
						g = dummy.get_attr("group")
						# retrieve alignments in EMAN-format
						pE = param3d.get_params('eman')
						outstring += "%f\t%f\t%f\t%f\t%f\t%i\n" %(pE["az"], pE["alt"], pE["phi"], pE["tx"], pE["ty"],g)
						foutput.write(outstring)
					foutput.close()
				del dummy
			mpi_barrier(MPI_COMM_WORLD)


#	mpi_finalize()	

	if myid == main_node: print_end_msg("ali3d_MPI")
Beispiel #29
0
def main():

	from logger import Logger, BaseLogger_Files
	import user_functions
	from optparse import OptionParser, SUPPRESS_HELP
	from global_def import SPARXVERSION
	from EMAN2 import EMData

	main_node = 0
	mpi_init(0, [])
	mpi_comm = MPI_COMM_WORLD
	myid = mpi_comm_rank(MPI_COMM_WORLD)
	mpi_size = mpi_comm_size(MPI_COMM_WORLD)	# Total number of processes, passed by --np option.

	# mpi_barrier(mpi_comm)
	# from mpi import mpi_finalize
	# mpi_finalize()
	# print "mpi finalize"
	# from sys import exit
	# exit()

	progname = os.path.basename(sys.argv[0])
	usage = progname + " stack  [output_directory] --ir=inner_radius --radius=outer_radius --rs=ring_step --xr=x_range --yr=y_range  --ts=translational_search_step  --delta=angular_step --an=angular_neighborhood  --center=center_type --maxit1=max_iter1 --maxit2=max_iter2 --L2threshold=0.1  --fl --aa --ref_a=S --sym=c1"
	usage += """

stack			2D images in a stack file: (default required string)
output_directory: directory name into which the output files will be written.  If it does not exist, the directory will be created.  If it does exist, the program will continue executing from where it stopped (if it did not already reach the end). The "--use_latest_master_directory" option can be used to choose the most recent directory that starts with "master".
"""

	parser = OptionParser(usage,version=SPARXVERSION)
	parser.add_option("--radius",                type="int",           help="radius of the particle: has to be less than < int(nx/2)-1 (default required int)")

	parser.add_option("--ir",                    type="int",           default=1,          help="inner radius for rotational search: > 0 (default 1)")
	parser.add_option("--rs",                    type="int",           default=1,          help="step between rings in rotational search: >0 (default 1)")
	parser.add_option("--xr",                    type="string",        default='0',        help="range for translation search in x direction: search is +/xr in pixels (default '0')")
	parser.add_option("--yr",                    type="string",        default='0',        help="range for translation search in y direction: if omitted will be set to xr, search is +/yr in pixels (default '0')")
	parser.add_option("--ts",                    type="string",        default='1.0',      help="step size of the translation search in x-y directions: search is -xr, -xr+ts, 0, xr-ts, xr, can be fractional (default '1.0')")
	parser.add_option("--delta",                 type="string",        default='2.0',      help="angular step of reference projections: (default '2.0')")
	#parser.add_option("--an",       type="string", default= "-1",              help="angular neighborhood for local searches (phi and theta)")
	parser.add_option("--center",                type="float",         default=-1.0,       help="centering of 3D template: average shift method; 0: no centering; 1: center of gravity (default -1.0)")
	parser.add_option("--maxit1",                type="int",           default=400,        help="maximum number of iterations performed for the GA part: (default 400)")
	parser.add_option("--maxit2",                type="int",           default=50,         help="maximum number of iterations performed for the finishing up part: (default 50)")
	parser.add_option("--L2threshold",           type="float",         default=0.03,       help="stopping criterion of GA: given as a maximum relative dispersion of volumes' L2 norms: (default 0.03)")
	parser.add_option("--doga",                  type="float",         default=0.1,        help="do GA when fraction of orientation changes less than 1.0 degrees is at least doga: (default 0.1)")
	parser.add_option("--n_shc_runs",            type="int",           default=4,          help="number of quasi-independent shc runs (same as '--nruns' parameter from sxviper.py): (default 4)")
	parser.add_option("--n_rv_runs",             type="int",           default=10,         help="number of rviper iterations: (default 10)")
	parser.add_option("--n_v_runs",              type="int",           default=3,          help="number of viper runs for each r_viper cycle: (default 3)")
	parser.add_option("--outlier_percentile",    type="float",         default=95.0,       help="percentile above which outliers are removed every rviper iteration: (default 95.0)")
	parser.add_option("--iteration_start",       type="int",           default=0,          help="starting iteration for rviper: 0 means go to the most recent one (default 0)")
	#parser.add_option("--CTF",      action="store_true", default=False,        help="NOT IMPLEMENTED Consider CTF correction during the alignment ")
	#parser.add_option("--snr",      type="float",  default= 1.0,               help="Signal-to-Noise Ratio of the data (default 1.0)")
	parser.add_option("--ref_a",                 type="string",        default='S',        help="method for generating the quasi-uniformly distributed projection directions: (default S)")
	parser.add_option("--sym",                   type="string",        default='c1',       help="point-group symmetry of the structure: (default c1)")
	# parser.add_option("--function", type="string", default="ref_ali3d",         help="name of the reference preparation function (ref_ali3d by default)")
	##### XXXXXXXXXXXXXXXXXXXXXX option does not exist in docs XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
	parser.add_option("--function", type="string", default="ref_ali3d",         help=SUPPRESS_HELP)
	parser.add_option("--npad",                  type="int",           default=2,          help="padding size for 3D reconstruction: (default 2)")
	# parser.add_option("--npad", type="int",  default= 2,            help="padding size for 3D reconstruction (default 2)")

	#options introduced for the do_volume function
	parser.add_option("--fl",                    type="float",         default=0.25,       help="cut-off frequency applied to the template volume: using a hyperbolic tangent low-pass filter (default 0.25)")
	parser.add_option("--aa",                    type="float",         default=0.1,        help="fall-off of hyperbolic tangent low-pass filter: (default 0.1)")
	parser.add_option("--pwreference",           type="string",        default='',         help="text file with a reference power spectrum: (default none)")
	parser.add_option("--mask3D",                type="string",        default=None,       help="3D mask file: (default sphere)")
	parser.add_option("--moon_elimination",      type="string",        default='',         help="elimination of disconnected pieces: two arguments: mass in KDa and pixel size in px/A separated by comma, no space (default none)")

	# used for debugging, help is supressed with SUPPRESS_HELP
	##### XXXXXXXXXXXXXXXXXXXXXX option does not exist in docs XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
	parser.add_option("--my_random_seed",      type="int",  default=123,  help = SUPPRESS_HELP)
	##### XXXXXXXXXXXXXXXXXXXXXX option does not exist in docs XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
	parser.add_option("--run_get_already_processed_viper_runs", action="store_true", dest="run_get_already_processed_viper_runs", default=False, help = SUPPRESS_HELP)
	##### XXXXXXXXXXXXXXXXXXXXXX option does not exist in docs XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
	parser.add_option("--use_latest_master_directory", action="store_true", dest="use_latest_master_directory", default=False, help = SUPPRESS_HELP)
	
	parser.add_option("--criterion_name",        type="string",        default='80th percentile',help="criterion deciding if volumes have a core set of stable projections: '80th percentile', other options:'fastest increase in the last quartile' (default '80th percentile')")
	parser.add_option("--outlier_index_threshold_method",type="string",        default='discontinuity_in_derivative',help="method that decides which images to keep: discontinuity_in_derivative, other options:percentile, angle_measure (default discontinuity_in_derivative)")
	parser.add_option("--angle_threshold",       type="int",           default=30,         help="angle threshold for projection removal if using 'angle_measure': (default 30)")
	

	required_option_list = ['radius']
	(options, args) = parser.parse_args(sys.argv[1:])

	options.CTF = False
	options.snr = 1.0
	options.an = -1

	if options.moon_elimination == "":
		options.moon_elimination = []
	else:
		options.moon_elimination = map(float, options.moon_elimination.split(","))

	# Making sure all required options appeared.
	for required_option in required_option_list:
		if not options.__dict__[required_option]:
			print "\n ==%s== mandatory option is missing.\n"%required_option
			print "Please run '" + progname + " -h' for detailed options"
			return 1

	mpi_barrier(MPI_COMM_WORLD)
	if(myid == main_node):
		print "****************************************************************"
		Util.version()
		print "****************************************************************"
		sys.stdout.flush()
	mpi_barrier(MPI_COMM_WORLD)

	# this is just for benefiting from a user friendly parameter name
	options.ou = options.radius 
	my_random_seed = options.my_random_seed
	criterion_name = options.criterion_name
	outlier_index_threshold_method = options.outlier_index_threshold_method
	use_latest_master_directory = options.use_latest_master_directory
	iteration_start_default = options.iteration_start
	number_of_rrr_viper_runs = options.n_rv_runs
	no_of_viper_runs_analyzed_together_from_user_options = options.n_v_runs
	no_of_shc_runs_analyzed_together = options.n_shc_runs 
	outlier_percentile = options.outlier_percentile 
	angle_threshold = options.angle_threshold 
	
	run_get_already_processed_viper_runs = options.run_get_already_processed_viper_runs
	get_already_processed_viper_runs(run_get_already_processed_viper_runs)

	import random
	random.seed(my_random_seed)

	if len(args) < 1 or len(args) > 3:
		print "usage: " + usage
		print "Please run '" + progname + " -h' for detailed options"
		return 1

	# if len(args) > 2:
	# 	ref_vol = get_im(args[2])
	# else:
	ref_vol = None
	
	# error_status = None
	# if myid == 0:
	# 	number_of_images = EMUtil.get_image_count(args[0])
	# 	if mpi_size > number_of_images:
	# 		error_status = ('Number of processes supplied by --np in mpirun needs to be less than or equal to %d (total number of images) ' % number_of_images, getframeinfo(currentframe()))
	# if_error_then_all_processes_exit_program(error_status)
	
	bdb_stack_location = ""

	masterdir = ""
	if len(args) == 2:
		masterdir = args[1]
		if masterdir[-1] != DIR_DELIM:
			masterdir += DIR_DELIM
	elif len(args) == 1:
		if use_latest_master_directory:
			all_dirs = [d for d in os.listdir(".") if os.path.isdir(d)]
			import re; r = re.compile("^master.*$")
			all_dirs = filter(r.match, all_dirs)
			if len(all_dirs)>0:
				# all_dirs = max(all_dirs, key=os.path.getctime)
				masterdir = max(all_dirs, key=os.path.getmtime)
				masterdir += DIR_DELIM

	log = Logger(BaseLogger_Files())

	error_status = 0	
	if mpi_size % no_of_shc_runs_analyzed_together != 0:
		ERROR('Number of processes needs to be a multiple of the number of quasi-independent runs (shc) within each viper run. '
		'Total quasi-independent runs by default are 3, you can change it by specifying '
		'--n_shc_runs option (in sxviper this option is called --nruns). Also, to improve communication time it is recommended that '
		'the number of processes divided by the number of quasi-independent runs is a power '
		'of 2 (e.g. 2, 4, 8 or 16 depending on how many physical cores each node has).', 'sxviper', 1)
		error_status = 1
	if_error_then_all_processes_exit_program(error_status)

	#Create folder for all results or check if there is one created already
	if(myid == main_node):
		#cmd = "{}".format("Rmycounter ccc")
		#cmdexecute(cmd)

		if( masterdir == ""):
			timestring = strftime("%Y_%m_%d__%H_%M_%S" + DIR_DELIM, localtime())
			masterdir = "master"+timestring

		if not os.path.exists(masterdir):
			cmd = "{} {}".format("mkdir", masterdir)
			cmdexecute(cmd)

		if ':' in args[0]:
			bdb_stack_location = args[0].split(":")[0] + ":" + masterdir + args[0].split(":")[1]
			org_stack_location = args[0]

			if(not os.path.exists(os.path.join(masterdir,"EMAN2DB" + DIR_DELIM))):
				# cmd = "{} {}".format("cp -rp EMAN2DB", masterdir, "EMAN2DB" DIR_DELIM)
				# cmdexecute(cmd)
				cmd = "{} {} {}".format("e2bdb.py", org_stack_location,"--makevstack=" + bdb_stack_location + "_000")
				cmdexecute(cmd)

				from applications import header
				try:
					header(bdb_stack_location + "_000", params='original_image_index', fprint=True)
					print "Images were already indexed!"
				except KeyError:
					print "Indexing images"
					header(bdb_stack_location + "_000", params='original_image_index', consecutive=True)
		else:
			filename = os.path.basename(args[0])
			bdb_stack_location = "bdb:" + masterdir + os.path.splitext(filename)[0]
			if(not os.path.exists(os.path.join(masterdir,"EMAN2DB" + DIR_DELIM))):
				cmd = "{} {} {}".format("sxcpy.py  ", args[0], bdb_stack_location + "_000")
				cmdexecute(cmd)

				from applications import header
				try:
					header(bdb_stack_location + "_000", params='original_image_index', fprint=True)
					print "Images were already indexed!"
				except KeyError:
					print "Indexing images"
					header(bdb_stack_location + "_000", params='original_image_index', consecutive=True)

	# send masterdir to all processes
	dir_len  = len(masterdir)*int(myid == main_node)
	dir_len = mpi_bcast(dir_len,1,MPI_INT,0,MPI_COMM_WORLD)[0]
	masterdir = mpi_bcast(masterdir,dir_len,MPI_CHAR,main_node,MPI_COMM_WORLD)
	masterdir = string.join(masterdir,"")
	if masterdir[-1] != DIR_DELIM:
		masterdir += DIR_DELIM
		
	global_def.LOGFILE =  os.path.join(masterdir, global_def.LOGFILE)
	print_program_start_information()
	

	# mpi_barrier(mpi_comm)
	# from mpi import mpi_finalize
	# mpi_finalize()
	# print "mpi finalize"
	# from sys import exit
	# exit()
		
	
	# send bdb_stack_location to all processes
	dir_len  = len(bdb_stack_location)*int(myid == main_node)
	dir_len = mpi_bcast(dir_len,1,MPI_INT,0,MPI_COMM_WORLD)[0]
	bdb_stack_location = mpi_bcast(bdb_stack_location,dir_len,MPI_CHAR,main_node,MPI_COMM_WORLD)
	bdb_stack_location = string.join(bdb_stack_location,"")

	iteration_start = get_latest_directory_increment_value(masterdir, "main")

	if (myid == main_node):
		if (iteration_start < iteration_start_default):
			ERROR('Starting iteration provided is greater than last iteration performed. Quiting program', 'sxviper', 1)
			error_status = 1
	if iteration_start_default!=0:
		iteration_start = iteration_start_default
	if (myid == main_node):
		if (number_of_rrr_viper_runs < iteration_start):
			ERROR('Please provide number of rviper runs (--n_rv_runs) greater than number of iterations already performed.', 'sxviper', 1)
			error_status = 1

	if_error_then_all_processes_exit_program(error_status)

	for rviper_iter in range(iteration_start, number_of_rrr_viper_runs + 1):
		if(myid == main_node):
			all_projs = EMData.read_images(bdb_stack_location + "_%03d"%(rviper_iter - 1))
			print "XXXXXXXXXXXXXXXXX"
			print "Number of projections (in loop): " + str(len(all_projs))
			print "XXXXXXXXXXXXXXXXX"
			subset = range(len(all_projs))
		else:
			all_projs = None
			subset = None

		runs_iter = get_latest_directory_increment_value(masterdir + NAME_OF_MAIN_DIR + "%03d"%rviper_iter, DIR_DELIM + NAME_OF_RUN_DIR, start_value=0) - 1
		no_of_viper_runs_analyzed_together = max(runs_iter + 2, no_of_viper_runs_analyzed_together_from_user_options)

		first_time_entering_the_loop_need_to_do_full_check_up = True
		while True:
			runs_iter += 1

			if not first_time_entering_the_loop_need_to_do_full_check_up:
				if runs_iter >= no_of_viper_runs_analyzed_together:
					break
			first_time_entering_the_loop_need_to_do_full_check_up = False

			this_run_is_NOT_complete = 0
			if (myid == main_node):
				independent_run_dir = masterdir + DIR_DELIM + NAME_OF_MAIN_DIR + ('%03d' + DIR_DELIM + NAME_OF_RUN_DIR + "%03d" + DIR_DELIM)%(rviper_iter, runs_iter)
				if run_get_already_processed_viper_runs:
					cmd = "{} {}".format("mkdir -p", masterdir + DIR_DELIM + NAME_OF_MAIN_DIR + ('%03d' + DIR_DELIM)%(rviper_iter)); cmdexecute(cmd)
					cmd = "{} {}".format("rm -rf", independent_run_dir); cmdexecute(cmd)
					cmd = "{} {}".format("cp -r", get_already_processed_viper_runs() + " " +  independent_run_dir); cmdexecute(cmd)

				if os.path.exists(independent_run_dir + "log.txt") and (string_found_in_file("Finish VIPER2", independent_run_dir + "log.txt")):
					this_run_is_NOT_complete = 0
				else:
					this_run_is_NOT_complete = 1
					cmd = "{} {}".format("rm -rf", independent_run_dir); cmdexecute(cmd)
					cmd = "{} {}".format("mkdir -p", independent_run_dir); cmdexecute(cmd)

				this_run_is_NOT_complete = mpi_bcast(this_run_is_NOT_complete,1,MPI_INT,main_node,MPI_COMM_WORLD)[0]
				dir_len = len(independent_run_dir)
				dir_len = mpi_bcast(dir_len,1,MPI_INT,main_node,MPI_COMM_WORLD)[0]
				independent_run_dir = mpi_bcast(independent_run_dir,dir_len,MPI_CHAR,main_node,MPI_COMM_WORLD)
				independent_run_dir = string.join(independent_run_dir,"")
			else:
				this_run_is_NOT_complete = mpi_bcast(this_run_is_NOT_complete,1,MPI_INT,main_node,MPI_COMM_WORLD)[0]
				dir_len = 0
				independent_run_dir = ""
				dir_len = mpi_bcast(dir_len,1,MPI_INT,main_node,MPI_COMM_WORLD)[0]
				independent_run_dir = mpi_bcast(independent_run_dir,dir_len,MPI_CHAR,main_node,MPI_COMM_WORLD)
				independent_run_dir = string.join(independent_run_dir,"")

			if this_run_is_NOT_complete:
				mpi_barrier(MPI_COMM_WORLD)

				if independent_run_dir[-1] != DIR_DELIM:
					independent_run_dir += DIR_DELIM

				log.prefix = independent_run_dir

				options.user_func = user_functions.factory[options.function]

				# for debugging purposes
				#if (myid == main_node):
					#cmd = "{} {}".format("cp ~/log.txt ", independent_run_dir)
					#cmdexecute(cmd)
					#cmd = "{} {}{}".format("cp ~/paramdir/params$(mycounter ccc).txt ", independent_run_dir, "param%03d.txt"%runs_iter)
					#cmd = "{} {}{}".format("cp ~/paramdir/params$(mycounter ccc).txt ", independent_run_dir, "params.txt")
					#cmdexecute(cmd)

				if (myid == main_node):
					store_value_of_simple_vars_in_json_file(masterdir + 'program_state_stack.json', locals(), exclude_list_of_vars=["usage"], 
						vars_that_will_show_only_size = ["subset"])
					store_value_of_simple_vars_in_json_file(masterdir + 'program_state_stack.json', options.__dict__, write_or_append='a')

				# mpi_barrier(mpi_comm)
				# from mpi import mpi_finalize
				# mpi_finalize()
				# print "mpi finalize"
				# from sys import exit
				# exit()

				out_params, out_vol, out_peaks = multi_shc(all_projs, subset, no_of_shc_runs_analyzed_together, options,
				mpi_comm=mpi_comm, log=log, ref_vol=ref_vol)

				# end of: if this_run_is_NOT_complete:

			if runs_iter >= (no_of_viper_runs_analyzed_together_from_user_options - 1):
				increment_for_current_iteration = identify_outliers(myid, main_node, rviper_iter,
				no_of_viper_runs_analyzed_together, no_of_viper_runs_analyzed_together_from_user_options, masterdir,
				bdb_stack_location, outlier_percentile, criterion_name, outlier_index_threshold_method, angle_threshold)

				if increment_for_current_iteration == MUST_END_PROGRAM_THIS_ITERATION:
					break

				no_of_viper_runs_analyzed_together += increment_for_current_iteration

		# end of independent viper loop

		calculate_volumes_after_rotation_and_save_them(options, rviper_iter, masterdir, bdb_stack_location, myid,
		mpi_size, no_of_viper_runs_analyzed_together, no_of_viper_runs_analyzed_together_from_user_options)

		if increment_for_current_iteration == MUST_END_PROGRAM_THIS_ITERATION:
			if (myid == main_node):
				print "RVIPER found a core set of stable projections for the current RVIPER iteration (%d), the maximum angle difference between corresponding projections from different VIPER volumes is less than %.2f. Finishing."%(rviper_iter, ANGLE_ERROR_THRESHOLD)
			break
	else:
		if (myid == main_node):
			print "After running the last iteration (%d), RVIPER did not find a set of projections with the maximum angle difference between corresponding projections from different VIPER volumes less than %.2f Finishing."%(rviper_iter, ANGLE_ERROR_THRESHOLD)
		
			
	# end of RVIPER loop

	#mpi_finalize()
	#sys.exit()

	mpi_barrier(MPI_COMM_WORLD)
	mpi_finalize()
Beispiel #30
0
def run():
    arglist = []
    for arg in sys.argv:
        arglist.append(arg)
    progname = optparse.os.path.basename(arglist[0])
    usage = (
        progname +
        """ firstvolume  secondvolume  maskfile  directory  --prefix  --wn  --step  --cutoff  --radius  --fsc  --res_overall  --out_ang_res  --apix  --MPI

	Compute local resolution in real space within area outlined by the maskfile and within regions wn x wn x wn
	""")
    parser = optparse.OptionParser(usage, version=sp_global_def.SPARXVERSION)

    parser.add_option(
        "--prefix",
        type="str",
        default="localres",
        help="Prefix for the output files. (default localres)",
    )
    parser.add_option(
        "--wn",
        type="int",
        default=7,
        help=
        "Size of window within which local real-space FSC is computed. (default 7)",
    )
    parser.add_option(
        "--step",
        type="float",
        default=1.0,
        help="Shell step in Fourier size in pixels. (default 1.0)",
    )
    parser.add_option(
        "--cutoff",
        type="float",
        default=0.143,
        help="Resolution cut-off for FSC. (default 0.143)",
    )
    parser.add_option(
        "--radius",
        type="int",
        default=-1,
        help=
        "If there is no maskfile, sphere with r=radius will be used. By default, the radius is nx/2-wn (default -1)",
    )
    parser.add_option(
        "--fsc",
        type="string",
        default=None,
        help=
        "Save overall FSC curve (might be truncated). By default, the program does not save the FSC curve. (default none)",
    )
    parser.add_option(
        "--res_overall",
        type="float",
        default=-1.0,
        help=
        "Overall resolution at the cutoff level estimated by the user [abs units]. (default None)",
    )
    parser.add_option(
        "--out_ang_res",
        action="store_true",
        default=False,
        help=
        "Additionally creates a local resolution file in Angstroms. (default False)",
    )
    parser.add_option(
        "--apix",
        type="float",
        default=1.0,
        help=
        "Pixel size in Angstrom. Effective only with --out_ang_res options. (default 1.0)",
    )
    parser.add_option("--MPI",
                      action="store_true",
                      default=False,
                      help="Use MPI version.")

    (options, args) = parser.parse_args(arglist[1:])

    if len(args) < 3 or len(args) > 4:
        sp_global_def.sxprint("Usage: " + usage)
        sp_global_def.ERROR(
            "Invalid number of parameters used. Please see usage information above."
        )
        return

    if sp_global_def.CACHE_DISABLE:
        sp_utilities.disable_bdb_cache()

    res_overall = options.res_overall

    if options.MPI:

        number_of_proc = mpi.mpi_comm_size(mpi.MPI_COMM_WORLD)
        myid = mpi.mpi_comm_rank(mpi.MPI_COMM_WORLD)
        main_node = 0
        sp_global_def.MPI = True
        cutoff = options.cutoff

        nk = int(options.wn)

        if myid == main_node:
            # print sys.argv
            vi = sp_utilities.get_im(sys.argv[1])
            ui = sp_utilities.get_im(sys.argv[2])

            nx = vi.get_xsize()
            ny = vi.get_ysize()
            nz = vi.get_zsize()
            dis = [nx, ny, nz]
        else:
            dis = [0, 0, 0, 0]

        sp_global_def.BATCH = True

        dis = sp_utilities.bcast_list_to_all(dis, myid, source_node=main_node)

        if myid != main_node:
            nx = int(dis[0])
            ny = int(dis[1])
            nz = int(dis[2])

            vi = sp_utilities.model_blank(nx, ny, nz)
            ui = sp_utilities.model_blank(nx, ny, nz)

        if len(args) == 3:
            m = sp_utilities.model_circle(old_div((min(nx, ny, nz) - nk), 2),
                                          nx, ny, nz)
            outdir = args[2]

        elif len(args) == 4:
            if myid == main_node:
                m = sp_morphology.binarize(sp_utilities.get_im(args[2]), 0.5)
            else:
                m = sp_utilities.model_blank(nx, ny, nz)
            outdir = args[3]
        if optparse.os.path.exists(outdir) and myid == 0:
            sp_global_def.ERROR("Output directory already exists!")
        elif myid == 0:
            optparse.os.makedirs(outdir)
        sp_global_def.write_command(outdir)
        sp_utilities.bcast_EMData_to_all(m, myid, main_node)
        """Multiline Comment0"""
        freqvol, resolut = sp_statistics.locres(vi, ui, m, nk, cutoff,
                                                options.step, myid, main_node,
                                                number_of_proc)

        if myid == 0:
            # Remove outliers based on the Interquartile range
            output_volume(
                freqvol,
                resolut,
                options.apix,
                outdir,
                options.prefix,
                options.fsc,
                options.out_ang_res,
                nx,
                ny,
                nz,
                res_overall,
            )

    else:
        cutoff = options.cutoff
        vi = sp_utilities.get_im(args[0])
        ui = sp_utilities.get_im(args[1])

        nn = vi.get_xsize()
        nx = nn
        ny = nn
        nz = nn
        nk = int(options.wn)

        if len(args) == 3:
            m = sp_utilities.model_circle(old_div((nn - nk), 2), nn, nn, nn)
            outdir = args[2]

        elif len(args) == 4:
            m = sp_morphology.binarize(sp_utilities.get_im(args[2]), 0.5)
            outdir = args[3]
        if optparse.os.path.exists(outdir):
            sp_global_def.ERROR("Output directory already exists!")
        else:
            optparse.os.makedirs(outdir)
        sp_global_def.write_command(outdir)

        mc = sp_utilities.model_blank(nn, nn, nn, 1.0) - m

        vf = sp_fundamentals.fft(vi)
        uf = sp_fundamentals.fft(ui)
        """Multiline Comment1"""
        lp = int(old_div(old_div(nn, 2), options.step) + 0.5)
        step = old_div(0.5, lp)

        freqvol = sp_utilities.model_blank(nn, nn, nn)
        resolut = []
        for i in range(1, lp):
            fl = step * i
            fh = fl + step
            # print(lp,i,step,fl,fh)
            v = sp_fundamentals.fft(sp_filter.filt_tophatb(vf, fl, fh))
            u = sp_fundamentals.fft(sp_filter.filt_tophatb(uf, fl, fh))
            tmp1 = EMAN2_cppwrap.Util.muln_img(v, v)
            tmp2 = EMAN2_cppwrap.Util.muln_img(u, u)

            do = EMAN2_cppwrap.Util.infomask(
                sp_morphology.square_root(
                    sp_morphology.threshold(
                        EMAN2_cppwrap.Util.muln_img(tmp1, tmp2))),
                m,
                True,
            )[0]

            tmp3 = EMAN2_cppwrap.Util.muln_img(u, v)
            dp = EMAN2_cppwrap.Util.infomask(tmp3, m, True)[0]
            resolut.append([i, old_div((fl + fh), 2.0), old_div(dp, do)])

            tmp1 = EMAN2_cppwrap.Util.box_convolution(tmp1, nk)
            tmp2 = EMAN2_cppwrap.Util.box_convolution(tmp2, nk)
            tmp3 = EMAN2_cppwrap.Util.box_convolution(tmp3, nk)

            EMAN2_cppwrap.Util.mul_img(tmp1, tmp2)

            tmp1 = sp_morphology.square_root(sp_morphology.threshold(tmp1))

            EMAN2_cppwrap.Util.mul_img(tmp1, m)
            EMAN2_cppwrap.Util.add_img(tmp1, mc)

            EMAN2_cppwrap.Util.mul_img(tmp3, m)
            EMAN2_cppwrap.Util.add_img(tmp3, mc)

            EMAN2_cppwrap.Util.div_img(tmp3, tmp1)

            EMAN2_cppwrap.Util.mul_img(tmp3, m)
            freq = old_div((fl + fh), 2.0)
            bailout = True
            for x in range(nn):
                for y in range(nn):
                    for z in range(nn):
                        if m.get_value_at(x, y, z) > 0.5:
                            if freqvol.get_value_at(x, y, z) == 0.0:
                                if tmp3.get_value_at(x, y, z) < cutoff:
                                    freqvol.set_value_at(x, y, z, freq)
                                    bailout = False
                                else:
                                    bailout = False
            if bailout:
                break
        # print(len(resolut))
        # remove outliers
        output_volume(
            freqvol,
            resolut,
            options.apix,
            outdir,
            options.prefix,
            options.fsc,
            options.out_ang_res,
            nx,
            ny,
            nz,
            res_overall,
        )
Beispiel #31
0
def main():
    import os
    import sys
    from optparse import OptionParser
    arglist = []
    for arg in sys.argv:
        arglist.append(arg)
    progname = os.path.basename(arglist[0])
    usage = progname + """ firstvolume  secondvolume  maskfile  outputfile  --wn  --step  --cutoff  --radius  --fsc  --res_overall  --out_ang_res  --apix  --MPI

	Compute local resolution in real space within area outlined by the maskfile and within regions wn x wn x wn
	"""
    parser = OptionParser(usage, version=SPARXVERSION)

    parser.add_option(
        "--wn",
        type="int",
        default=7,
        help=
        "Size of window within which local real-space FSC is computed. (default 7)"
    )
    parser.add_option(
        "--step",
        type="float",
        default=1.0,
        help="Shell step in Fourier size in pixels. (default 1.0)")
    parser.add_option("--cutoff",
                      type="float",
                      default=0.5,
                      help="Resolution cut-off for FSC. (default 0.5)")
    parser.add_option(
        "--radius",
        type="int",
        default=-1,
        help=
        "If there is no maskfile, sphere with r=radius will be used. By default, the radius is nx/2-wn (default -1)"
    )
    parser.add_option(
        "--fsc",
        type="string",
        default=None,
        help=
        "Save overall FSC curve (might be truncated). By default, the program does not save the FSC curve. (default none)"
    )
    parser.add_option(
        "--res_overall",
        type="float",
        default=-1.0,
        help=
        "Overall resolution at the cutoff level estimated by the user [abs units]. (default None)"
    )
    parser.add_option(
        "--out_ang_res",
        action="store_true",
        default=False,
        help=
        "Additionally creates a local resolution file in Angstroms. (default False)"
    )
    parser.add_option(
        "--apix",
        type="float",
        default=1.0,
        help=
        "Pixel size in Angstrom. Effective only with --out_ang_res options. (default 1.0)"
    )
    parser.add_option("--MPI",
                      action="store_true",
                      default=False,
                      help="Use MPI version.")

    (options, args) = parser.parse_args(arglist[1:])

    if len(args) < 3 or len(args) > 4:
        print("See usage " + usage)
        sys.exit()

    if global_def.CACHE_DISABLE:
        from utilities import disable_bdb_cache
        disable_bdb_cache()

    res_overall = options.res_overall

    if options.MPI:
        from mpi import mpi_init, mpi_comm_size, mpi_comm_rank, MPI_COMM_WORLD
        from mpi import mpi_reduce, mpi_bcast, mpi_barrier, mpi_gatherv, mpi_send, mpi_recv
        from mpi import MPI_SUM, MPI_FLOAT, MPI_INT
        sys.argv = mpi_init(len(sys.argv), sys.argv)

        number_of_proc = mpi_comm_size(MPI_COMM_WORLD)
        myid = mpi_comm_rank(MPI_COMM_WORLD)
        main_node = 0
        global_def.MPI = True
        cutoff = options.cutoff

        nk = int(options.wn)

        if (myid == main_node):
            #print sys.argv
            vi = get_im(sys.argv[1])
            ui = get_im(sys.argv[2])

            nx = vi.get_xsize()
            ny = vi.get_ysize()
            nz = vi.get_zsize()
            dis = [nx, ny, nz]
        else:
            dis = [0, 0, 0, 0]

        global_def.BATCH = True

        dis = bcast_list_to_all(dis, myid, source_node=main_node)

        if (myid != main_node):
            nx = int(dis[0])
            ny = int(dis[1])
            nz = int(dis[2])

            vi = model_blank(nx, ny, nz)
            ui = model_blank(nx, ny, nz)

        if len(args) == 3:
            m = model_circle((min(nx, ny, nz) - nk) // 2, nx, ny, nz)
            outvol = args[2]

        elif len(args) == 4:
            if (myid == main_node):
                m = binarize(get_im(args[2]), 0.5)
            else:
                m = model_blank(nx, ny, nz)
            outvol = args[3]
        bcast_EMData_to_all(m, myid, main_node)

        from statistics import locres
        """
		res_overall = 0.5
		if myid ==main_node:
			fsc_curve = fsc(vi, ui)
			for ifreq in xrange(len(fsc_curve[0])-1, -1, -1):
				if fsc_curve[1][ifreq] > options.cutoff:
					res_overall = fsc_curve[0][ifreq]
					break
		res_overall = bcast_number_to_all(res_overall, main_node)
		"""
        freqvol, resolut = locres(vi, ui, m, nk, cutoff, options.step, myid,
                                  main_node, number_of_proc)
        if (myid == 0):
            if res_overall != -1.0:
                freqvol += (res_overall - Util.infomask(freqvol, m, True)[0])
                for ifreq in xrange(len(resolut)):
                    if resolut[ifreq][0] > res_overall:
                        break
                for jfreq in xrange(ifreq, len(resolut)):
                    resolut[jfreq][1] = 0.0
            freqvol.write_image(outvol)

            if (options.out_ang_res):
                outAngResVolName = os.path.splitext(outvol)[0] + "_ang.hdf"
                outAngResVol = makeAngRes(freqvol, nx, ny, nz, options.apix)
                outAngResVol.write_image(outAngResVolName)

            if (options.fsc != None): write_text_row(resolut, options.fsc)
        from mpi import mpi_finalize
        mpi_finalize()

    else:
        cutoff = options.cutoff
        vi = get_im(args[0])
        ui = get_im(args[1])

        nn = vi.get_xsize()
        nk = int(options.wn)

        if len(args) == 3:
            m = model_circle((nn - nk) // 2, nn, nn, nn)
            outvol = args[2]

        elif len(args) == 4:
            m = binarize(get_im(args[2]), 0.5)
            outvol = args[3]

        mc = model_blank(nn, nn, nn, 1.0) - m

        vf = fft(vi)
        uf = fft(ui)
        """		
		res_overall = 0.5
		fsc_curve = fsc(vi, ui)
		for ifreq in xrange(len(fsc_curve[0])-1, -1, -1):
			if fsc_curve[1][ifreq] > options.cutoff:
				res_overall = fsc_curve[0][ifreq]
				break
		"""
        lp = int(nn / 2 / options.step + 0.5)
        step = 0.5 / lp

        freqvol = model_blank(nn, nn, nn)
        resolut = []
        for i in xrange(1, lp):
            fl = step * i
            fh = fl + step
            print(lp, i, step, fl, fh)
            v = fft(filt_tophatb(vf, fl, fh))
            u = fft(filt_tophatb(uf, fl, fh))
            tmp1 = Util.muln_img(v, v)
            tmp2 = Util.muln_img(u, u)

            do = Util.infomask(
                square_root(threshold(Util.muln_img(tmp1, tmp2))), m, True)[0]

            tmp3 = Util.muln_img(u, v)
            dp = Util.infomask(tmp3, m, True)[0]
            resolut.append([i, (fl + fh) / 2.0, dp / do])

            tmp1 = Util.box_convolution(tmp1, nk)
            tmp2 = Util.box_convolution(tmp2, nk)
            tmp3 = Util.box_convolution(tmp3, nk)

            Util.mul_img(tmp1, tmp2)

            tmp1 = square_root(threshold(tmp1))

            Util.mul_img(tmp1, m)
            Util.add_img(tmp1, mc)

            Util.mul_img(tmp3, m)
            Util.add_img(tmp3, mc)

            Util.div_img(tmp3, tmp1)

            Util.mul_img(tmp3, m)
            freq = (fl + fh) / 2.0
            bailout = True
            for x in xrange(nn):
                for y in xrange(nn):
                    for z in xrange(nn):
                        if (m.get_value_at(x, y, z) > 0.5):
                            if (freqvol.get_value_at(x, y, z) == 0.0):
                                if (tmp3.get_value_at(x, y, z) < cutoff):
                                    freqvol.set_value_at(x, y, z, freq)
                                    bailout = False
                                else:
                                    bailout = False
            if (bailout): break
        print(len(resolut))
        if res_overall != -1.0:
            freqvol += (res_overall - Util.infomask(freqvol, m, True)[0])
            for ifreq in xrange(len(resolut)):
                if resolut[ifreq][1] > res_overall:
                    break
            for jfreq in xrange(ifreq, len(resolut)):
                resolut[jfreq][2] = 0.0
        freqvol.write_image(outvol)

        if (options.out_ang_res):
            outAngResVolName = os.path.splitext(outvol)[0] + "_ang.hdf"
            outAngResVol = makeAngRes(freqvol, nn, nn, nn, options.apix)
            outAngResVol.write_image(outAngResVolName)

        if (options.fsc != None): write_text_row(resolut, options.fsc)
Beispiel #32
0
def main():
    def params_3D_2D_NEW(phi, theta, psi, s2x, s2y, mirror):
        # the final ali2d parameters already combine shifts operation first and rotation operation second for parameters converted from 3D
        if mirror:
            m = 1
            alpha, sx, sy, scalen = sp_utilities.compose_transform2(
                0, s2x, s2y, 1.0, 540.0 - psi, 0, 0, 1.0)
        else:
            m = 0
            alpha, sx, sy, scalen = sp_utilities.compose_transform2(
                0, s2x, s2y, 1.0, 360.0 - psi, 0, 0, 1.0)
        return alpha, sx, sy, m

    progname = optparse.os.path.basename(sys.argv[0])
    usage = (
        progname +
        " prj_stack  --ave2D= --var2D=  --ave3D= --var3D= --img_per_grp= --fl=  --aa=   --sym=symmetry --CTF"
    )
    parser = optparse.OptionParser(usage, version=sp_global_def.SPARXVERSION)

    parser.add_option("--output_dir",
                      type="string",
                      default="./",
                      help="Output directory")
    parser.add_option(
        "--ave2D",
        type="string",
        default=False,
        help="Write to the disk a stack of 2D averages",
    )
    parser.add_option(
        "--var2D",
        type="string",
        default=False,
        help="Write to the disk a stack of 2D variances",
    )
    parser.add_option(
        "--ave3D",
        type="string",
        default=False,
        help="Write to the disk reconstructed 3D average",
    )
    parser.add_option(
        "--var3D",
        type="string",
        default=False,
        help="Compute 3D variability (time consuming!)",
    )
    parser.add_option(
        "--img_per_grp",
        type="int",
        default=100,
        help="Number of neighbouring projections.(Default is 100)",
    )
    parser.add_option(
        "--no_norm",
        action="store_true",
        default=False,
        help="Do not use normalization.(Default is to apply normalization)",
    )
    # parser.add_option("--radius", 	    type="int"         ,	default=-1   ,				help="radius for 3D variability" )
    parser.add_option(
        "--npad",
        type="int",
        default=2,
        help=
        "Number of time to pad the original images.(Default is 2 times padding)",
    )
    parser.add_option("--sym",
                      type="string",
                      default="c1",
                      help="Symmetry. (Default is no symmetry)")
    parser.add_option(
        "--fl",
        type="float",
        default=0.0,
        help=
        "Low pass filter cutoff in absolute frequency (0.0 - 0.5) and is applied to decimated images. (Default - no filtration)",
    )
    parser.add_option(
        "--aa",
        type="float",
        default=0.02,
        help=
        "Fall off of the filter. Use default value if user has no clue about falloff (Default value is 0.02)",
    )
    parser.add_option(
        "--CTF",
        action="store_true",
        default=False,
        help="Use CFT correction.(Default is no CTF correction)",
    )
    # parser.add_option("--MPI" , 		action="store_true",	default=False,				help="use MPI version")
    # parser.add_option("--radiuspca", 	type="int"         ,	default=-1   ,				help="radius for PCA" )
    # parser.add_option("--iter", 		type="int"         ,	default=40   ,				help="maximum number of iterations (stop criterion of reconstruction process)" )
    # parser.add_option("--abs", 		type="float"   ,        default=0.0  ,				help="minimum average absolute change of voxels' values (stop criterion of reconstruction process)" )
    # parser.add_option("--squ", 		type="float"   ,	    default=0.0  ,				help="minimum average squared change of voxels' values (stop criterion of reconstruction process)" )
    parser.add_option(
        "--VAR",
        action="store_true",
        default=False,
        help="Stack of input consists of 2D variances (Default False)",
    )
    parser.add_option(
        "--decimate",
        type="float",
        default=0.25,
        help="Image decimate rate, a number less than 1. (Default is 0.25)",
    )
    parser.add_option(
        "--window",
        type="int",
        default=0,
        help=
        "Target image size relative to original image size. (Default value is zero.)",
    )
    # parser.add_option("--SND",			action="store_true",	default=False,				help="compute squared normalized differences (Default False)")
    # parser.add_option("--nvec",			type="int"         ,	default=0    ,				help="Number of eigenvectors, (Default = 0 meaning no PCA calculated)")
    parser.add_option(
        "--symmetrize",
        action="store_true",
        default=False,
        help="Prepare input stack for handling symmetry (Default False)",
    )
    parser.add_option("--overhead",
                      type="float",
                      default=0.5,
                      help="python overhead per CPU.")

    (options, args) = parser.parse_args()
    #####
    # from mpi import *

    #  This is code for handling symmetries by the above program.  To be incorporated. PAP 01/27/2015

    # Set up global variables related to bdb cache
    if sp_global_def.CACHE_DISABLE:
        sp_utilities.disable_bdb_cache()

    # Set up global variables related to ERROR function
    sp_global_def.BATCH = True

    # detect if program is running under MPI
    RUNNING_UNDER_MPI = "OMPI_COMM_WORLD_SIZE" in optparse.os.environ
    if RUNNING_UNDER_MPI:
        sp_global_def.MPI = True
    if options.output_dir == "./":
        current_output_dir = optparse.os.path.abspath(options.output_dir)
    else:
        current_output_dir = options.output_dir
    if options.symmetrize:

        if mpi.mpi_comm_size(mpi.MPI_COMM_WORLD) > 1:
            sp_global_def.ERROR(
                "Cannot use more than one CPU for symmetry preparation")

        if not optparse.os.path.exists(current_output_dir):
            optparse.os.makedirs(current_output_dir)
            sp_global_def.write_command(current_output_dir)

        if optparse.os.path.exists(
                optparse.os.path.join(current_output_dir, "log.txt")):
            optparse.os.remove(
                optparse.os.path.join(current_output_dir, "log.txt"))
        log_main = sp_logger.Logger(sp_logger.BaseLogger_Files())
        log_main.prefix = optparse.os.path.join(current_output_dir, "./")

        instack = args[0]
        sym = options.sym.lower()
        if sym == "c1":
            sp_global_def.ERROR(
                "There is no need to symmetrize stack for C1 symmetry")

        line = ""
        for a in sys.argv:
            line += " " + a
        log_main.add(line)

        if instack[:4] != "bdb:":
            # if output_dir =="./": stack = "bdb:data"
            stack = "bdb:" + current_output_dir + "/data"
            sp_utilities.delete_bdb(stack)
            junk = sp_utilities.cmdexecute("sp_cpy.py  " + instack + "  " +
                                           stack)
        else:
            stack = instack

        qt = EMAN2_cppwrap.EMUtil.get_all_attributes(stack, "xform.projection")

        na = len(qt)
        ts = sp_utilities.get_symt(sym)
        ks = len(ts)
        angsa = [None] * na

        for k in range(ks):
            # Qfile = "Q%1d"%k
            # if options.output_dir!="./": Qfile = os.path.join(options.output_dir,"Q%1d"%k)
            Qfile = optparse.os.path.join(current_output_dir, "Q%1d" % k)
            # delete_bdb("bdb:Q%1d"%k)
            sp_utilities.delete_bdb("bdb:" + Qfile)
            # junk = cmdexecute("e2bdb.py  "+stack+"  --makevstack=bdb:Q%1d"%k)
            junk = sp_utilities.cmdexecute("e2bdb.py  " + stack +
                                           "  --makevstack=bdb:" + Qfile)
            # DB = db_open_dict("bdb:Q%1d"%k)
            DB = EMAN2db.db_open_dict("bdb:" + Qfile)
            for i in range(na):
                ut = qt[i] * ts[k]
                DB.set_attr(i, "xform.projection", ut)
                # bt = ut.get_params("spider")
                # angsa[i] = [round(bt["phi"],3)%360.0, round(bt["theta"],3)%360.0, bt["psi"], -bt["tx"], -bt["ty"]]
            # write_text_row(angsa, 'ptsma%1d.txt'%k)
            # junk = cmdexecute("e2bdb.py  "+stack+"  --makevstack=bdb:Q%1d"%k)
            # junk = cmdexecute("sxheader.py  bdb:Q%1d  --params=xform.projection  --import=ptsma%1d.txt"%(k,k))
            DB.close()
        # if options.output_dir =="./": delete_bdb("bdb:sdata")
        sp_utilities.delete_bdb("bdb:" + current_output_dir + "/" + "sdata")
        # junk = cmdexecute("e2bdb.py . --makevstack=bdb:sdata --filt=Q")
        sdata = "bdb:" + current_output_dir + "/" + "sdata"
        sp_global_def.sxprint(sdata)
        junk = sp_utilities.cmdexecute("e2bdb.py   " + current_output_dir +
                                       "  --makevstack=" + sdata + " --filt=Q")
        # junk = cmdexecute("ls  EMAN2DB/sdata*")
        # a = get_im("bdb:sdata")
        a = sp_utilities.get_im(sdata)
        a.set_attr("variabilitysymmetry", sym)
        # a.write_image("bdb:sdata")
        a.write_image(sdata)

    else:

        myid = mpi.mpi_comm_rank(mpi.MPI_COMM_WORLD)
        number_of_proc = mpi.mpi_comm_size(mpi.MPI_COMM_WORLD)
        main_node = 0
        shared_comm = mpi.mpi_comm_split_type(mpi.MPI_COMM_WORLD,
                                              mpi.MPI_COMM_TYPE_SHARED, 0,
                                              mpi.MPI_INFO_NULL)
        myid_on_node = mpi.mpi_comm_rank(shared_comm)
        no_of_processes_per_group = mpi.mpi_comm_size(shared_comm)
        masters_from_groups_vs_everything_else_comm = mpi.mpi_comm_split(
            mpi.MPI_COMM_WORLD, main_node == myid_on_node, myid_on_node)
        color, no_of_groups, balanced_processor_load_on_nodes = sp_utilities.get_colors_and_subsets(
            main_node,
            mpi.MPI_COMM_WORLD,
            myid,
            shared_comm,
            myid_on_node,
            masters_from_groups_vs_everything_else_comm,
        )
        overhead_loading = options.overhead * number_of_proc
        # memory_per_node  = options.memory_per_node
        # if memory_per_node == -1.: memory_per_node = 2.*no_of_processes_per_group
        keepgoing = 1

        current_window = options.window
        current_decimate = options.decimate

        if len(args) == 1:
            stack = args[0]
        else:
            sp_global_def.sxprint("Usage: " + usage)
            sp_global_def.sxprint("Please run '" + progname +
                                  " -h' for detailed options")
            sp_global_def.ERROR(
                "Invalid number of parameters used. Please see usage information above."
            )
            return

        t0 = time.time()
        # obsolete flags
        options.MPI = True
        # options.nvec = 0
        options.radiuspca = -1
        options.iter = 40
        options.abs = 0.0
        options.squ = 0.0

        if options.fl > 0.0 and options.aa == 0.0:
            sp_global_def.ERROR(
                "Fall off has to be given for the low-pass filter", myid=myid)

        # if options.VAR and options.SND:
        # 	ERROR( "Only one of var and SND can be set!",myid=myid )

        if options.VAR and (options.ave2D or options.ave3D or options.var2D):
            sp_global_def.ERROR(
                "When VAR is set, the program cannot output ave2D, ave3D or var2D",
                myid=myid,
            )

        # if options.SND and (options.ave2D or options.ave3D):
        # 	ERROR( "When SND is set, the program cannot output ave2D or ave3D", myid=myid )

        # if options.nvec > 0 :
        # 	ERROR( "PCA option not implemented", myid=myid )

        # if options.nvec > 0 and options.ave3D == None:
        # 	ERROR( "When doing PCA analysis, one must set ave3D", myid=myid )

        if current_decimate > 1.0 or current_decimate < 0.0:
            sp_global_def.ERROR(
                "Decimate rate should be a value between 0.0 and 1.0",
                myid=myid)

        if current_window < 0.0:
            sp_global_def.ERROR(
                "Target window size should be always larger than zero",
                myid=myid)

        if myid == main_node:
            img = sp_utilities.get_image(stack, 0)
            nx = img.get_xsize()
            ny = img.get_ysize()
            if min(nx, ny) < current_window:
                keepgoing = 0
        keepgoing = sp_utilities.bcast_number_to_all(keepgoing, main_node,
                                                     mpi.MPI_COMM_WORLD)
        if keepgoing == 0:
            sp_global_def.ERROR(
                "The target window size cannot be larger than the size of decimated image",
                myid=myid,
            )

        options.sym = options.sym.lower()
        # if global_def.CACHE_DISABLE:
        # 	from utilities import disable_bdb_cache
        # 	disable_bdb_cache()
        # global_def.BATCH = True

        if myid == main_node:
            if not optparse.os.path.exists(current_output_dir):
                optparse.os.makedirs(
                    current_output_dir
                )  # Never delete output_dir in the program!

        img_per_grp = options.img_per_grp
        # nvec        = options.nvec
        radiuspca = options.radiuspca
        # if os.path.exists(os.path.join(options.output_dir, "log.txt")): os.remove(os.path.join(options.output_dir, "log.txt"))
        log_main = sp_logger.Logger(sp_logger.BaseLogger_Files())
        log_main.prefix = optparse.os.path.join(current_output_dir, "./")

        if myid == main_node:
            line = ""
            for a in sys.argv:
                line += " " + a
            log_main.add(line)
            log_main.add("-------->>>Settings given by all options<<<-------")
            log_main.add("Symmetry             : %s" % options.sym)
            log_main.add("Input stack          : %s" % stack)
            log_main.add("Output_dir           : %s" % current_output_dir)

            if options.ave3D:
                log_main.add("Ave3d                : %s" % options.ave3D)
            if options.var3D:
                log_main.add("Var3d                : %s" % options.var3D)
            if options.ave2D:
                log_main.add("Ave2D                : %s" % options.ave2D)
            if options.var2D:
                log_main.add("Var2D                : %s" % options.var2D)
            if options.VAR:
                log_main.add("VAR                  : True")
            else:
                log_main.add("VAR                  : False")
            if options.CTF:
                log_main.add("CTF correction       : True  ")
            else:
                log_main.add("CTF correction       : False ")

            log_main.add("Image per group      : %5d" % options.img_per_grp)
            log_main.add("Image decimate rate  : %4.3f" % current_decimate)
            log_main.add("Low pass filter      : %4.3f" % options.fl)
            current_fl = options.fl
            if current_fl == 0.0:
                current_fl = 0.5
            log_main.add(
                "Current low pass filter is equivalent to cutoff frequency %4.3f for original image size"
                % round((current_fl * current_decimate), 3))
            log_main.add("Window size          : %5d " % current_window)
            log_main.add("sx3dvariability begins")

        symbaselen = 0
        if myid == main_node:
            nima = EMAN2_cppwrap.EMUtil.get_image_count(stack)
            img = sp_utilities.get_image(stack)
            nx = img.get_xsize()
            ny = img.get_ysize()
            nnxo = nx
            nnyo = ny
            if options.sym != "c1":
                imgdata = sp_utilities.get_im(stack)
                try:
                    i = imgdata.get_attr("variabilitysymmetry").lower()
                    if i != options.sym:
                        sp_global_def.ERROR(
                            "The symmetry provided does not agree with the symmetry of the input stack",
                            myid=myid,
                        )
                except:
                    sp_global_def.ERROR(
                        "Input stack is not prepared for symmetry, please follow instructions",
                        myid=myid,
                    )
                i = len(sp_utilities.get_symt(options.sym))
                if (old_div(nima, i)) * i != nima:
                    sp_global_def.ERROR(
                        "The length of the input stack is incorrect for symmetry processing",
                        myid=myid,
                    )
                symbaselen = old_div(nima, i)
            else:
                symbaselen = nima
        else:
            nima = 0
            nx = 0
            ny = 0
            nnxo = 0
            nnyo = 0
        nima = sp_utilities.bcast_number_to_all(nima)
        nx = sp_utilities.bcast_number_to_all(nx)
        ny = sp_utilities.bcast_number_to_all(ny)
        nnxo = sp_utilities.bcast_number_to_all(nnxo)
        nnyo = sp_utilities.bcast_number_to_all(nnyo)
        if current_window > max(nx, ny):
            sp_global_def.ERROR(
                "Window size is larger than the original image size")

        if current_decimate == 1.0:
            if current_window != 0:
                nx = current_window
                ny = current_window
        else:
            if current_window == 0:
                nx = int(nx * current_decimate + 0.5)
                ny = int(ny * current_decimate + 0.5)
            else:
                nx = int(current_window * current_decimate + 0.5)
                ny = nx
        symbaselen = sp_utilities.bcast_number_to_all(symbaselen)

        # check FFT prime number
        is_fft_friendly = nx == sp_fundamentals.smallprime(nx)

        if not is_fft_friendly:
            if myid == main_node:
                log_main.add(
                    "The target image size is not a product of small prime numbers"
                )
                log_main.add("Program adjusts the input settings!")
            ### two cases
            if current_decimate == 1.0:
                nx = sp_fundamentals.smallprime(nx)
                ny = nx
                current_window = nx  # update
                if myid == main_node:
                    log_main.add("The window size is updated to %d." %
                                 current_window)
            else:
                if current_window == 0:
                    nx = sp_fundamentals.smallprime(
                        int(nx * current_decimate + 0.5))
                    current_decimate = old_div(float(nx), nnxo)
                    ny = nx
                    if myid == main_node:
                        log_main.add("The decimate rate is updated to %f." %
                                     current_decimate)
                else:
                    nx = sp_fundamentals.smallprime(
                        int(current_window * current_decimate + 0.5))
                    ny = nx
                    current_window = int(old_div(nx, current_decimate) + 0.5)
                    if myid == main_node:
                        log_main.add("The window size is updated to %d." %
                                     current_window)

        if myid == main_node:
            log_main.add("The target image size is %d" % nx)

        if radiuspca == -1:
            radiuspca = old_div(nx, 2) - 2
        if myid == main_node:
            log_main.add("%-70s:  %d\n" % ("Number of projection", nima))
        img_begin, img_end = sp_applications.MPI_start_end(
            nima, number_of_proc, myid)
        """Multiline Comment0"""
        """
        Comments from adnan, replace index_of_proj to index_of_particle, index_of_proj was not defined
        also varList is not defined not made an empty list there
        """

        if options.VAR:  # 2D variance images have no shifts
            varList = []
            # varList   = EMData.read_images(stack, range(img_begin, img_end))
            for index_of_particle in range(img_begin, img_end):
                image = sp_utilities.get_im(stack, index_of_particle)
                if current_window > 0:
                    varList.append(
                        sp_fundamentals.fdecimate(
                            sp_fundamentals.window2d(image, current_window,
                                                     current_window),
                            nx,
                            ny,
                        ))
                else:
                    varList.append(sp_fundamentals.fdecimate(image, nx, ny))

        else:
            if myid == main_node:
                t1 = time.time()
                proj_angles = []
                aveList = []
                tab = EMAN2_cppwrap.EMUtil.get_all_attributes(
                    stack, "xform.projection")
                for i in range(nima):
                    t = tab[i].get_params("spider")
                    phi = t["phi"]
                    theta = t["theta"]
                    psi = t["psi"]
                    x = theta
                    if x > 90.0:
                        x = 180.0 - x
                    x = x * 10000 + psi
                    proj_angles.append([x, t["phi"], t["theta"], t["psi"], i])
                t2 = time.time()
                log_main.add(
                    "%-70s:  %d\n" %
                    ("Number of neighboring projections", img_per_grp))
                log_main.add("...... Finding neighboring projections\n")
                log_main.add("Number of images per group: %d" % img_per_grp)
                log_main.add("Now grouping projections")
                proj_angles.sort()
                proj_angles_list = numpy.full((nima, 4),
                                              0.0,
                                              dtype=numpy.float32)
                for i in range(nima):
                    proj_angles_list[i][0] = proj_angles[i][1]
                    proj_angles_list[i][1] = proj_angles[i][2]
                    proj_angles_list[i][2] = proj_angles[i][3]
                    proj_angles_list[i][3] = proj_angles[i][4]
            else:
                proj_angles_list = 0
            proj_angles_list = sp_utilities.wrap_mpi_bcast(
                proj_angles_list, main_node, mpi.MPI_COMM_WORLD)
            proj_angles = []
            for i in range(nima):
                proj_angles.append([
                    proj_angles_list[i][0],
                    proj_angles_list[i][1],
                    proj_angles_list[i][2],
                    int(proj_angles_list[i][3]),
                ])
            del proj_angles_list
            proj_list, mirror_list = sp_utilities.nearest_proj(
                proj_angles, img_per_grp, range(img_begin, img_end))
            all_proj = []
            for im in proj_list:
                for jm in im:
                    all_proj.append(proj_angles[jm][3])
            all_proj = list(set(all_proj))
            index = {}
            for i in range(len(all_proj)):
                index[all_proj[i]] = i
            mpi.mpi_barrier(mpi.MPI_COMM_WORLD)
            if myid == main_node:
                log_main.add("%-70s:  %.2f\n" %
                             ("Finding neighboring projections lasted [s]",
                              time.time() - t2))
                log_main.add("%-70s:  %d\n" %
                             ("Number of groups processed on the main node",
                              len(proj_list)))
                log_main.add("Grouping projections took:  %12.1f [m]" %
                             (old_div((time.time() - t2), 60.0)))
                log_main.add("Number of groups on main node: ", len(proj_list))
            mpi.mpi_barrier(mpi.MPI_COMM_WORLD)

            if myid == main_node:
                log_main.add("...... Calculating the stack of 2D variances \n")
            # Memory estimation. There are two memory consumption peaks
            # peak 1. Compute ave, var;
            # peak 2. Var volume reconstruction;
            # proj_params = [0.0]*(nima*5)
            aveList = []
            varList = []
            # if nvec > 0: eigList = [[] for i in range(nvec)]
            dnumber = len(
                all_proj)  # all neighborhood set for assigned to myid
            pnumber = len(proj_list) * 2.0 + img_per_grp  # aveList and varList
            tnumber = dnumber + pnumber
            vol_size2 = old_div(nx**3 * 4.0 * 8, 1.0e9)
            vol_size1 = old_div(2.0 * nnxo**3 * 4.0 * 8, 1.0e9)
            proj_size = old_div(nnxo * nnyo * len(proj_list) * 4.0 * 2.0,
                                1.0e9)  # both aveList and varList
            orig_data_size = old_div(nnxo * nnyo * 4.0 * tnumber, 1.0e9)
            reduced_data_size = old_div(nx * nx * 4.0 * tnumber, 1.0e9)
            full_data = numpy.full((number_of_proc, 2),
                                   -1.0,
                                   dtype=numpy.float16)
            full_data[myid] = orig_data_size, reduced_data_size
            if myid != main_node:
                sp_utilities.wrap_mpi_send(full_data, main_node,
                                           mpi.MPI_COMM_WORLD)
            if myid == main_node:
                for iproc in range(number_of_proc):
                    if iproc != main_node:
                        dummy = sp_utilities.wrap_mpi_recv(
                            iproc, mpi.MPI_COMM_WORLD)
                        full_data[numpy.where(dummy > -1)] = dummy[numpy.where(
                            dummy > -1)]
                del dummy
            mpi.mpi_barrier(mpi.MPI_COMM_WORLD)
            full_data = sp_utilities.wrap_mpi_bcast(full_data, main_node,
                                                    mpi.MPI_COMM_WORLD)
            # find the CPU with heaviest load
            minindx = numpy.argsort(full_data, 0)
            heavy_load_myid = minindx[-1][1]
            total_mem = sum(full_data)
            if myid == main_node:
                if current_window == 0:
                    log_main.add(
                        "Nx:   current image size = %d. Decimated by %f from %d"
                        % (nx, current_decimate, nnxo))
                else:
                    log_main.add(
                        "Nx:   current image size = %d. Windowed to %d, and decimated by %f from %d"
                        % (nx, current_window, current_decimate, nnxo))
                log_main.add("Nproj:       number of particle images.")
                log_main.add("Navg:        number of 2D average images.")
                log_main.add("Nvar:        number of 2D variance images.")
                log_main.add(
                    "Img_per_grp: user defined image per group for averaging = %d"
                    % img_per_grp)
                log_main.add(
                    "Overhead:    total python overhead memory consumption   = %f"
                    % overhead_loading)
                log_main.add(
                    "Total memory) = 4.0*nx^2*(nproj + navg +nvar+ img_per_grp)/1.0e9 + overhead: %12.3f [GB]"
                    % (total_mem[1] + overhead_loading))
            del full_data
            mpi.mpi_barrier(mpi.MPI_COMM_WORLD)
            if myid == heavy_load_myid:
                log_main.add(
                    "Begin reading and preprocessing images on processor. Wait... "
                )
                ttt = time.time()
            # imgdata = EMData.read_images(stack, all_proj)
            imgdata = [None for im in range(len(all_proj))]
            for index_of_proj in range(len(all_proj)):
                # image = get_im(stack, all_proj[index_of_proj])
                if current_window > 0:
                    imgdata[index_of_proj] = sp_fundamentals.fdecimate(
                        sp_fundamentals.window2d(
                            sp_utilities.get_im(stack,
                                                all_proj[index_of_proj]),
                            current_window,
                            current_window,
                        ),
                        nx,
                        ny,
                    )
                else:
                    imgdata[index_of_proj] = sp_fundamentals.fdecimate(
                        sp_utilities.get_im(stack, all_proj[index_of_proj]),
                        nx, ny)

                if current_decimate > 0.0 and options.CTF:
                    ctf = imgdata[index_of_proj].get_attr("ctf")
                    ctf.apix = old_div(ctf.apix, current_decimate)
                    imgdata[index_of_proj].set_attr("ctf", ctf)

                if myid == heavy_load_myid and index_of_proj % 100 == 0:
                    log_main.add(
                        " ...... %6.2f%% " %
                        (old_div(index_of_proj, float(len(all_proj))) * 100.0))
            mpi.mpi_barrier(mpi.MPI_COMM_WORLD)
            if myid == heavy_load_myid:
                log_main.add("All_proj preprocessing cost %7.2f m" % (old_div(
                    (time.time() - ttt), 60.0)))
                log_main.add("Wait untill reading on all CPUs done...")
            """Multiline Comment1"""
            if not options.no_norm:
                mask = sp_utilities.model_circle(old_div(nx, 2) - 2, nx, nx)
            if myid == heavy_load_myid:
                log_main.add("Start computing 2D aveList and varList. Wait...")
                ttt = time.time()
            inner = old_div(nx, 2) - 4
            outer = inner + 2
            xform_proj_for_2D = [None for i in range(len(proj_list))]
            for i in range(len(proj_list)):
                ki = proj_angles[proj_list[i][0]][3]
                if ki >= symbaselen:
                    continue
                mi = index[ki]
                dpar = EMAN2_cppwrap.Util.get_transform_params(
                    imgdata[mi], "xform.projection", "spider")
                phiM, thetaM, psiM, s2xM, s2yM = (
                    dpar["phi"],
                    dpar["theta"],
                    dpar["psi"],
                    -dpar["tx"] * current_decimate,
                    -dpar["ty"] * current_decimate,
                )
                grp_imgdata = []
                for j in range(img_per_grp):
                    mj = index[proj_angles[proj_list[i][j]][3]]
                    cpar = EMAN2_cppwrap.Util.get_transform_params(
                        imgdata[mj], "xform.projection", "spider")
                    alpha, sx, sy, mirror = params_3D_2D_NEW(
                        cpar["phi"],
                        cpar["theta"],
                        cpar["psi"],
                        -cpar["tx"] * current_decimate,
                        -cpar["ty"] * current_decimate,
                        mirror_list[i][j],
                    )
                    if thetaM <= 90:
                        if mirror == 0:
                            alpha, sx, sy, scale = sp_utilities.compose_transform2(
                                alpha, sx, sy, 1.0, phiM - cpar["phi"], 0.0,
                                0.0, 1.0)
                        else:
                            alpha, sx, sy, scale = sp_utilities.compose_transform2(
                                alpha,
                                sx,
                                sy,
                                1.0,
                                180 - (phiM - cpar["phi"]),
                                0.0,
                                0.0,
                                1.0,
                            )
                    else:
                        if mirror == 0:
                            alpha, sx, sy, scale = sp_utilities.compose_transform2(
                                alpha, sx, sy, 1.0, -(phiM - cpar["phi"]), 0.0,
                                0.0, 1.0)
                        else:
                            alpha, sx, sy, scale = sp_utilities.compose_transform2(
                                alpha,
                                sx,
                                sy,
                                1.0,
                                -(180 - (phiM - cpar["phi"])),
                                0.0,
                                0.0,
                                1.0,
                            )
                    imgdata[mj].set_attr(
                        "xform.align2d",
                        EMAN2_cppwrap.Transform({
                            "type": "2D",
                            "alpha": alpha,
                            "tx": sx,
                            "ty": sy,
                            "mirror": mirror,
                            "scale": 1.0,
                        }),
                    )
                    grp_imgdata.append(imgdata[mj])
                if not options.no_norm:
                    for k in range(img_per_grp):
                        ave, std, minn, maxx = EMAN2_cppwrap.Util.infomask(
                            grp_imgdata[k], mask, False)
                        grp_imgdata[k] -= ave
                        grp_imgdata[k] = old_div(grp_imgdata[k], std)
                if options.fl > 0.0:
                    for k in range(img_per_grp):
                        grp_imgdata[k] = sp_filter.filt_tanl(
                            grp_imgdata[k], options.fl, options.aa)

                #  Because of background issues, only linear option works.
                if options.CTF:
                    ave, var = sp_statistics.aves_wiener(
                        grp_imgdata, SNR=1.0e5, interpolation_method="linear")
                else:
                    ave, var = sp_statistics.ave_var(grp_imgdata)
                # Switch to std dev
                # threshold is not really needed,it is just in case due to numerical accuracy something turns out negative.
                var = sp_morphology.square_root(sp_morphology.threshold(var))

                sp_utilities.set_params_proj(ave,
                                             [phiM, thetaM, 0.0, 0.0, 0.0])
                sp_utilities.set_params_proj(var,
                                             [phiM, thetaM, 0.0, 0.0, 0.0])

                aveList.append(ave)
                varList.append(var)
                xform_proj_for_2D[i] = [phiM, thetaM, 0.0, 0.0, 0.0]
                """Multiline Comment2"""
                if (myid == heavy_load_myid) and (i % 100 == 0):
                    log_main.add(" ......%6.2f%%  " %
                                 (old_div(i, float(len(proj_list))) * 100.0))
            del imgdata, grp_imgdata, cpar, dpar, all_proj, proj_angles, index
            if not options.no_norm:
                del mask
            if myid == main_node:
                del tab
            #  At this point, all averages and variances are computed
            mpi.mpi_barrier(mpi.MPI_COMM_WORLD)

            if myid == heavy_load_myid:
                log_main.add("Computing aveList and varList took %12.1f [m]" %
                             (old_div((time.time() - ttt), 60.0)))

            xform_proj_for_2D = sp_utilities.wrap_mpi_gatherv(
                xform_proj_for_2D, main_node, mpi.MPI_COMM_WORLD)
            if myid == main_node:
                sp_utilities.write_text_row(
                    [str(entry) for entry in xform_proj_for_2D],
                    optparse.os.path.join(current_output_dir, "params.txt"),
                )
            del xform_proj_for_2D
            mpi.mpi_barrier(mpi.MPI_COMM_WORLD)
            if options.ave2D:
                if myid == main_node:
                    log_main.add("Compute ave2D ... ")
                    km = 0
                    for i in range(number_of_proc):
                        if i == main_node:
                            for im in range(len(aveList)):
                                aveList[im].write_image(
                                    optparse.os.path.join(
                                        current_output_dir, options.ave2D),
                                    km,
                                )
                                km += 1
                        else:
                            nl = mpi.mpi_recv(
                                1,
                                mpi.MPI_INT,
                                i,
                                sp_global_def.SPARX_MPI_TAG_UNIVERSAL,
                                mpi.MPI_COMM_WORLD,
                            )
                            nl = int(nl[0])
                            for im in range(nl):
                                ave = sp_utilities.recv_EMData(
                                    i, im + i + 70000)
                                """Multiline Comment3"""
                                tmpvol = sp_fundamentals.fpol(ave, nx, nx, 1)
                                tmpvol.write_image(
                                    optparse.os.path.join(
                                        current_output_dir, options.ave2D),
                                    km,
                                )
                                km += 1
                else:
                    mpi.mpi_send(
                        len(aveList),
                        1,
                        mpi.MPI_INT,
                        main_node,
                        sp_global_def.SPARX_MPI_TAG_UNIVERSAL,
                        mpi.MPI_COMM_WORLD,
                    )
                    for im in range(len(aveList)):
                        sp_utilities.send_EMData(aveList[im], main_node,
                                                 im + myid + 70000)
                        """Multiline Comment4"""
                if myid == main_node:
                    sp_applications.header(
                        optparse.os.path.join(current_output_dir,
                                              options.ave2D),
                        params="xform.projection",
                        fimport=optparse.os.path.join(current_output_dir,
                                                      "params.txt"),
                    )
                mpi.mpi_barrier(mpi.MPI_COMM_WORLD)
            if options.ave3D:
                t5 = time.time()
                if myid == main_node:
                    log_main.add("Reconstruct ave3D ... ")
                ave3D = sp_reconstruction.recons3d_4nn_MPI(
                    myid, aveList, symmetry=options.sym, npad=options.npad)
                sp_utilities.bcast_EMData_to_all(ave3D, myid)
                if myid == main_node:
                    if current_decimate != 1.0:
                        ave3D = sp_fundamentals.resample(
                            ave3D, old_div(1.0, current_decimate))
                    ave3D = sp_fundamentals.fpol(
                        ave3D, nnxo, nnxo,
                        nnxo)  # always to the orignal image size
                    sp_utilities.set_pixel_size(ave3D, 1.0)
                    ave3D.write_image(
                        optparse.os.path.join(current_output_dir,
                                              options.ave3D))
                    log_main.add("Ave3D reconstruction took %12.1f [m]" %
                                 (old_div((time.time() - t5), 60.0)))
                    log_main.add("%-70s:  %s\n" %
                                 ("The reconstructed ave3D is saved as ",
                                  options.ave3D))

            mpi.mpi_barrier(mpi.MPI_COMM_WORLD)
            del ave, var, proj_list, stack, alpha, sx, sy, mirror, aveList
            """Multiline Comment5"""

            if options.ave3D:
                del ave3D
            if options.var2D:
                if myid == main_node:
                    log_main.add("Compute var2D...")
                    km = 0
                    for i in range(number_of_proc):
                        if i == main_node:
                            for im in range(len(varList)):
                                tmpvol = sp_fundamentals.fpol(
                                    varList[im], nx, nx, 1)
                                tmpvol.write_image(
                                    optparse.os.path.join(
                                        current_output_dir, options.var2D),
                                    km,
                                )
                                km += 1
                        else:
                            nl = mpi.mpi_recv(
                                1,
                                mpi.MPI_INT,
                                i,
                                sp_global_def.SPARX_MPI_TAG_UNIVERSAL,
                                mpi.MPI_COMM_WORLD,
                            )
                            nl = int(nl[0])
                            for im in range(nl):
                                ave = sp_utilities.recv_EMData(
                                    i, im + i + 70000)
                                tmpvol = sp_fundamentals.fpol(ave, nx, nx, 1)
                                tmpvol.write_image(
                                    optparse.os.path.join(
                                        current_output_dir, options.var2D),
                                    km,
                                )
                                km += 1
                else:
                    mpi.mpi_send(
                        len(varList),
                        1,
                        mpi.MPI_INT,
                        main_node,
                        sp_global_def.SPARX_MPI_TAG_UNIVERSAL,
                        mpi.MPI_COMM_WORLD,
                    )
                    for im in range(len(varList)):
                        sp_utilities.send_EMData(
                            varList[im], main_node,
                            im + myid + 70000)  # What with the attributes??
                mpi.mpi_barrier(mpi.MPI_COMM_WORLD)
                if myid == main_node:
                    sp_applications.header(
                        optparse.os.path.join(current_output_dir,
                                              options.var2D),
                        params="xform.projection",
                        fimport=optparse.os.path.join(current_output_dir,
                                                      "params.txt"),
                    )
                mpi.mpi_barrier(mpi.MPI_COMM_WORLD)
        if options.var3D:
            if myid == main_node:
                log_main.add("Reconstruct var3D ...")
            t6 = time.time()
            # radiusvar = options.radius
            # if( radiusvar < 0 ):  radiusvar = nx//2 -3
            res = sp_reconstruction.recons3d_4nn_MPI(myid,
                                                     varList,
                                                     symmetry=options.sym,
                                                     npad=options.npad)
            # res = recons3d_em_MPI(varList, vol_stack, options.iter, radiusvar, options.abs, True, options.sym, options.squ)
            if myid == main_node:
                if current_decimate != 1.0:
                    res = sp_fundamentals.resample(
                        res, old_div(1.0, current_decimate))
                res = sp_fundamentals.fpol(res, nnxo, nnxo, nnxo)
                sp_utilities.set_pixel_size(res, 1.0)
                res.write_image(os.path.join(current_output_dir,
                                             options.var3D))
                log_main.add(
                    "%-70s:  %s\n" %
                    ("The reconstructed var3D is saved as ", options.var3D))
                log_main.add("Var3D reconstruction took %f12.1 [m]" % (old_div(
                    (time.time() - t6), 60.0)))
                log_main.add("Total computation time %f12.1 [m]" % (old_div(
                    (time.time() - t0), 60.0)))
                log_main.add("sx3dvariability finishes")

        if RUNNING_UNDER_MPI:
            sp_global_def.MPI = False

        sp_global_def.BATCH = False
Beispiel #33
0
def main(args):
	from utilities import if_error_then_all_processes_exit_program, write_text_row, drop_image, model_gauss_noise, get_im, set_params_proj, wrap_mpi_bcast, model_circle
	from logger import Logger, BaseLogger_Files
	from mpi import mpi_init, mpi_finalize, MPI_COMM_WORLD, mpi_comm_rank, mpi_comm_size, mpi_barrier
	import user_functions
	import sys
	import os
	from applications import MPI_start_end
	from optparse import OptionParser, SUPPRESS_HELP
	from global_def import SPARXVERSION
	from EMAN2 import EMData
	from multi_shc import multi_shc

	progname = os.path.basename(sys.argv[0])
	usage = progname + " stack  [output_directory] --ir=inner_radius --rs=ring_step --xr=x_range --yr=y_range  --ts=translational_search_step  --delta=angular_step --center=center_type --maxit1=max_iter1 --maxit2=max_iter2 --L2threshold=0.1 --ref_a=S --sym=c1"
	usage += """

stack			2D images in a stack file: (default required string)
directory		output directory name: into which the results will be written (if it does not exist, it will be created, if it does exist, the results will be written possibly overwriting previous results) (default required string)
"""
	
	parser = OptionParser(usage,version=SPARXVERSION)
	parser.add_option("--radius",                type="int",           help="radius of the particle: has to be less than < int(nx/2)-1 (default required int)")

	parser.add_option("--xr",                    type="string",        default='0',        help="range for translation search in x direction: search is +/xr in pixels (default '0')")
	parser.add_option("--yr",                    type="string",        default='0',        help="range for translation search in y direction: if omitted will be set to xr, search is +/yr in pixels (default '0')")
	parser.add_option("--mask3D",                type="string",        default=None,       help="3D mask file: (default sphere)")
	parser.add_option("--moon_elimination",      type="string",        default='',         help="elimination of disconnected pieces: two arguments: mass in KDa and pixel size in px/A separated by comma, no space (default none)")
	parser.add_option("--ir",                    type="int",           default=1,          help="inner radius for rotational search: > 0 (default 1)")
	
	# 'radius' and 'ou' are the same as per Pawel's request; 'ou' is hidden from the user
	# the 'ou' variable is not changed to 'radius' in the 'sparx' program. This change is at interface level only for sxviper.
	##### XXXXXXXXXXXXXXXXXXXXXX option does not exist in docs XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
	parser.add_option("--ou",                    type="int",           default=-1,         help=SUPPRESS_HELP)
	parser.add_option("--rs",                    type="int",           default=1,          help="step between rings in rotational search: >0 (default 1)")
	parser.add_option("--ts",                    type="string",        default='1.0',      help="step size of the translation search in x-y directions: search is -xr, -xr+ts, 0, xr-ts, xr, can be fractional (default '1.0')")
	parser.add_option("--delta",                 type="string",        default='2.0',      help="angular step of reference projections: (default '2.0')")
	parser.add_option("--center",                type="float",         default=-1.0,       help="centering of 3D template: average shift method; 0: no centering; 1: center of gravity (default -1.0)")
	parser.add_option("--maxit1",                type="int",           default=400,        help="maximum number of iterations performed for the GA part: (default 400)")
	parser.add_option("--maxit2",                type="int",           default=50,         help="maximum number of iterations performed for the finishing up part: (default 50)")
	parser.add_option("--L2threshold",           type="float",         default=0.03,       help="stopping criterion of GA: given as a maximum relative dispersion of volumes' L2 norms: (default 0.03)")
	parser.add_option("--ref_a",                 type="string",        default='S',        help="method for generating the quasi-uniformly distributed projection directions: (default S)")
	parser.add_option("--sym",                   type="string",        default='c1',       help="point-group symmetry of the structure: (default c1)")
	
	# parser.add_option("--function", type="string", default="ref_ali3d",         help="name of the reference preparation function (ref_ali3d by default)")
	##### XXXXXXXXXXXXXXXXXXXXXX option does not exist in docs XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
	parser.add_option("--function", type="string", default="ref_ali3d",         help= SUPPRESS_HELP)
	
	parser.add_option("--nruns",                 type="int",           default=6,          help="GA population: aka number of quasi-independent volumes (default 6)")
	parser.add_option("--doga",                  type="float",         default=0.1,        help="do GA when fraction of orientation changes less than 1.0 degrees is at least doga: (default 0.1)")
	##### XXXXXXXXXXXXXXXXXXXXXX option does not exist in docs XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
	parser.add_option("--npad",     type="int",    default= 2,                  help="padding size for 3D reconstruction (default=2)")
	parser.add_option("--fl",                    type="float",         default=0.25,       help="cut-off frequency applied to the template volume: using a hyperbolic tangent low-pass filter (default 0.25)")
	parser.add_option("--aa",                    type="float",         default=0.1,        help="fall-off of hyperbolic tangent low-pass filter: (default 0.1)")
	parser.add_option("--pwreference",           type="string",        default='',         help="text file with a reference power spectrum: (default none)")
	parser.add_option("--debug",                 action="store_true",  default=False,      help="debug info printout: (default False)")
	
	##### XXXXXXXXXXXXXXXXXXXXXX option does not exist in docs XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
	parser.add_option("--return_options", action="store_true", dest="return_options", default=False, help = SUPPRESS_HELP)	
	
	#parser.add_option("--an",       type="string", default= "-1",               help="NOT USED angular neighborhood for local searches (phi and theta)")
	#parser.add_option("--CTF",      action="store_true", default=False,         help="NOT USED Consider CTF correction during the alignment ")
	#parser.add_option("--snr",      type="float",  default= 1.0,                help="NOT USED Signal-to-Noise Ratio of the data (default 1.0)")
	# (options, args) = parser.parse_args(sys.argv[1:])

	required_option_list = ['radius']
	(options, args) = parser.parse_args(args)
	# option_dict = vars(options)
	# print parser
	
	if options.return_options:
		return parser
	
	if options.moon_elimination == "":
		options.moon_elimination = []
	else:
		options.moon_elimination = map(float, options.moon_elimination.split(","))

	# Making sure all required options appeared.
	for required_option in required_option_list:
		if not options.__dict__[required_option]:
			print "\n ==%s== mandatory option is missing.\n"%required_option
			print "Please run '" + progname + " -h' for detailed options"
			return 1



	if len(args) < 2 or len(args) > 3:
		print "usage: " + usage
		print "Please run '" + progname + " -h' for detailed options"
		return 1

	mpi_init(0, [])

	log = Logger(BaseLogger_Files())

	# 'radius' and 'ou' are the same as per Pawel's request; 'ou' is hidden from the user
	# the 'ou' variable is not changed to 'radius' in the 'sparx' program. This change is at interface level only for sxviper.
	options.ou = options.radius 
	runs_count = options.nruns
	mpi_rank = mpi_comm_rank(MPI_COMM_WORLD)
	mpi_size = mpi_comm_size(MPI_COMM_WORLD)	# Total number of processes, passed by --np option.
	
	if mpi_rank == 0:
		all_projs = EMData.read_images(args[0])
		subset = range(len(all_projs))
		# if mpi_size > len(all_projs):
		# 	ERROR('Number of processes supplied by --np needs to be less than or equal to %d (total number of images) ' % len(all_projs), 'sxviper', 1)
		# 	mpi_finalize()
		# 	return
	else:
		all_projs = None
		subset = None

	outdir = args[1]
	if mpi_rank == 0:
		if mpi_size % options.nruns != 0:
			ERROR('Number of processes needs to be a multiple of total number of runs. Total runs by default are 3, you can change it by specifying --nruns option.', 'sxviper', 1)
			mpi_finalize()
			return

		if os.path.exists(outdir):
			ERROR('Output directory exists, please change the name and restart the program', "sxviper", 1)
			mpi_finalize()
			return

		os.mkdir(outdir)
		import global_def
		global_def.LOGFILE =  os.path.join(outdir, global_def.LOGFILE)

	mpi_barrier(MPI_COMM_WORLD)

	if outdir[-1] != "/":
		outdir += "/"
	log.prefix = outdir
	
	# if len(args) > 2:
	# 	ref_vol = get_im(args[2])
	# else:
	ref_vol = None

	options.user_func = user_functions.factory[options.function]

	options.CTF = False
	options.snr = 1.0
	options.an  = -1.0
	from multi_shc import multi_shc
	out_params, out_vol, out_peaks = multi_shc(all_projs, subset, runs_count, options, mpi_comm=MPI_COMM_WORLD, log=log, ref_vol=ref_vol)

	mpi_finalize()
Beispiel #34
0
def main():
    import os
    import sys
    from optparse import OptionParser

    arglist = []
    for arg in sys.argv:
        arglist.append(arg)
    progname = os.path.basename(arglist[0])
    usage = (
        progname
        + """ inputvolume  locresvolume maskfile outputfile   --radius --falloff  --MPI

	    Locally filer a volume based on local resolution volume (sxlocres.py) within area outlined by the maskfile
	"""
    )
    parser = OptionParser(usage, version=SPARXVERSION)

    parser.add_option(
        "--radius",
        type="int",
        default=-1,
        help="if there is no maskfile, sphere with r=radius will be used, by default the radius is nx/2-1",
    )
    parser.add_option("--falloff", type="float", default=0.1, help="falloff of tanl filter (default 0.1)")
    parser.add_option("--MPI", action="store_true", default=False, help="use MPI version")

    (options, args) = parser.parse_args(arglist[1:])

    if len(args) < 3 or len(args) > 4:
        print "See usage " + usage
        sys.exit()

    if global_def.CACHE_DISABLE:
        from utilities import disable_bdb_cache

        disable_bdb_cache()

    if options.MPI:
        from mpi import mpi_init, mpi_comm_size, mpi_comm_rank, MPI_COMM_WORLD
        from mpi import mpi_reduce, mpi_bcast, mpi_barrier, mpi_gatherv, mpi_send, mpi_recv
        from mpi import MPI_SUM, MPI_FLOAT, MPI_INT

        sys.argv = mpi_init(len(sys.argv), sys.argv)

        number_of_proc = mpi_comm_size(MPI_COMM_WORLD)
        myid = mpi_comm_rank(MPI_COMM_WORLD)
        main_node = 0

        if myid == main_node:
            # print sys.argv
            vi = get_im(sys.argv[1])
            ui = get_im(sys.argv[2])
            # print   Util.infomask(ui, None, True)
            radius = options.radius
            nx = vi.get_xsize()
            ny = vi.get_ysize()
            nz = vi.get_zsize()
            dis = [nx, ny, nz]
        else:
            falloff = 0.0
            radius = 0
            dis = [0, 0, 0]
            vi = None
            ui = None
        dis = bcast_list_to_all(dis, myid, source_node=main_node)

        if myid != main_node:
            nx = int(dis[0])
            ny = int(dis[1])
            nz = int(dis[2])
        radius = bcast_number_to_all(radius, main_node)
        if len(args) == 3:
            if radius == -1:
                radius = min(nx, ny, nz) // 2 - 1
            m = model_circle(radius, nx, ny, nz)
            outvol = args[2]

        elif len(args) == 4:
            if myid == main_node:
                m = binarize(get_im(args[2]), 0.5)
            else:
                m = model_blank(nx, ny, nz)
            outvol = args[3]
            bcast_EMData_to_all(m, myid, main_node)

        from filter import filterlocal

        filteredvol = filterlocal(ui, vi, m, options.falloff, myid, main_node, number_of_proc)

        if myid == 0:
            filteredvol.write_image(outvol)

        from mpi import mpi_finalize

        mpi_finalize()

    else:
        vi = get_im(args[0])
        ui = get_im(args[1])  # resolution volume, values are assumed to be from 0 to 0.5

        nn = vi.get_xsize()

        falloff = options.falloff

        if len(args) == 3:
            radius = options.radius
            if radius == -1:
                radius = nn // 2 - 1
            m = model_circle(radius, nn, nn, nn)
            outvol = args[2]

        elif len(args) == 4:
            m = binarize(get_im(args[2]), 0.5)
            outvol = args[3]

        fftip(vi)  # this is the volume to be filtered

        #  Round all resolution numbers to two digits
        for x in xrange(nn):
            for y in xrange(nn):
                for z in xrange(nn):
                    ui.set_value_at_fast(x, y, z, round(ui.get_value_at(x, y, z), 2))
        st = Util.infomask(ui, m, True)

        filteredvol = model_blank(nn, nn, nn)
        cutoff = max(st[2] - 0.01, 0.0)
        while cutoff < st[3]:
            cutoff = round(cutoff + 0.01, 2)
            pt = Util.infomask(threshold_outside(ui, cutoff - 0.00501, cutoff + 0.005), m, True)
            if pt[0] != 0.0:
                vovo = fft(filt_tanl(vi, cutoff, falloff))
                for x in xrange(nn):
                    for y in xrange(nn):
                        for z in xrange(nn):
                            if m.get_value_at(x, y, z) > 0.5:
                                if round(ui.get_value_at(x, y, z), 2) == cutoff:
                                    filteredvol.set_value_at_fast(x, y, z, vovo.get_value_at(x, y, z))

        filteredvol.write_image(outvol)
Beispiel #35
0
#!/usr/bin/env /usr/bin/python
import numpy
from numpy import *
import mpi
import sys
from time import gmtime, time, sleep


def stamp():
    timeTuple = gmtime(time())[1:6]
    return "%02d%02d%02d%02d%02d" % timeTuple


sys.argv = mpi.mpi_init(len(sys.argv), sys.argv)
myid = mpi.mpi_comm_rank(mpi.MPI_COMM_WORLD)
numprocs = mpi.mpi_comm_size(mpi.MPI_COMM_WORLD)
parent = mpi.mpi_comm_get_parent()
parentSize = mpi.mpi_comm_size(parent)
print "parentSize", parentSize

tod = stamp()
s = sys.argv[1] + "%2.2d" % myid
print "hello from python worker", myid, " writing to ", s

x = array([5, 3, 4, 2], 'i')
print "starting bcast"
buffer = mpi.mpi_bcast(x, 4, mpi.MPI_INT, 0, parent)
out = open(s, "w")
out.write(str(buffer))
out.write(tod + "\n")
out.close()
Beispiel #36
0
def main():
    arglist = []
    for arg in sys.argv:
        arglist.append(arg)
    progname = optparse.os.path.basename(arglist[0])
    usage = progname + """ inputvolume  locresvolume maskfile outputfile   --radius --falloff  --MPI

	    Locally filer a volume based on local resolution volume (sxlocres.py) within area outlined by the maskfile
	"""
    parser = optparse.OptionParser(usage, version=sp_global_def.SPARXVERSION)

    parser.add_option(
        "--radius",
        type="int",
        default=-1,
        help=
        "if there is no maskfile, sphere with r=radius will be used, by default the radius is nx/2-1"
    )
    parser.add_option("--falloff",
                      type="float",
                      default=0.1,
                      help="falloff of tanl filter (default 0.1)")
    parser.add_option("--MPI",
                      action="store_true",
                      default=False,
                      help="use MPI version")

    (options, args) = parser.parse_args(arglist[1:])

    if len(args) < 3 or len(args) > 4:
        sp_global_def.sxprint("See usage " + usage)
        sp_global_def.ERROR(
            "Wrong number of parameters. Please see usage information above.")
        return

    if sp_global_def.CACHE_DISABLE:
        pass  #IMPORTIMPORTIMPORT from sp_utilities import disable_bdb_cache
        sp_utilities.disable_bdb_cache()

    if options.MPI:
        number_of_proc = mpi.mpi_comm_size(mpi.MPI_COMM_WORLD)
        myid = mpi.mpi_comm_rank(mpi.MPI_COMM_WORLD)
        main_node = 0

        if (myid == main_node):
            #print sys.argv
            vi = sp_utilities.get_im(sys.argv[1])
            ui = sp_utilities.get_im(sys.argv[2])
            #print   Util.infomask(ui, None, True)
            radius = options.radius
            nx = vi.get_xsize()
            ny = vi.get_ysize()
            nz = vi.get_zsize()
            dis = [nx, ny, nz]
        else:
            falloff = 0.0
            radius = 0
            dis = [0, 0, 0]
            vi = None
            ui = None
        dis = sp_utilities.bcast_list_to_all(dis, myid, source_node=main_node)

        if (myid != main_node):
            nx = int(dis[0])
            ny = int(dis[1])
            nz = int(dis[2])
        radius = sp_utilities.bcast_number_to_all(radius, main_node)
        if len(args) == 3:
            if (radius == -1):
                radius = min(nx, ny, nz) // 2 - 1
            m = sp_utilities.model_circle(radius, nx, ny, nz)
            outvol = args[2]

        elif len(args) == 4:
            if (myid == main_node):
                m = sp_morphology.binarize(sp_utilities.get_im(args[2]), 0.5)
            else:
                m = sp_utilities.model_blank(nx, ny, nz)
            outvol = args[3]
            sp_utilities.bcast_EMData_to_all(m, myid, main_node)

        pass  #IMPORTIMPORTIMPORT from sp_filter import filterlocal
        filteredvol = sp_filter.filterlocal(ui, vi, m, options.falloff, myid,
                                            main_node, number_of_proc)

        if (myid == 0):
            filteredvol.write_image(outvol)

    else:
        vi = sp_utilities.get_im(args[0])
        ui = sp_utilities.get_im(
            args[1]
        )  # resolution volume, values are assumed to be from 0 to 0.5

        nn = vi.get_xsize()

        falloff = options.falloff

        if len(args) == 3:
            radius = options.radius
            if (radius == -1):
                radius = nn // 2 - 1
            m = sp_utilities.model_circle(radius, nn, nn, nn)
            outvol = args[2]

        elif len(args) == 4:
            m = sp_morphology.binarize(sp_utilities.get_im(args[2]), 0.5)
            outvol = args[3]

        sp_fundamentals.fftip(vi)  # this is the volume to be filtered

        #  Round all resolution numbers to two digits
        for x in range(nn):
            for y in range(nn):
                for z in range(nn):
                    ui.set_value_at_fast(x, y, z,
                                         round(ui.get_value_at(x, y, z), 2))
        st = EMAN2_cppwrap.Util.infomask(ui, m, True)

        filteredvol = sp_utilities.model_blank(nn, nn, nn)
        cutoff = max(st[2] - 0.01, 0.0)
        while (cutoff < st[3]):
            cutoff = round(cutoff + 0.01, 2)
            pt = EMAN2_cppwrap.Util.infomask(
                sp_morphology.threshold_outside(ui, cutoff - 0.00501,
                                                cutoff + 0.005), m, True)
            if (pt[0] != 0.0):
                vovo = sp_fundamentals.fft(
                    sp_filter.filt_tanl(vi, cutoff, falloff))
                for x in range(nn):
                    for y in range(nn):
                        for z in range(nn):
                            if (m.get_value_at(x, y, z) > 0.5):
                                if (round(ui.get_value_at(x, y, z),
                                          2) == cutoff):
                                    filteredvol.set_value_at_fast(
                                        x, y, z, vovo.get_value_at(x, y, z))

        sp_global_def.write_command(optparse.os.path.dirname(outvol))
        filteredvol.write_image(outvol)
Beispiel #37
0
def main():
    arglist = []
    for arg in sys.argv:
        arglist.append(arg)
    progname = os.path.basename(arglist[0])
    usage = progname + """ firstvolume  secondvolume  maskfile  outputfile  --wn  --step  --cutoff  --radius  --fsc  --res_overall  --out_ang_res  --apix  --MPI

	Compute local resolution in real space within area outlined by the maskfile and within regions wn x wn x wn
	"""
    parser = optparse.OptionParser(usage, version=global_def.SPARXVERSION)

    parser.add_option(
        "--wn",
        type="int",
        default=7,
        help=
        "Size of window within which local real-space FSC is computed. (default 7)"
    )
    parser.add_option(
        "--step",
        type="float",
        default=1.0,
        help="Shell step in Fourier size in pixels. (default 1.0)")
    parser.add_option("--cutoff",
                      type="float",
                      default=0.5,
                      help="Resolution cut-off for FSC. (default 0.5)")
    parser.add_option(
        "--radius",
        type="int",
        default=-1,
        help=
        "If there is no maskfile, sphere with r=radius will be used. By default, the radius is nx/2-wn (default -1)"
    )
    parser.add_option(
        "--fsc",
        type="string",
        default=None,
        help=
        "Save overall FSC curve (might be truncated). By default, the program does not save the FSC curve. (default none)"
    )
    parser.add_option(
        "--res_overall",
        type="float",
        default=-1.0,
        help=
        "Overall resolution at the cutoff level estimated by the user [abs units]. (default None)"
    )
    parser.add_option(
        "--out_ang_res",
        action="store_true",
        default=False,
        help=
        "Additionally creates a local resolution file in Angstroms. (default False)"
    )
    parser.add_option(
        "--apix",
        type="float",
        default=1.0,
        help=
        "Pixel size in Angstrom. Effective only with --out_ang_res options. (default 1.0)"
    )
    parser.add_option("--MPI",
                      action="store_true",
                      default=False,
                      help="Use MPI version.")

    (options, args) = parser.parse_args(arglist[1:])

    if len(args) < 3 or len(args) > 4:
        print("See usage " + usage)
        sys.exit()

    if global_def.CACHE_DISABLE:
        utilities.disable_bdb_cache()

    res_overall = options.res_overall

    if options.MPI:
        sys.argv = mpi.mpi_init(len(sys.argv), sys.argv)

        number_of_proc = mpi.mpi_comm_size(mpi.MPI_COMM_WORLD)
        myid = mpi.mpi_comm_rank(mpi.MPI_COMM_WORLD)
        main_node = 0
        global_def.MPI = True
        cutoff = options.cutoff

        nk = int(options.wn)

        if (myid == main_node):
            #print sys.argv
            vi = utilities.get_im(sys.argv[1])
            ui = utilities.get_im(sys.argv[2])

            nx = vi.get_xsize()
            ny = vi.get_ysize()
            nz = vi.get_zsize()
            dis = [nx, ny, nz]
        else:
            dis = [0, 0, 0, 0]

        global_def.BATCH = True

        dis = utilities.bcast_list_to_all(dis, myid, source_node=main_node)

        if (myid != main_node):
            nx = int(dis[0])
            ny = int(dis[1])
            nz = int(dis[2])

            vi = utilities.model_blank(nx, ny, nz)
            ui = utilities.model_blank(nx, ny, nz)

        if len(args) == 3:
            m = utilities.model_circle((min(nx, ny, nz) - nk) // 2, nx, ny, nz)
            outvol = args[2]

        elif len(args) == 4:
            if (myid == main_node):
                m = morphology.binarize(utilities.get_im(args[2]), 0.5)
            else:
                m = utilities.model_blank(nx, ny, nz)
            outvol = args[3]
        utilities.bcast_EMData_to_all(m, myid, main_node)
        """Multiline Comment0"""
        freqvol, resolut = statistics.locres(vi, ui, m, nk, cutoff,
                                             options.step, myid, main_node,
                                             number_of_proc)

        if (myid == 0):
            # Remove outliers based on the Interquartile range
            output_volume(freqvol, resolut, options.apix, outvol, options.fsc,
                          options.out_ang_res, nx, ny, nz, res_overall)
        mpi.mpi_finalize()

    else:
        cutoff = options.cutoff
        vi = utilities.get_im(args[0])
        ui = utilities.get_im(args[1])

        nn = vi.get_xsize()
        nk = int(options.wn)

        if len(args) == 3:
            m = utilities.model_circle((nn - nk) // 2, nn, nn, nn)
            outvol = args[2]

        elif len(args) == 4:
            m = morphology.binarize(utilities.get_im(args[2]), 0.5)
            outvol = args[3]

        mc = utilities.model_blank(nn, nn, nn, 1.0) - m

        vf = fundamentals.fft(vi)
        uf = fundamentals.fft(ui)
        """Multiline Comment1"""
        lp = int(nn / 2 / options.step + 0.5)
        step = 0.5 / lp

        freqvol = utilities.model_blank(nn, nn, nn)
        resolut = []
        for i in range(1, lp):
            fl = step * i
            fh = fl + step
            #print(lp,i,step,fl,fh)
            v = fundamentals.fft(filter.filt_tophatb(vf, fl, fh))
            u = fundamentals.fft(filter.filt_tophatb(uf, fl, fh))
            tmp1 = EMAN2_cppwrap.Util.muln_img(v, v)
            tmp2 = EMAN2_cppwrap.Util.muln_img(u, u)

            do = EMAN2_cppwrap.Util.infomask(
                morphology.square_root(
                    morphology.threshold(
                        EMAN2_cppwrap.Util.muln_img(tmp1, tmp2))), m, True)[0]

            tmp3 = EMAN2_cppwrap.Util.muln_img(u, v)
            dp = EMAN2_cppwrap.Util.infomask(tmp3, m, True)[0]
            resolut.append([i, (fl + fh) / 2.0, dp / do])

            tmp1 = EMAN2_cppwrap.Util.box_convolution(tmp1, nk)
            tmp2 = EMAN2_cppwrap.Util.box_convolution(tmp2, nk)
            tmp3 = EMAN2_cppwrap.Util.box_convolution(tmp3, nk)

            EMAN2_cppwrap.Util.mul_img(tmp1, tmp2)

            tmp1 = morphology.square_root(morphology.threshold(tmp1))

            EMAN2_cppwrap.Util.mul_img(tmp1, m)
            EMAN2_cppwrap.Util.add_img(tmp1, mc)

            EMAN2_cppwrap.Util.mul_img(tmp3, m)
            EMAN2_cppwrap.Util.add_img(tmp3, mc)

            EMAN2_cppwrap.Util.div_img(tmp3, tmp1)

            EMAN2_cppwrap.Util.mul_img(tmp3, m)
            freq = (fl + fh) / 2.0
            bailout = True
            for x in range(nn):
                for y in range(nn):
                    for z in range(nn):
                        if (m.get_value_at(x, y, z) > 0.5):
                            if (freqvol.get_value_at(x, y, z) == 0.0):
                                if (tmp3.get_value_at(x, y, z) < cutoff):
                                    freqvol.set_value_at(x, y, z, freq)
                                    bailout = False
                                else:
                                    bailout = False
            if (bailout): break
        #print(len(resolut))
        # remove outliers
        output_volume(freqvol, resolut, options.apix, outvol, options.fsc,
                      options.out_ang_res, nx, ny, nz, res_overall)
Beispiel #38
0
def shiftali_MPI(stack,
                 maskfile=None,
                 maxit=100,
                 CTF=False,
                 snr=1.0,
                 Fourvar=False,
                 search_rng=-1,
                 oneDx=False,
                 search_rng_y=-1):

    number_of_proc = mpi.mpi_comm_size(mpi.MPI_COMM_WORLD)
    myid = mpi.mpi_comm_rank(mpi.MPI_COMM_WORLD)
    main_node = 0

    ftp = file_type(stack)

    if myid == main_node:
        print_begin_msg("shiftali_MPI")

    max_iter = int(maxit)

    if myid == main_node:
        if ftp == "bdb":
            from EMAN2db import db_open_dict
            dummy = db_open_dict(stack, True)
        nima = EMUtil.get_image_count(stack)
    else:
        nima = 0
    nima = bcast_number_to_all(nima, source_node=main_node)
    list_of_particles = list(range(nima))

    image_start, image_end = MPI_start_end(nima, number_of_proc, myid)
    list_of_particles = list_of_particles[image_start:image_end]

    # read nx and ctf_app (if CTF) and broadcast to all nodes
    if myid == main_node:
        ima = EMData()
        ima.read_image(stack, list_of_particles[0], True)
        nx = ima.get_xsize()
        ny = ima.get_ysize()
        if CTF: ctf_app = ima.get_attr_default('ctf_applied', 2)
        del ima
    else:
        nx = 0
        ny = 0
        if CTF: ctf_app = 0
    nx = bcast_number_to_all(nx, source_node=main_node)
    ny = bcast_number_to_all(ny, source_node=main_node)
    if CTF:
        ctf_app = bcast_number_to_all(ctf_app, source_node=main_node)
        if ctf_app > 0:
            ERROR("data cannot be ctf-applied", myid=myid)

    if maskfile == None:
        mrad = min(nx, ny)
        mask = model_circle(mrad // 2 - 2, nx, ny)
    else:
        mask = get_im(maskfile)

    if CTF:
        from sp_filter import filt_ctf
        from sp_morphology import ctf_img
        ctf_abs_sum = EMData(nx, ny, 1, False)
        ctf_2_sum = EMData(nx, ny, 1, False)
    else:
        ctf_2_sum = None

    from sp_global_def import CACHE_DISABLE
    if CACHE_DISABLE:
        data = EMData.read_images(stack, list_of_particles)
    else:
        for i in range(number_of_proc):
            if myid == i:
                data = EMData.read_images(stack, list_of_particles)
            if ftp == "bdb": mpi.mpi_barrier(mpi.MPI_COMM_WORLD)

    for im in range(len(data)):
        data[im].set_attr('ID', list_of_particles[im])
        st = Util.infomask(data[im], mask, False)
        data[im] -= st[0]
        if CTF:
            ctf_params = data[im].get_attr("ctf")
            ctfimg = ctf_img(nx, ctf_params, ny=ny)
            Util.add_img2(ctf_2_sum, ctfimg)
            Util.add_img_abs(ctf_abs_sum, ctfimg)

    if CTF:
        reduce_EMData_to_root(ctf_2_sum, myid, main_node)
        reduce_EMData_to_root(ctf_abs_sum, myid, main_node)
    else:
        ctf_2_sum = None
    if CTF:
        if myid != main_node:
            del ctf_2_sum
            del ctf_abs_sum
        else:
            temp = EMData(nx, ny, 1, False)
            for i in range(0, nx, 2):
                for j in range(ny):
                    temp.set_value_at(i, j, snr)
            Util.add_img(ctf_2_sum, temp)
            del temp

    total_iter = 0

    # apply initial xform.align2d parameters stored in header
    init_params = []
    for im in range(len(data)):
        t = data[im].get_attr('xform.align2d')
        init_params.append(t)
        p = t.get_params("2d")
        data[im] = rot_shift2D(data[im],
                               p['alpha'],
                               sx=p['tx'],
                               sy=p['ty'],
                               mirror=p['mirror'],
                               scale=p['scale'])

    # fourier transform all images, and apply ctf if CTF
    for im in range(len(data)):
        if CTF:
            ctf_params = data[im].get_attr("ctf")
            data[im] = filt_ctf(fft(data[im]), ctf_params)
        else:
            data[im] = fft(data[im])

    sx_sum = 0
    sy_sum = 0
    sx_sum_total = 0
    sy_sum_total = 0
    shift_x = [0.0] * len(data)
    shift_y = [0.0] * len(data)
    ishift_x = [0.0] * len(data)
    ishift_y = [0.0] * len(data)

    for Iter in range(max_iter):
        if myid == main_node:
            start_time = time()
            print_msg("Iteration #%4d\n" % (total_iter))
        total_iter += 1
        avg = EMData(nx, ny, 1, False)
        for im in data:
            Util.add_img(avg, im)

        reduce_EMData_to_root(avg, myid, main_node)

        if myid == main_node:
            if CTF:
                tavg = Util.divn_filter(avg, ctf_2_sum)
            else:
                tavg = Util.mult_scalar(avg, 1.0 / float(nima))
        else:
            tavg = EMData(nx, ny, 1, False)

        if Fourvar:
            bcast_EMData_to_all(tavg, myid, main_node)
            vav, rvar = varf2d_MPI(myid, data, tavg, mask, "a", CTF)

        if myid == main_node:
            if Fourvar:
                tavg = fft(Util.divn_img(fft(tavg), vav))
                vav_r = Util.pack_complex_to_real(vav)

            # normalize and mask tavg in real space
            tavg = fft(tavg)
            stat = Util.infomask(tavg, mask, False)
            tavg -= stat[0]
            Util.mul_img(tavg, mask)
            # For testing purposes: shift tavg to some random place and see if the centering is still correct
            #tavg = rot_shift3D(tavg,sx=3,sy=-4)
            tavg = fft(tavg)

        if Fourvar: del vav
        bcast_EMData_to_all(tavg, myid, main_node)

        sx_sum = 0
        sy_sum = 0
        if search_rng > 0: nwx = 2 * search_rng + 1
        else: nwx = nx

        if search_rng_y > 0: nwy = 2 * search_rng_y + 1
        else: nwy = ny

        not_zero = 0
        for im in range(len(data)):
            if oneDx:
                ctx = Util.window(ccf(data[im], tavg), nwx, 1)
                p1 = peak_search(ctx)
                p1_x = -int(p1[0][3])
                ishift_x[im] = p1_x
                sx_sum += p1_x
            else:
                p1 = peak_search(Util.window(ccf(data[im], tavg), nwx, nwy))
                p1_x = -int(p1[0][4])
                p1_y = -int(p1[0][5])
                ishift_x[im] = p1_x
                ishift_y[im] = p1_y
                sx_sum += p1_x
                sy_sum += p1_y

            if not_zero == 0:
                if (not (ishift_x[im] == 0.0)) or (not (ishift_y[im] == 0.0)):
                    not_zero = 1

        sx_sum = mpi.mpi_reduce(sx_sum, 1, mpi.MPI_INT, mpi.MPI_SUM, main_node,
                                mpi.MPI_COMM_WORLD)

        if not oneDx:
            sy_sum = mpi.mpi_reduce(sy_sum, 1, mpi.MPI_INT, mpi.MPI_SUM,
                                    main_node, mpi.MPI_COMM_WORLD)

        if myid == main_node:
            sx_sum_total = int(sx_sum[0])
            if not oneDx:
                sy_sum_total = int(sy_sum[0])
        else:
            sx_sum_total = 0
            sy_sum_total = 0

        sx_sum_total = bcast_number_to_all(sx_sum_total, source_node=main_node)

        if not oneDx:
            sy_sum_total = bcast_number_to_all(sy_sum_total,
                                               source_node=main_node)

        sx_ave = round(float(sx_sum_total) / nima)
        sy_ave = round(float(sy_sum_total) / nima)
        for im in range(len(data)):
            p1_x = ishift_x[im] - sx_ave
            p1_y = ishift_y[im] - sy_ave
            params2 = {
                "filter_type": Processor.fourier_filter_types.SHIFT,
                "x_shift": p1_x,
                "y_shift": p1_y,
                "z_shift": 0.0
            }
            data[im] = Processor.EMFourierFilter(data[im], params2)
            shift_x[im] += p1_x
            shift_y[im] += p1_y
        # stop if all shifts are zero
        not_zero = mpi.mpi_reduce(not_zero, 1, mpi.MPI_INT, mpi.MPI_SUM,
                                  main_node, mpi.MPI_COMM_WORLD)
        if myid == main_node:
            not_zero_all = int(not_zero[0])
        else:
            not_zero_all = 0
        not_zero_all = bcast_number_to_all(not_zero_all, source_node=main_node)

        if myid == main_node:
            print_msg("Time of iteration = %12.2f\n" % (time() - start_time))
            start_time = time()

        if not_zero_all == 0: break

    #for im in xrange(len(data)): data[im] = fft(data[im])  This should not be required as only header information is used
    # combine shifts found with the original parameters
    for im in range(len(data)):
        t0 = init_params[im]
        t1 = Transform()
        t1.set_params({
            "type": "2D",
            "alpha": 0,
            "scale": t0.get_scale(),
            "mirror": 0,
            "tx": shift_x[im],
            "ty": shift_y[im]
        })
        # combine t0 and t1
        tt = t1 * t0
        data[im].set_attr("xform.align2d", tt)

    # write out headers and STOP, under MPI writing has to be done sequentially
    mpi.mpi_barrier(mpi.MPI_COMM_WORLD)
    par_str = ["xform.align2d", "ID"]
    if myid == main_node:
        from sp_utilities import file_type
        if (file_type(stack) == "bdb"):
            from sp_utilities import recv_attr_dict_bdb
            recv_attr_dict_bdb(main_node, stack, data, par_str, image_start,
                               image_end, number_of_proc)
        else:
            from sp_utilities import recv_attr_dict
            recv_attr_dict(main_node, stack, data, par_str, image_start,
                           image_end, number_of_proc)

    else:
        send_attr_dict(main_node, data, par_str, image_start, image_end)
    if myid == main_node: print_end_msg("shiftali_MPI")
Beispiel #39
0
def do_volume_mrk02(ref_data):
	"""
		data - projections (scattered between cpus) or the volume.  If volume, just do the volume processing
		options - the same for all cpus
		return - volume the same for all cpus
	"""
	from EMAN2          import Util
	from mpi            import mpi_comm_rank, mpi_comm_size, MPI_COMM_WORLD
	from filter         import filt_table
	from reconstruction import recons3d_4nn_MPI, recons3d_4nn_ctf_MPI
	from utilities      import bcast_EMData_to_all, bcast_number_to_all, model_blank
	from fundamentals import rops_table, fftip, fft
	import types

	# Retrieve the function specific input arguments from ref_data
	data     = ref_data[0]
	Tracker  = ref_data[1]
	iter     = ref_data[2]
	mpi_comm = ref_data[3]
	
	# # For DEBUG
	# print "Type of data %s" % (type(data))
	# print "Type of Tracker %s" % (type(Tracker))
	# print "Type of iter %s" % (type(iter))
	# print "Type of mpi_comm %s" % (type(mpi_comm))
	
	if(mpi_comm == None):  mpi_comm = MPI_COMM_WORLD
	myid  = mpi_comm_rank(mpi_comm)
	nproc = mpi_comm_size(mpi_comm)
	
	try:     local_filter = Tracker["local_filter"]
	except:  local_filter = False
	#=========================================================================
	# volume reconstruction
	if( type(data) == types.ListType ):
		if Tracker["constants"]["CTF"]:
			vol = recons3d_4nn_ctf_MPI(myid, data, Tracker["constants"]["snr"], \
					symmetry=Tracker["constants"]["sym"], npad=Tracker["constants"]["npad"], mpi_comm=mpi_comm, smearstep = Tracker["smearstep"])
		else:
			vol = recons3d_4nn_MPI    (myid, data,\
					symmetry=Tracker["constants"]["sym"], npad=Tracker["constants"]["npad"], mpi_comm=mpi_comm)
	else:
		vol = data

	if myid == 0:
		from morphology import threshold
		from filter     import filt_tanl, filt_btwl
		from utilities  import model_circle, get_im
		import types
		nx = vol.get_xsize()
		if(Tracker["constants"]["mask3D"] == None):
			mask3D = model_circle(int(Tracker["constants"]["radius"]*float(nx)/float(Tracker["constants"]["nnxo"])+0.5), nx, nx, nx)
		elif(Tracker["constants"]["mask3D"] == "auto"):
			from utilities import adaptive_mask
			mask3D = adaptive_mask(vol)
		else:
			if( type(Tracker["constants"]["mask3D"]) == types.StringType ):  mask3D = get_im(Tracker["constants"]["mask3D"])
			else:  mask3D = (Tracker["constants"]["mask3D"]).copy()
			nxm = mask3D.get_xsize()
			if( nx != nxm):
				from fundamentals import rot_shift3D
				mask3D = Util.window(rot_shift3D(mask3D,scale=float(nx)/float(nxm)),nx,nx,nx)
				nxm = mask3D.get_xsize()
				assert(nx == nxm)

		stat = Util.infomask(vol, mask3D, False)
		vol -= stat[0]
		Util.mul_scalar(vol, 1.0/stat[1])
		vol = threshold(vol)
		Util.mul_img(vol, mask3D)
		if( Tracker["PWadjustment"] ):
			from utilities    import read_text_file, write_text_file
			rt = read_text_file( Tracker["PWadjustment"] )
			fftip(vol)
			ro = rops_table(vol)
			#  Here unless I am mistaken it is enough to take the beginning of the reference pw.
			for i in xrange(1,len(ro)):  ro[i] = (rt[i]/ro[i])**Tracker["upscale"]
			#write_text_file(rops_table(filt_table( vol, ro),1),"foo.txt")
			if Tracker["constants"]["sausage"]:
				ny = vol.get_ysize()
				y = float(ny)
				from math import exp
				for i in xrange(len(ro)):  ro[i] *= \
				  (1.0+1.0*exp(-(((i/y/Tracker["constants"]["pixel_size"])-0.10)/0.025)**2)+1.0*exp(-(((i/y/Tracker["constants"]["pixel_size"])-0.215)/0.025)**2))

			if local_filter:
				# skip low-pass filtration
				vol = fft( filt_table( vol, ro) )
			else:
				if( type(Tracker["lowpass"]) == types.ListType ):
					vol = fft( filt_table( filt_table(vol, Tracker["lowpass"]), ro) )
				else:
					vol = fft( filt_table( filt_tanl(vol, Tracker["lowpass"], Tracker["falloff"]), ro) )
			del ro
		else:
			if Tracker["constants"]["sausage"]:
				ny = vol.get_ysize()
				y = float(ny)
				ro = [0.0]*(ny//2+2)
				from math import exp
				for i in xrange(len(ro)):  ro[i] = \
				  (1.0+1.0*exp(-(((i/y/Tracker["constants"]["pixel_size"])-0.10)/0.025)**2)+1.0*exp(-(((i/y/Tracker["constants"]["pixel_size"])-0.215)/0.025)**2))
				fftip(vol)
				filt_table(vol, ro)
				del ro
			if not local_filter:
				if( type(Tracker["lowpass"]) == types.ListType ):
					vol = filt_table(vol, Tracker["lowpass"])
				else:
					vol = filt_tanl(vol, Tracker["lowpass"], Tracker["falloff"])
			if Tracker["constants"]["sausage"]: vol = fft(vol)

	if local_filter:
		from morphology import binarize
		if(myid == 0): nx = mask3D.get_xsize()
		else:  nx = 0
		nx = bcast_number_to_all(nx, source_node = 0)
		#  only main processor needs the two input volumes
		if(myid == 0):
			mask = binarize(mask3D, 0.5)
			locres = get_im(Tracker["local_filter"])
			lx = locres.get_xsize()
			if(lx != nx):
				if(lx < nx):
					from fundamentals import fdecimate, rot_shift3D
					mask = Util.window(rot_shift3D(mask,scale=float(lx)/float(nx)),lx,lx,lx)
					vol = fdecimate(vol, lx,lx,lx)
				else:  ERROR("local filter cannot be larger than input volume","user function",1)
			stat = Util.infomask(vol, mask, False)
			vol -= stat[0]
			Util.mul_scalar(vol, 1.0/stat[1])
		else:
			lx = 0
			locres = model_blank(1,1,1)
			vol = model_blank(1,1,1)
		lx = bcast_number_to_all(lx, source_node = 0)
		if( myid != 0 ):  mask = model_blank(lx,lx,lx)
		bcast_EMData_to_all(mask, myid, 0, comm=mpi_comm)
		from filter import filterlocal
		vol = filterlocal( locres, vol, mask, Tracker["falloff"], myid, 0, nproc)

		if myid == 0:
			if(lx < nx):
				from fundamentals import fpol
				vol = fpol(vol, nx,nx,nx)
			vol = threshold(vol)
			vol = filt_btwl(vol, 0.38, 0.5)#  This will have to be corrected.
			Util.mul_img(vol, mask3D)
			del mask3D
			# vol.write_image('toto%03d.hdf'%iter)
		else:
			vol = model_blank(nx,nx,nx)
	else:
		if myid == 0:
			#from utilities import write_text_file
			#write_text_file(rops_table(vol,1),"goo.txt")
			stat = Util.infomask(vol, mask3D, False)
			vol -= stat[0]
			Util.mul_scalar(vol, 1.0/stat[1])
			vol = threshold(vol)
			vol = filt_btwl(vol, 0.38, 0.5)#  This will have to be corrected.
			Util.mul_img(vol, mask3D)
			del mask3D
			# vol.write_image('toto%03d.hdf'%iter)
	# broadcast volume
	bcast_EMData_to_all(vol, myid, 0, comm=mpi_comm)
	#=========================================================================
	return vol
Beispiel #40
0
def main():

	def params_3D_2D_NEW(phi, theta, psi, s2x, s2y, mirror):
		if mirror:
			m = 1
			alpha, sx, sy, scalen = compose_transform2(0, s2x, s2y, 1.0, 540.0-psi, 0, 0, 1.0)
		else:
			m = 0
			alpha, sx, sy, scalen = compose_transform2(0, s2x, s2y, 1.0, 360.0-psi, 0, 0, 1.0)
		return  alpha, sx, sy, m
	
	progname = os.path.basename(sys.argv[0])
	usage = progname + " prj_stack  --ave2D= --var2D=  --ave3D= --var3D= --img_per_grp= --fl=0.2 --aa=0.1  --sym=symmetry --CTF"
	parser = OptionParser(usage, version=SPARXVERSION)

	parser.add_option("--ave2D",		type="string"	   ,	default=False,				help="write to the disk a stack of 2D averages")
	parser.add_option("--var2D",		type="string"	   ,	default=False,				help="write to the disk a stack of 2D variances")
	parser.add_option("--ave3D",		type="string"	   ,	default=False,				help="write to the disk reconstructed 3D average")
	parser.add_option("--var3D",		type="string"	   ,	default=False,				help="compute 3D variability (time consuming!)")
	parser.add_option("--img_per_grp",	type="int"         ,	default=10   ,				help="number of neighbouring projections")
	parser.add_option("--no_norm",		action="store_true",	default=False,				help="do not use normalization")
	parser.add_option("--radiusvar", 	type="int"         ,	default=-1   ,				help="radius for 3D var" )
	parser.add_option("--npad",			type="int"         ,	default=2    ,				help="number of time to pad the original images")
	parser.add_option("--sym" , 		type="string"      ,	default="c1" ,				help="symmetry")
	parser.add_option("--fl",			type="float"       ,	default=0.0  ,				help="stop-band frequency (Default - no filtration)")
	parser.add_option("--aa",			type="float"       ,	default=0.0  ,				help="fall off of the filter (Default - no filtration)")
	parser.add_option("--CTF",			action="store_true",	default=False,				help="use CFT correction")
	parser.add_option("--VERBOSE",		action="store_true",	default=False,				help="Long output for debugging")
	#parser.add_option("--MPI" , 		action="store_true",	default=False,				help="use MPI version")
	#parser.add_option("--radiuspca", 	type="int"         ,	default=-1   ,				help="radius for PCA" )
	#parser.add_option("--iter", 		type="int"         ,	default=40   ,				help="maximum number of iterations (stop criterion of reconstruction process)" )
	#parser.add_option("--abs", 			type="float"       ,	default=0.0  ,				help="minimum average absolute change of voxels' values (stop criterion of reconstruction process)" )
	#parser.add_option("--squ", 			type="float"       ,	default=0.0  ,				help="minimum average squared change of voxels' values (stop criterion of reconstruction process)" )
	parser.add_option("--VAR" , 		action="store_true",	default=False,				help="stack on input consists of 2D variances (Default False)")
	parser.add_option("--decimate",     type="float",           default=1.0,                 help="image decimate rate, a number large than 1. default is 1")
	parser.add_option("--window",       type="int",             default=0,                   help="reduce images to a small image size without changing pixel_size. Default value is zero.")
	#parser.add_option("--SND",			action="store_true",	default=False,				help="compute squared normalized differences (Default False)")
	parser.add_option("--nvec",			type="int"         ,	default=0    ,				help="number of eigenvectors, default = 0 meaning no PCA calculated")
	parser.add_option("--symmetrize",	action="store_true",	default=False,				help="Prepare input stack for handling symmetry (Default False)")
	
	(options,args) = parser.parse_args()
	#####
	from mpi import mpi_init, mpi_comm_rank, mpi_comm_size, mpi_recv, MPI_COMM_WORLD, MPI_TAG_UB
	from mpi import mpi_barrier, mpi_reduce, mpi_bcast, mpi_send, MPI_FLOAT, MPI_SUM, MPI_INT, MPI_MAX
	from applications import MPI_start_end
	from reconstruction import recons3d_em, recons3d_em_MPI
	from reconstruction	import recons3d_4nn_MPI, recons3d_4nn_ctf_MPI
	from utilities import print_begin_msg, print_end_msg, print_msg
	from utilities import read_text_row, get_image, get_im
	from utilities import bcast_EMData_to_all, bcast_number_to_all
	from utilities import get_symt

	#  This is code for handling symmetries by the above program.  To be incorporated. PAP 01/27/2015

	from EMAN2db import db_open_dict
	
	if options.symmetrize :
		try:
			sys.argv = mpi_init(len(sys.argv), sys.argv)
			try:	
				number_of_proc = mpi_comm_size(MPI_COMM_WORLD)
				if( number_of_proc > 1 ):
					ERROR("Cannot use more than one CPU for symmetry prepration","sx3dvariability",1)
			except:
				pass
		except:
			pass

		#  Input
		#instack = "Clean_NORM_CTF_start_wparams.hdf"
		#instack = "bdb:data"
		instack = args[0]
		sym = options.sym
		if( sym == "c1" ):
			ERROR("Thre is no need to symmetrize stack for C1 symmetry","sx3dvariability",1)

		if(instack[:4] !="bdb:"):
			stack = "bdb:data"
			delete_bdb(stack)
			cmdexecute("sxcpy.py  "+instack+"  "+stack)
		else:
			stack = instack

		qt = EMUtil.get_all_attributes(stack,'xform.projection')

		na = len(qt)
		ts = get_symt(sym)
		ks = len(ts)
		angsa = [None]*na
		for k in xrange(ks):
			delete_bdb("bdb:Q%1d"%k)
			cmdexecute("e2bdb.py  "+stack+"  --makevstack=bdb:Q%1d"%k)
			DB = db_open_dict("bdb:Q%1d"%k)
			for i in xrange(na):
				ut = qt[i]*ts[k]
				DB.set_attr(i, "xform.projection", ut)
				#bt = ut.get_params("spider")
				#angsa[i] = [round(bt["phi"],3)%360.0, round(bt["theta"],3)%360.0, bt["psi"], -bt["tx"], -bt["ty"]]
			#write_text_row(angsa, 'ptsma%1d.txt'%k)
			#cmdexecute("e2bdb.py  "+stack+"  --makevstack=bdb:Q%1d"%k)
			#cmdexecute("sxheader.py  bdb:Q%1d  --params=xform.projection  --import=ptsma%1d.txt"%(k,k))
			DB.close()
		delete_bdb("bdb:sdata")
		cmdexecute("e2bdb.py . --makevstack=bdb:sdata --filt=Q")
		#cmdexecute("ls  EMAN2DB/sdata*")
		a = get_im("bdb:sdata")
		a.set_attr("variabilitysymmetry",sym)
		a.write_image("bdb:sdata")


	else:

		sys.argv = mpi_init(len(sys.argv), sys.argv)
		myid     = mpi_comm_rank(MPI_COMM_WORLD)
		number_of_proc = mpi_comm_size(MPI_COMM_WORLD)
		main_node = 0

		if len(args) == 1:
			stack = args[0]
		else:
			print( "usage: " + usage)
			print( "Please run '" + progname + " -h' for detailed options")
			return 1

		t0 = time()
	
		# obsolete flags
		options.MPI = True
		options.nvec = 0
		options.radiuspca = -1
		options.iter = 40
		options.abs = 0.0
		options.squ = 0.0

		if options.fl > 0.0 and options.aa == 0.0:
			ERROR("Fall off has to be given for the low-pass filter", "sx3dvariability", 1, myid)
		if options.VAR and options.SND:
			ERROR("Only one of var and SND can be set!", "sx3dvariability", myid)
			exit()
		if options.VAR and (options.ave2D or options.ave3D or options.var2D): 
			ERROR("When VAR is set, the program cannot output ave2D, ave3D or var2D", "sx3dvariability", 1, myid)
			exit()
		#if options.SND and (options.ave2D or options.ave3D):
		#	ERROR("When SND is set, the program cannot output ave2D or ave3D", "sx3dvariability", 1, myid)
		#	exit()
		if options.nvec > 0 :
			ERROR("PCA option not implemented", "sx3dvariability", 1, myid)
			exit()
		if options.nvec > 0 and options.ave3D == None:
			ERROR("When doing PCA analysis, one must set ave3D", "sx3dvariability", myid=myid)
			exit()
		import string
		options.sym = options.sym.lower()
		 
		if global_def.CACHE_DISABLE:
			from utilities import disable_bdb_cache
			disable_bdb_cache()
		global_def.BATCH = True

		if myid == main_node:
			print_begin_msg("sx3dvariability")
			print_msg("%-70s:  %s\n"%("Input stack", stack))
	
		img_per_grp = options.img_per_grp
		nvec = options.nvec
		radiuspca = options.radiuspca

		symbaselen = 0
		if myid == main_node:
			nima = EMUtil.get_image_count(stack)
			img  = get_image(stack)
			nx   = img.get_xsize()
			ny   = img.get_ysize()
			if options.sym != "c1" :
				imgdata = get_im(stack)
				try:
					i = imgdata.get_attr("variabilitysymmetry")
					if(i != options.sym):
						ERROR("The symmetry provided does not agree with the symmetry of the input stack", "sx3dvariability", myid=myid)
				except:
					ERROR("Input stack is not prepared for symmetry, please follow instructions", "sx3dvariability", myid=myid)
				from utilities import get_symt
				i = len(get_symt(options.sym))
				if((nima/i)*i != nima):
					ERROR("The length of the input stack is incorrect for symmetry processing", "sx3dvariability", myid=myid)
				symbaselen = nima/i
			else:  symbaselen = nima
		else:
			nima = 0
			nx = 0
			ny = 0
		nima = bcast_number_to_all(nima)
		nx   = bcast_number_to_all(nx)
		ny   = bcast_number_to_all(ny)
		Tracker ={}
		Tracker["nx"]  =nx
		Tracker["ny"]  =ny
		Tracker["total_stack"]=nima
		if options.decimate==1.:
			if options.window !=0:
				nx = options.window
				ny = options.window
		else:
			if options.window ==0:
				nx = int(nx/options.decimate)
				ny = int(ny/options.decimate)
			else:
				nx = int(options.window/options.decimate)
				ny = nx
		symbaselen = bcast_number_to_all(symbaselen)
		if radiuspca == -1: radiuspca = nx/2-2

		if myid == main_node:
			print_msg("%-70s:  %d\n"%("Number of projection", nima))
		
		img_begin, img_end = MPI_start_end(nima, number_of_proc, myid)
		"""
		if options.SND:
			from projection		import prep_vol, prgs
			from statistics		import im_diff
			from utilities		import get_im, model_circle, get_params_proj, set_params_proj
			from utilities		import get_ctf, generate_ctf
			from filter			import filt_ctf
		
			imgdata = EMData.read_images(stack, range(img_begin, img_end))

			if options.CTF:
				vol = recons3d_4nn_ctf_MPI(myid, imgdata, 1.0, symmetry=options.sym, npad=options.npad, xysize=-1, zsize=-1)
			else:
				vol = recons3d_4nn_MPI(myid, imgdata, symmetry=options.sym, npad=options.npad, xysize=-1, zsize=-1)

			bcast_EMData_to_all(vol, myid)
			volft, kb = prep_vol(vol)

			mask = model_circle(nx/2-2, nx, ny)
			varList = []
			for i in xrange(img_begin, img_end):
				phi, theta, psi, s2x, s2y = get_params_proj(imgdata[i-img_begin])
				ref_prj = prgs(volft, kb, [phi, theta, psi, -s2x, -s2y])
				if options.CTF:
					ctf_params = get_ctf(imgdata[i-img_begin])
					ref_prj = filt_ctf(ref_prj, generate_ctf(ctf_params))
				diff, A, B = im_diff(ref_prj, imgdata[i-img_begin], mask)
				diff2 = diff*diff
				set_params_proj(diff2, [phi, theta, psi, s2x, s2y])
				varList.append(diff2)
			mpi_barrier(MPI_COMM_WORLD)
		"""
		if options.VAR:
			#varList = EMData.read_images(stack, range(img_begin, img_end))
			varList = []
			this_image = EMData()
			for index_of_particle in xrange(img_begin,img_end):
				this_image.read_image(stack,index_of_particle)
				varList.append(image_decimate_window_xform_ctf(img,options.decimate,options.window,options.CTF))
		else:
			from utilities		import bcast_number_to_all, bcast_list_to_all, send_EMData, recv_EMData
			from utilities		import set_params_proj, get_params_proj, params_3D_2D, get_params2D, set_params2D, compose_transform2
			from utilities		import model_blank, nearest_proj, model_circle
			from applications	import pca
			from statistics		import avgvar, avgvar_ctf, ccc
			from filter		    import filt_tanl
			from morphology		import threshold, square_root
			from projection 	import project, prep_vol, prgs
			from sets		    import Set

			if myid == main_node:
				t1 = time()
				proj_angles = []
				aveList = []
				tab = EMUtil.get_all_attributes(stack, 'xform.projection')
				for i in xrange(nima):
					t     = tab[i].get_params('spider')
					phi   = t['phi']
					theta = t['theta']
					psi   = t['psi']
					x     = theta
					if x > 90.0: x = 180.0 - x
					x = x*10000+psi
					proj_angles.append([x, t['phi'], t['theta'], t['psi'], i])
				t2 = time()
				print_msg("%-70s:  %d\n"%("Number of neighboring projections", img_per_grp))
				print_msg("...... Finding neighboring projections\n")
				if options.VERBOSE:
					print "Number of images per group: ", img_per_grp
					print "Now grouping projections"
				proj_angles.sort()

			proj_angles_list = [0.0]*(nima*4)
			if myid == main_node:
				for i in xrange(nima):
					proj_angles_list[i*4]   = proj_angles[i][1]
					proj_angles_list[i*4+1] = proj_angles[i][2]
					proj_angles_list[i*4+2] = proj_angles[i][3]
					proj_angles_list[i*4+3] = proj_angles[i][4]
			proj_angles_list = bcast_list_to_all(proj_angles_list, myid, main_node)
			proj_angles = []
			for i in xrange(nima):
				proj_angles.append([proj_angles_list[i*4], proj_angles_list[i*4+1], proj_angles_list[i*4+2], int(proj_angles_list[i*4+3])])
			del proj_angles_list

			proj_list, mirror_list = nearest_proj(proj_angles, img_per_grp, range(img_begin, img_end))

			all_proj = Set()
			for im in proj_list:
				for jm in im:
					all_proj.add(proj_angles[jm][3])

			all_proj = list(all_proj)
			if options.VERBOSE:
				print "On node %2d, number of images needed to be read = %5d"%(myid, len(all_proj))

			index = {}
			for i in xrange(len(all_proj)): index[all_proj[i]] = i
			mpi_barrier(MPI_COMM_WORLD)

			if myid == main_node:
				print_msg("%-70s:  %.2f\n"%("Finding neighboring projections lasted [s]", time()-t2))
				print_msg("%-70s:  %d\n"%("Number of groups processed on the main node", len(proj_list)))
				if options.VERBOSE:
					print "Grouping projections took: ", (time()-t2)/60	, "[min]"
					print "Number of groups on main node: ", len(proj_list)
			mpi_barrier(MPI_COMM_WORLD)

			if myid == main_node:
				print_msg("...... calculating the stack of 2D variances \n")
				if options.VERBOSE:
					print "Now calculating the stack of 2D variances"

			proj_params = [0.0]*(nima*5)
			aveList = []
			varList = []				
			if nvec > 0:
				eigList = [[] for i in xrange(nvec)]

			if options.VERBOSE: 	print "Begin to read images on processor %d"%(myid)
			ttt = time()
			#imgdata = EMData.read_images(stack, all_proj)
			img     = EMData()
			imgdata = []
			for index_of_proj in xrange(len(all_proj)):
				img.read_image(stack, all_proj[index_of_proj])
				dmg = image_decimate_window_xform_ctf(img,options.decimate,options.window,options.CTF)
				#print dmg.get_xsize(), "init"
				imgdata.append(dmg)
			if options.VERBOSE:
				print "Reading images on processor %d done, time = %.2f"%(myid, time()-ttt)
				print "On processor %d, we got %d images"%(myid, len(imgdata))
			mpi_barrier(MPI_COMM_WORLD)

			'''	
			imgdata2 = EMData.read_images(stack, range(img_begin, img_end))
			if options.fl > 0.0:
				for k in xrange(len(imgdata2)):
					imgdata2[k] = filt_tanl(imgdata2[k], options.fl, options.aa)
			if options.CTF:
				vol = recons3d_4nn_ctf_MPI(myid, imgdata2, 1.0, symmetry=options.sym, npad=options.npad, xysize=-1, zsize=-1)
			else:
				vol = recons3d_4nn_MPI(myid, imgdata2, symmetry=options.sym, npad=options.npad, xysize=-1, zsize=-1)
			if myid == main_node:
				vol.write_image("vol_ctf.hdf")
				print_msg("Writing to the disk volume reconstructed from averages as		:  %s\n"%("vol_ctf.hdf"))
			del vol, imgdata2
			mpi_barrier(MPI_COMM_WORLD)
			'''
			from applications import prepare_2d_forPCA
			from utilities import model_blank
			for i in xrange(len(proj_list)):
				ki = proj_angles[proj_list[i][0]][3]
				if ki >= symbaselen:  continue
				mi = index[ki]
				phiM, thetaM, psiM, s2xM, s2yM = get_params_proj(imgdata[mi])

				grp_imgdata = []
				for j in xrange(img_per_grp):
					mj = index[proj_angles[proj_list[i][j]][3]]
					phi, theta, psi, s2x, s2y = get_params_proj(imgdata[mj])
					alpha, sx, sy, mirror = params_3D_2D_NEW(phi, theta, psi, s2x, s2y, mirror_list[i][j])
					if thetaM <= 90:
						if mirror == 0:  alpha, sx, sy, scale = compose_transform2(alpha, sx, sy, 1.0, phiM-phi, 0.0, 0.0, 1.0)
						else:            alpha, sx, sy, scale = compose_transform2(alpha, sx, sy, 1.0, 180-(phiM-phi), 0.0, 0.0, 1.0)
					else:
						if mirror == 0:  alpha, sx, sy, scale = compose_transform2(alpha, sx, sy, 1.0, -(phiM-phi), 0.0, 0.0, 1.0)
						else:            alpha, sx, sy, scale = compose_transform2(alpha, sx, sy, 1.0, -(180-(phiM-phi)), 0.0, 0.0, 1.0)
					set_params2D(imgdata[mj], [alpha, sx, sy, mirror, 1.0])
					grp_imgdata.append(imgdata[mj])
					#print grp_imgdata[j].get_xsize(), imgdata[mj].get_xsize()

				if not options.no_norm:
					#print grp_imgdata[j].get_xsize()
					mask = model_circle(nx/2-2, nx, nx)
					for k in xrange(img_per_grp):
						ave, std, minn, maxx = Util.infomask(grp_imgdata[k], mask, False)
						grp_imgdata[k] -= ave
						grp_imgdata[k] /= std
					del mask

				if options.fl > 0.0:
					from filter import filt_ctf, filt_table
					from fundamentals import fft, window2d
					nx2 = 2*nx
					ny2 = 2*ny
					if options.CTF:
						from utilities import pad
						for k in xrange(img_per_grp):
							grp_imgdata[k] = window2d(fft( filt_tanl( filt_ctf(fft(pad(grp_imgdata[k], nx2, ny2, 1,0.0)), grp_imgdata[k].get_attr("ctf"), binary=1), options.fl, options.aa) ),nx,ny)
							#grp_imgdata[k] = window2d(fft( filt_table( filt_tanl( filt_ctf(fft(pad(grp_imgdata[k], nx2, ny2, 1,0.0)), grp_imgdata[k].get_attr("ctf"), binary=1), options.fl, options.aa), fifi) ),nx,ny)
							#grp_imgdata[k] = filt_tanl(grp_imgdata[k], options.fl, options.aa)
					else:
						for k in xrange(img_per_grp):
							grp_imgdata[k] = filt_tanl( grp_imgdata[k], options.fl, options.aa)
							#grp_imgdata[k] = window2d(fft( filt_table( filt_tanl( filt_ctf(fft(pad(grp_imgdata[k], nx2, ny2, 1,0.0)), grp_imgdata[k].get_attr("ctf"), binary=1), options.fl, options.aa), fifi) ),nx,ny)
							#grp_imgdata[k] = filt_tanl(grp_imgdata[k], options.fl, options.aa)
				else:
					from utilities import pad, read_text_file
					from filter import filt_ctf, filt_table
					from fundamentals import fft, window2d
					nx2 = 2*nx
					ny2 = 2*ny
					if options.CTF:
						from utilities import pad
						for k in xrange(img_per_grp):
							grp_imgdata[k] = window2d( fft( filt_ctf(fft(pad(grp_imgdata[k], nx2, ny2, 1,0.0)), grp_imgdata[k].get_attr("ctf"), binary=1) ) , nx,ny)
							#grp_imgdata[k] = window2d(fft( filt_table( filt_tanl( filt_ctf(fft(pad(grp_imgdata[k], nx2, ny2, 1,0.0)), grp_imgdata[k].get_attr("ctf"), binary=1), options.fl, options.aa), fifi) ),nx,ny)
							#grp_imgdata[k] = filt_tanl(grp_imgdata[k], options.fl, options.aa)

				'''
				if i < 10 and myid == main_node:
					for k in xrange(10):
						grp_imgdata[k].write_image("grp%03d.hdf"%i, k)
				'''
				"""
				if myid == main_node and i==0:
					for pp in xrange(len(grp_imgdata)):
						grp_imgdata[pp].write_image("pp.hdf", pp)
				"""
				ave, grp_imgdata = prepare_2d_forPCA(grp_imgdata)
				"""
				if myid == main_node and i==0:
					for pp in xrange(len(grp_imgdata)):
						grp_imgdata[pp].write_image("qq.hdf", pp)
				"""

				var = model_blank(nx,ny)
				for q in grp_imgdata:  Util.add_img2( var, q )
				Util.mul_scalar( var, 1.0/(len(grp_imgdata)-1))
				# Switch to std dev
				var = square_root(threshold(var))
				#if options.CTF:	ave, var = avgvar_ctf(grp_imgdata, mode="a")
				#else:	            ave, var = avgvar(grp_imgdata, mode="a")
				"""
				if myid == main_node:
					ave.write_image("avgv.hdf",i)
					var.write_image("varv.hdf",i)
				"""
			
				set_params_proj(ave, [phiM, thetaM, 0.0, 0.0, 0.0])
				set_params_proj(var, [phiM, thetaM, 0.0, 0.0, 0.0])

				aveList.append(ave)
				varList.append(var)

				if options.VERBOSE:
					print "%5.2f%% done on processor %d"%(i*100.0/len(proj_list), myid)
				if nvec > 0:
					eig = pca(input_stacks=grp_imgdata, subavg="", mask_radius=radiuspca, nvec=nvec, incore=True, shuffle=False, genbuf=True)
					for k in xrange(nvec):
						set_params_proj(eig[k], [phiM, thetaM, 0.0, 0.0, 0.0])
						eigList[k].append(eig[k])
					"""
					if myid == 0 and i == 0:
						for k in xrange(nvec):
							eig[k].write_image("eig.hdf", k)
					"""

			del imgdata
			#  To this point, all averages, variances, and eigenvectors are computed

			if options.ave2D:
				from fundamentals import fpol
				if myid == main_node:
					km = 0
					for i in xrange(number_of_proc):
						if i == main_node :
							for im in xrange(len(aveList)):
								aveList[im].write_image(options.ave2D, km)
								km += 1
						else:
							nl = mpi_recv(1, MPI_INT, i, MPI_TAG_UB, MPI_COMM_WORLD)
							nl = int(nl[0])
							for im in xrange(nl):
								ave = recv_EMData(i, im+i+70000)
								"""
								nm = mpi_recv(1, MPI_INT, i, MPI_TAG_UB, MPI_COMM_WORLD)
								nm = int(nm[0])
								members = mpi_recv(nm, MPI_INT, i, MPI_TAG_UB, MPI_COMM_WORLD)
								ave.set_attr('members', map(int, members))
								members = mpi_recv(nm, MPI_FLOAT, i, MPI_TAG_UB, MPI_COMM_WORLD)
								ave.set_attr('pix_err', map(float, members))
								members = mpi_recv(3, MPI_FLOAT, i, MPI_TAG_UB, MPI_COMM_WORLD)
								ave.set_attr('refprojdir', map(float, members))
								"""
								tmpvol=fpol(ave, Tracker["nx"],Tracker["nx"],Tracker["nx"])								
								tmpvol.write_image(options.ave2D, km)
								km += 1
				else:
					mpi_send(len(aveList), 1, MPI_INT, main_node, MPI_TAG_UB, MPI_COMM_WORLD)
					for im in xrange(len(aveList)):
						send_EMData(aveList[im], main_node,im+myid+70000)
						"""
						members = aveList[im].get_attr('members')
						mpi_send(len(members), 1, MPI_INT, main_node, MPI_TAG_UB, MPI_COMM_WORLD)
						mpi_send(members, len(members), MPI_INT, main_node, MPI_TAG_UB, MPI_COMM_WORLD)
						members = aveList[im].get_attr('pix_err')
						mpi_send(members, len(members), MPI_FLOAT, main_node, MPI_TAG_UB, MPI_COMM_WORLD)
						try:
							members = aveList[im].get_attr('refprojdir')
							mpi_send(members, 3, MPI_FLOAT, main_node, MPI_TAG_UB, MPI_COMM_WORLD)
						except:
							mpi_send([-999.0,-999.0,-999.0], 3, MPI_FLOAT, main_node, MPI_TAG_UB, MPI_COMM_WORLD)
						"""

			if options.ave3D:
				from fundamentals import fpol
				if options.VERBOSE:
					print "Reconstructing 3D average volume"
				ave3D = recons3d_4nn_MPI(myid, aveList, symmetry=options.sym, npad=options.npad)
				bcast_EMData_to_all(ave3D, myid)
				if myid == main_node:
					ave3D=fpol(ave3D,Tracker["nx"],Tracker["nx"],Tracker["nx"])
					ave3D.write_image(options.ave3D)
					print_msg("%-70s:  %s\n"%("Writing to the disk volume reconstructed from averages as", options.ave3D))
			del ave, var, proj_list, stack, phi, theta, psi, s2x, s2y, alpha, sx, sy, mirror, aveList

			if nvec > 0:
				for k in xrange(nvec):
					if options.VERBOSE:
						print "Reconstruction eigenvolumes", k
					cont = True
					ITER = 0
					mask2d = model_circle(radiuspca, nx, nx)
					while cont:
						#print "On node %d, iteration %d"%(myid, ITER)
						eig3D = recons3d_4nn_MPI(myid, eigList[k], symmetry=options.sym, npad=options.npad)
						bcast_EMData_to_all(eig3D, myid, main_node)
						if options.fl > 0.0:
							eig3D = filt_tanl(eig3D, options.fl, options.aa)
						if myid == main_node:
							eig3D.write_image("eig3d_%03d.hdf"%k, ITER)
						Util.mul_img( eig3D, model_circle(radiuspca, nx, nx, nx) )
						eig3Df, kb = prep_vol(eig3D)
						del eig3D
						cont = False
						icont = 0
						for l in xrange(len(eigList[k])):
							phi, theta, psi, s2x, s2y = get_params_proj(eigList[k][l])
							proj = prgs(eig3Df, kb, [phi, theta, psi, s2x, s2y])
							cl = ccc(proj, eigList[k][l], mask2d)
							if cl < 0.0:
								icont += 1
								cont = True
								eigList[k][l] *= -1.0
						u = int(cont)
						u = mpi_reduce([u], 1, MPI_INT, MPI_MAX, main_node, MPI_COMM_WORLD)
						icont = mpi_reduce([icont], 1, MPI_INT, MPI_SUM, main_node, MPI_COMM_WORLD)

						if myid == main_node:
							u = int(u[0])
							print " Eigenvector: ",k," number changed ",int(icont[0])
						else: u = 0
						u = bcast_number_to_all(u, main_node)
						cont = bool(u)
						ITER += 1

					del eig3Df, kb
					mpi_barrier(MPI_COMM_WORLD)
				del eigList, mask2d

			if options.ave3D: del ave3D
			if options.var2D:
				from fundamentals import fpol 
				if myid == main_node:
					km = 0
					for i in xrange(number_of_proc):
						if i == main_node :
							for im in xrange(len(varList)):
								tmpvol=fpol(varList[im], Tracker["nx"], Tracker["nx"],1)
								tmpvol.write_image(options.var2D, km)
								km += 1
						else:
							nl = mpi_recv(1, MPI_INT, i, MPI_TAG_UB, MPI_COMM_WORLD)
							nl = int(nl[0])
							for im in xrange(nl):
								ave = recv_EMData(i, im+i+70000)
								tmpvol=fpol(ave, Tracker["nx"], Tracker["nx"],1)
								tmpvol.write_image(options.var2D, km)
								km += 1
				else:
					mpi_send(len(varList), 1, MPI_INT, main_node, MPI_TAG_UB, MPI_COMM_WORLD)
					for im in xrange(len(varList)):
						send_EMData(varList[im], main_node, im+myid+70000)#  What with the attributes??

			mpi_barrier(MPI_COMM_WORLD)

		if  options.var3D:
			if myid == main_node and options.VERBOSE:
				print "Reconstructing 3D variability volume"

			t6 = time()
			radiusvar = options.radiusvar
			if( radiusvar < 0 ):  radiusvar = nx//2 -3
			res = recons3d_4nn_MPI(myid, varList, symmetry=options.sym, npad=options.npad)
			#res = recons3d_em_MPI(varList, vol_stack, options.iter, radiusvar, options.abs, True, options.sym, options.squ)
			if myid == main_node:
				from fundamentals import fpol
				res =fpol(res, Tracker["nx"], Tracker["nx"], Tracker["nx"])
				res.write_image(options.var3D)

			if myid == main_node:
				print_msg("%-70s:  %.2f\n"%("Reconstructing 3D variability took [s]", time()-t6))
				if options.VERBOSE:
					print "Reconstruction took: %.2f [min]"%((time()-t6)/60)

			if myid == main_node:
				print_msg("%-70s:  %.2f\n"%("Total time for these computations [s]", time()-t0))
				if options.VERBOSE:
					print "Total time for these computations: %.2f [min]"%((time()-t0)/60)
				print_end_msg("sx3dvariability")

		global_def.BATCH = False

		from mpi import mpi_finalize
		mpi_finalize()
Beispiel #41
0
#!/usr/bin/env python
#
# This program shows how to use mpi_comm_split
#
import numpy
from numpy import *
import mpi
import sys
import math

#print "before",len(sys.argv),sys.argv
sys.argv = mpi.mpi_init(len(sys.argv), sys.argv)
#print "after ",len(sys.argv),sys.argv
myid = mpi.mpi_comm_rank(mpi.MPI_COMM_WORLD)
numnodes = mpi.mpi_comm_size(mpi.MPI_COMM_WORLD)
print "hello from ", myid, " of ", numnodes

color = myid % 2
new_comm = mpi.mpi_comm_split(mpi.MPI_COMM_WORLD, color, myid)
new_id = mpi.mpi_comm_rank(new_comm)
new_nodes = mpi.mpi_comm_size(new_comm)
zero_one = -1
if new_id == 0:
    zero_one = color

zero_one = mpi.mpi_bcast(zero_one, 1, mpi.MPI_INT, 0, new_comm)
if zero_one == 0:
    print myid, " part of even processor communicator ", new_id

if zero_one == 1:
    print myid, " part of odd processor communicator ", new_id
Beispiel #42
0
def main():
	progname = os.path.basename(sys.argv[0])
	usage = progname + " stack outdir <maskfile> --ir=inner_radius --ou=outer_radius --rs=ring_step --xr=x_range --yr=y_range --ts=translation_step --dst=delta --center=center --maxit=max_iteration --CTF --snr=SNR --Fourvar=Fourier_variance --Ng=group_number --Function=user_function_name --CUDA --GPUID --MPI"
	parser = OptionParser(usage,version=SPARXVERSION)
	parser.add_option("--ir",       type="float",  default=1,             help="inner radius for rotational correlation > 0 (set to 1)")
	parser.add_option("--ou",       type="float",  default=-1,            help="outer radius for rotational correlation < nx/2-1 (set to the radius of the particle)")
	parser.add_option("--rs",       type="float",  default=1,             help="step between rings in rotational correlation > 0 (set to 1)" ) 
	parser.add_option("--xr",       type="string", default="4 2 1 1",     help="range for translation search in x direction, search is +/xr ")
	parser.add_option("--yr",       type="string", default="-1",          help="range for translation search in y direction, search is +/yr ")
	parser.add_option("--ts",       type="string", default="2 1 0.5 0.25",help="step of translation search in both directions")
	parser.add_option("--nomirror", action="store_true", default=False,   help="Disable checking mirror orientations of images (default False)")
	parser.add_option("--dst",      type="float",  default=0.0,           help="delta")
	parser.add_option("--center",   type="float",  default=-1,            help="-1.average center method; 0.not centered; 1.phase approximation; 2.cc with Gaussian function; 3.cc with donut-shaped image 4.cc with user-defined reference 5.cc with self-rotated average")
	parser.add_option("--maxit",    type="float",  default=0,             help="maximum number of iterations (0 means the maximum iterations is 10, but it will automatically stop should the criterion falls")
	parser.add_option("--CTF",      action="store_true", default=False,   help="use CTF correction during alignment")
	parser.add_option("--snr",      type="float",  default=1.0,           help="signal-to-noise ratio of the data (set to 1.0)")
	parser.add_option("--Fourvar",  action="store_true", default=False,   help="compute Fourier variance")
	#parser.add_option("--Ng",       type="int",          default=-1,      help="number of groups in the new CTF filteration")
	parser.add_option("--function", type="string",       default="ref_ali2d",  help="name of the reference preparation function (default ref_ali2d)")
	#parser.add_option("--CUDA",     action="store_true", default=False,   help="use CUDA program")
	#parser.add_option("--GPUID",    type="string",    default="",         help="ID of GPUs available")
	parser.add_option("--MPI",      action="store_true", default=False,   help="use MPI version ")
	parser.add_option("--rotational", action="store_true", default=False, help="rotational alignment with optional limited in-plane angle, the parameters are: ir, ou, rs, psi_max, mode(F or H), maxit, orient, randomize")
	parser.add_option("--psi_max",  type="float",        default=180.0,   help="psi_max")
	parser.add_option("--mode",     type="string",       default="F",     help="Full or Half rings, default F")
	parser.add_option("--randomize",action="store_true", default=False,   help="randomize initial rotations (suboption of friedel, default False)")
	parser.add_option("--orient",   action="store_true", default=False,   help="orient images such that the average is symmetric about x-axis, for layer lines (suboption of friedel, default False)")
	parser.add_option("--template", type="string",       default=None,    help="2D alignment will be initialized using the template provided (only non-MPI version, default None)")
	parser.add_option("--random_method",   type="string", default="",   help="use SHC or SCF (default standard method)")

	(options, args) = parser.parse_args()

	if len(args) < 2 or len(args) > 3:
		sxprint( "Usage: " + usage )
		sxprint( "Please run \'" + progname + " -h\' for detailed options" )
		sp_global_def.ERROR( "Invalid number of parameters used. Please see usage information above." )
		return

	elif(options.rotational):
		from sp_applications import ali2d_rotationaltop
		sp_global_def.BATCH = True
		ali2d_rotationaltop(args[1], args[0], options.randomize, options.orient, options.ir, options.ou, options.rs, options.psi_max, options.mode, options.maxit)
	else:
		if args[1] == 'None': 
			outdir = None
		else:		          
			outdir = args[1]

		if len(args) == 2: 
			mask = None
		else:              
			mask = args[2]
		
		if sp_global_def.CACHE_DISABLE:
			from sp_utilities import disable_bdb_cache
			disable_bdb_cache()
		
		sp_global_def.BATCH = True
		if  options.MPI:
			from sp_applications import ali2d_base
			from mpi import mpi_comm_size, mpi_comm_rank, MPI_COMM_WORLD

			number_of_proc = mpi_comm_size(MPI_COMM_WORLD)
			myid = mpi_comm_rank(MPI_COMM_WORLD)
			main_node = 0

			if(myid == main_node):
				import subprocess
				from sp_logger import Logger, BaseLogger_Files
				#  Create output directory
				log = Logger(BaseLogger_Files())
				log.prefix = os.path.join(outdir)
				cmd = "mkdir "+log.prefix
				outcome = subprocess.call(cmd, shell=True)
				log.prefix += "/"
			else:
				outcome = 0
				log = None
			from sp_utilities       import bcast_number_to_all
			outcome  = bcast_number_to_all(outcome, source_node = main_node)
			if(outcome == 1):
				sp_global_def.ERROR( "Output directory exists, please change the name and restart the program", myid=myid )

			dummy = ali2d_base(args[0], outdir, mask, options.ir, options.ou, options.rs, options.xr, options.yr, \
				options.ts, options.nomirror, options.dst, \
				options.center, options.maxit, options.CTF, options.snr, options.Fourvar, \
				options.function, random_method = options.random_method, log = log, \
				number_of_proc = number_of_proc, myid = myid, main_node = main_node, mpi_comm = MPI_COMM_WORLD,\
				write_headers = True)
		else:
			sxprint( " Non-MPI is no more in use, try MPI option, please." )
			"""
			from sp_applications import ali2d
			ali2d(args[0], outdir, mask, options.ir, options.ou, options.rs, options.xr, options.yr, \
				options.ts, options.nomirror, options.dst, \
				options.center, options.maxit, options.CTF, options.snr, options.Fourvar, \
				-1, options.function, False, "", options.MPI, \
				options.template, random_method = options.random_method)
	    	"""
		sp_global_def.BATCH = False
Beispiel #43
0
def findHsym_MPI(vol,dp,dphi,apix,rmax,rmin,myid,main_node):
	from alignment import helios7
	from mpi import mpi_comm_size, mpi_recv, mpi_send, MPI_TAG_UB, MPI_COMM_WORLD, MPI_FLOAT

	nproc = mpi_comm_size(MPI_COMM_WORLD)	

	ndp=12
	ndphi=12
	dp_step=0.05
	dphi_step=0.05

	nlprms = (2*ndp+1)*(2*ndphi+1)
	#make sure num of helical search is more than num of processors
	if nlprms < nproc:
		mindp = (nproc/4)+1
		ndp,ndphi = mindp,mindp
	if myid == main_node:
		lprms = []
		for i in xrange(-ndp,ndp+1,1):
			for j in xrange(-ndphi,ndphi+1,1):
				lprms.append( dp   + i*dp_step)
				lprms.append( dphi + j*dphi_step)

		recvpara = []
		for im in xrange(nproc):
			helic_ib,helic_ie= MPI_start_end(nlprms, nproc, im)
			recvpara.append(helic_ib )
			recvpara.append(helic_ie )

	para_start, para_end = MPI_start_end(nlprms, nproc, myid)

	list_dps     = [0.0]*((para_end-para_start)*2)
	list_fvalues = [-1.0]*((para_end-para_start)*1)

	if myid == main_node:
		for n in xrange(nproc):
			if n!=main_node: mpi_send(lprms[2*recvpara[2*n]:2*recvpara[2*n+1]], 2*(recvpara[2*n+1]-recvpara[2*n]), MPI_FLOAT, n, MPI_TAG_UB, MPI_COMM_WORLD)
			else:    list_dps = lprms[2*recvpara[2*0]:2*recvpara[2*0+1]]
	else:
		list_dps = mpi_recv((para_end-para_start)*2, MPI_FLOAT, main_node, MPI_TAG_UB, MPI_COMM_WORLD)

	list_dps = map(float, list_dps)

	local_pos = [0.0, 0.0, -1.0e20]
	fract = 0.67
	for i in xrange(para_end-para_start):
		fvalue = helios7(vol, apix, list_dps[i*2], list_dps[i*2+1], fract, rmax, rmin)
		if(fvalue >= local_pos[2]):
			local_pos = [list_dps[i*2], list_dps[i*2+1], fvalue ]
	if myid == main_node:
		list_return = [0.0]*(3*nproc)
		for n in xrange(nproc):
			if n != main_node:
				list_return[3*n:3*n+3] = mpi_recv(3,MPI_FLOAT, n, MPI_TAG_UB, MPI_COMM_WORLD)
			else:
				list_return[3*main_node:3*main_node+3] = local_pos[:]
	else:
		mpi_send(local_pos, 3, MPI_FLOAT, main_node, MPI_TAG_UB, MPI_COMM_WORLD)

	if myid == main_node:
		maxvalue = list_return[2]
		for i in xrange(nproc):
			if( list_return[i*3+2] >= maxvalue ):
				maxvalue = list_return[i*3+2]
				dp       = list_return[i*3+0]
				dphi     = list_return[i*3+1]
		dp   = float(dp)
		dphi = float(dphi)

		return dp,dphi
	return None,None
def rec3D_MPI(data, snr, symmetry, mask3D, fsc_curve, myid, main_node = 0, rstep = 1.0, odd_start=0, eve_start=1, finfo=None, index=-1, npad = 4, hparams=None):
	'''
	  This function is to be called within an MPI program to do a reconstruction on a dataset kept 
          in the memory, computes reconstruction and through odd-even, in order to get the resolution
	'''
	import os
	from statistics import fsc_mask
	from utilities  import model_blank, reduce_EMData_to_root, get_image, send_EMData, recv_EMData
	from random     import randint
	from mpi        import mpi_comm_size, mpi_comm_rank, MPI_COMM_WORLD
	nproc = mpi_comm_size(MPI_COMM_WORLD)
	
	if nproc==1:
		assert main_node==0
		main_node_odd = main_node
		main_node_eve = main_node
		main_node_all = main_node
	elif nproc==2:
		main_node_odd = main_node
		main_node_eve = (main_node+1)%2
		main_node_all = main_node

		tag_voleve     = 1000
		tag_fftvol_eve = 1001
		tag_weight_eve = 1002
	else:
		#spread CPUs between different nodes to save memory
		main_node_odd = main_node
		main_node_eve = (int(main_node)+nproc-1)%int(nproc)
		main_node_all = (int(main_node)+nproc//2)%int(nproc)

		tag_voleve     = 1000
		tag_fftvol_eve = 1001
		tag_weight_eve = 1002

		tag_fftvol_odd = 1003
		tag_weight_odd = 1004
		tag_volall     = 1005


        if index !=-1 :
		grpdata = []
		for i in xrange( len(data) ):
		    if data[i].get_attr( 'group' ) == index:
		    	    grpdata.append( data[i] )
        	imgdata = grpdata
        else:
		imgdata = data
	nx = get_image_size( imgdata, myid )
	if nx==0:
		ERROR("Warning: no images were given for reconstruction, this usually means there is an empty group, returning empty volume","rec3D",0)
		return model_blank( 2, 2, 2 ), None, model_blank(nx,nx,nx), model_blank(nx,nx,nx)
	
	fftvol_odd_file,weight_odd_file = prepare_recons_ctf(nx, imgdata, snr, symmetry, myid, main_node_odd, odd_start, 2, finfo, npad)
	fftvol_eve_file,weight_eve_file = prepare_recons_ctf(nx, imgdata, snr, symmetry, myid, main_node_eve, eve_start, 2, finfo, npad)
	del imgdata

	if nproc == 1:
		fftvol = get_image(fftvol_odd_file)
		weight = get_image(weight_odd_file)
		volodd = recons_ctf_from_fftvol(nx, fftvol, weight, snr, symmetry, npad)

		fftvol = get_image(fftvol_eve_file)
		weight = get_image(weight_eve_file)
		voleve = recons_ctf_from_fftvol(nx, fftvol, weight, snr, symmetry, npad)
                
		fscdat = fsc_mask( volodd, voleve, mask3D, rstep, fsc_curve)

		fftvol = get_image( fftvol_odd_file )
		fftvol_tmp = get_image(fftvol_eve_file)
		fftvol += fftvol_tmp
		fftvol_tmp = None

		weight = get_image( weight_odd_file )
		weight_tmp = get_image(weight_eve_file)
		weight += weight_tmp
		weight_tmp = None

		volall = recons_ctf_from_fftvol(nx, fftvol, weight, snr, symmetry, npad)

		# if helical, find & apply symmetry to volume
		if hparams is not None:
			volodd,voleve,volall = hsymVols(volodd,voleve,volall,hparams)
		fscdat = fsc_mask( volodd, voleve, mask3D, rstep, fsc_curve)

		os.system( "rm -f " + fftvol_odd_file + " " + weight_odd_file )
		os.system( "rm -f " + fftvol_eve_file + " " + weight_eve_file )
		return volall,fscdat,volodd,voleve
  
	if nproc == 2:
		if myid == main_node_odd:
			fftvol = get_image( fftvol_odd_file )
			weight = get_image( weight_odd_file )
			volodd = recons_ctf_from_fftvol(nx, fftvol, weight, snr, symmetry, npad)
			voleve = recv_EMData(main_node_eve, tag_voleve)
			fscdat = fsc_mask( volodd, voleve, mask3D, rstep, fsc_curve)
		else:
			assert myid == main_node_eve
			fftvol = get_image( fftvol_eve_file )
			weight = get_image( weight_eve_file )
			voleve = recons_ctf_from_fftvol(nx, fftvol, weight, snr, symmetry, npad)
			send_EMData(voleve, main_node_odd, tag_voleve)

		if myid == main_node_odd:
			fftvol = get_image( fftvol_odd_file )
			fftvol_tmp = recv_EMData( main_node_eve, tag_fftvol_eve )
			fftvol += fftvol_tmp
			fftvol_tmp = None

			weight = get_image( weight_odd_file )
			weight_tmp = recv_EMData( main_node_eve, tag_weight_eve )
			weight += weight_tmp
			weight_tmp = None
			volall = recons_ctf_from_fftvol(nx, fftvol, weight, snr, symmetry, npad)

			# if helical, find & apply symmetry to volume
			if hparams is not None:
				volodd,voleve,volall = hsymVols(volodd,voleve,volall,hparams)
			fscdat = fsc_mask( volodd, voleve, mask3D, rstep, fsc_curve)

			os.system( "rm -f " + fftvol_odd_file + " " + weight_odd_file ) 
			os.system( "rm -f " + fftvol_eve_file + " " + weight_eve_file ) 
			return volall,fscdat,volodd,voleve
		else:
			assert myid == main_node_eve
			fftvol = get_image( fftvol_eve_file )
			send_EMData(fftvol, main_node_odd, tag_fftvol_eve )

			weight = get_image( weight_eve_file )
			send_EMData(weight, main_node_odd, tag_weight_eve )
			os.system( "rm -f " + fftvol_eve_file + " " + weight_eve_file )
			return model_blank(nx,nx,nx), None, model_blank(nx,nx,nx), model_blank(nx,nx,nx)

	# cases from all other number of processors situations
	if myid == main_node_odd:
		fftvol = get_image( fftvol_odd_file )
		send_EMData(fftvol, main_node_eve, tag_fftvol_odd )

		if not(finfo is None):
			finfo.write("fftvol odd sent\n")
			finfo.flush()

		weight = get_image( weight_odd_file )
		send_EMData(weight, main_node_all, tag_weight_odd )

		if not(finfo is None):
			finfo.write("weight odd sent\n")
			finfo.flush()

		volodd = recons_ctf_from_fftvol(nx, fftvol, weight, snr, symmetry, npad)
		del fftvol, weight
		voleve = recv_EMData(main_node_eve, tag_voleve)
		fscdat = fsc_mask(volodd, voleve, mask3D, rstep, fsc_curve)
		volall = recv_EMData(main_node_all, tag_volall)

		# if helical, find & apply symmetry to volume
		if hparams is not None:
			volodd,voleve,volall = hsymVols(volodd,voleve,volall,hparams)
		fscdat = fsc_mask( volodd, voleve, mask3D, rstep, fsc_curve)

		os.system( "rm -f " + fftvol_odd_file + " " + weight_odd_file );
		return volall,fscdat,volodd,voleve

	if myid == main_node_eve:
		ftmp = recv_EMData(main_node_odd, tag_fftvol_odd)
		fftvol = get_image( fftvol_eve_file )
		Util.add_img( ftmp, fftvol )
		send_EMData(ftmp, main_node_all, tag_fftvol_eve )
		del ftmp

		weight = get_image( weight_eve_file )
		send_EMData(weight, main_node_all, tag_weight_eve )

		voleve = recons_ctf_from_fftvol(nx, fftvol, weight, snr, symmetry, npad)
		send_EMData(voleve, main_node_odd, tag_voleve)
		os.system( "rm -f " + fftvol_eve_file + " " + weight_eve_file );

		return model_blank(nx,nx,nx), None, model_blank(nx,nx,nx), model_blank(nx,nx,nx)


	if myid == main_node_all:
		fftvol = recv_EMData(main_node_eve, tag_fftvol_eve)
		if not(finfo is None):
			finfo.write( "fftvol odd received\n" )
			finfo.flush()

		weight = recv_EMData(main_node_odd, tag_weight_odd)
		weight_tmp = recv_EMData(main_node_eve, tag_weight_eve)
		Util.add_img( weight, weight_tmp )
		weight_tmp = None

		volall = recons_ctf_from_fftvol(nx, fftvol, weight, snr, symmetry, npad)
		send_EMData(volall, main_node_odd, tag_volall)

		return model_blank(nx,nx,nx),None, model_blank(nx,nx,nx), model_blank(nx,nx,nx)

        return model_blank(nx,nx,nx),None, model_blank(nx,nx,nx), model_blank(nx,nx,nx)
Beispiel #45
0
#!/usr/bin/env /usr/bin/python
import numpy
from numpy import *
import mpi
import sys
from time import sleep

sys.argv = mpi.mpi_init(len(sys.argv), sys.argv)
myid = mpi.mpi_comm_rank(mpi.MPI_COMM_WORLD)
numprocs = mpi.mpi_comm_size(mpi.MPI_COMM_WORLD)

print "hello from python main1   myid= ", myid

port_name = mpi.mpi_open_port(mpi.MPI_INFO_NULL)
print "port=", port_name
client = mpi.mpi_comm_accept(port_name, mpi.MPI_INFO_NULL, 0,
                             mpi.MPI_COMM_WORLD)

back = mpi.mpi_recv(1, mpi.MPI_INT, 0, 5678, client)
print "back=", back
back[0] = back[0] + 1
mpi.mpi_send(back, 1, mpi.MPI_INT, 0, 1234, client)

sleep(10)
mpi.mpi_close_port(port_name)
mpi.mpi_comm_disconnect(client)
mpi.mpi_finalize()
Beispiel #46
0
def main():

	def params_3D_2D_NEW(phi, theta, psi, s2x, s2y, mirror):
		# the final ali2d parameters already combine shifts operation first and rotation operation second for parameters converted from 3D
		if mirror:
			m = 1
			alpha, sx, sy, scalen = compose_transform2(0, s2x, s2y, 1.0, 540.0-psi, 0, 0, 1.0)
		else:
			m = 0
			alpha, sx, sy, scalen = compose_transform2(0, s2x, s2y, 1.0, 360.0-psi, 0, 0, 1.0)
		return  alpha, sx, sy, m
	
	progname = os.path.basename(sys.argv[0])
	usage = progname + " prj_stack  --ave2D= --var2D=  --ave3D= --var3D= --img_per_grp= --fl=  --aa=   --sym=symmetry --CTF"
	parser = OptionParser(usage, version=SPARXVERSION)
	
	parser.add_option("--output_dir",   type="string"	   ,	default="./",				    help="Output directory")
	parser.add_option("--ave2D",		type="string"	   ,	default=False,				help="Write to the disk a stack of 2D averages")
	parser.add_option("--var2D",		type="string"	   ,	default=False,				help="Write to the disk a stack of 2D variances")
	parser.add_option("--ave3D",		type="string"	   ,	default=False,				help="Write to the disk reconstructed 3D average")
	parser.add_option("--var3D",		type="string"	   ,	default=False,				help="Compute 3D variability (time consuming!)")
	parser.add_option("--img_per_grp",	type="int"         ,	default=100,	     	    help="Number of neighbouring projections.(Default is 100)")
	parser.add_option("--no_norm",		action="store_true",	default=False,				help="Do not use normalization.(Default is to apply normalization)")
	#parser.add_option("--radius", 	    type="int"         ,	default=-1   ,				help="radius for 3D variability" )
	parser.add_option("--npad",			type="int"         ,	default=2    ,				help="Number of time to pad the original images.(Default is 2 times padding)")
	parser.add_option("--sym" , 		type="string"      ,	default="c1",				help="Symmetry. (Default is no symmetry)")
	parser.add_option("--fl",			type="float"       ,	default=0.0,				help="Low pass filter cutoff in absolute frequency (0.0 - 0.5) and is applied to decimated images. (Default - no filtration)")
	parser.add_option("--aa",			type="float"       ,	default=0.02 ,				help="Fall off of the filter. Use default value if user has no clue about falloff (Default value is 0.02)")
	parser.add_option("--CTF",			action="store_true",	default=False,				help="Use CFT correction.(Default is no CTF correction)")
	#parser.add_option("--MPI" , 		action="store_true",	default=False,				help="use MPI version")
	#parser.add_option("--radiuspca", 	type="int"         ,	default=-1   ,				help="radius for PCA" )
	#parser.add_option("--iter", 		type="int"         ,	default=40   ,				help="maximum number of iterations (stop criterion of reconstruction process)" )
	#parser.add_option("--abs", 		type="float"   ,        default=0.0  ,				help="minimum average absolute change of voxels' values (stop criterion of reconstruction process)" )
	#parser.add_option("--squ", 		type="float"   ,	    default=0.0  ,				help="minimum average squared change of voxels' values (stop criterion of reconstruction process)" )
	parser.add_option("--VAR" , 		action="store_true",	default=False,				help="Stack of input consists of 2D variances (Default False)")
	parser.add_option("--decimate",     type  ="float",         default=0.25,               help="Image decimate rate, a number less than 1. (Default is 0.25)")
	parser.add_option("--window",       type  ="int",           default=0,                  help="Target image size relative to original image size. (Default value is zero.)")
	#parser.add_option("--SND",			action="store_true",	default=False,				help="compute squared normalized differences (Default False)")
	#parser.add_option("--nvec",			type="int"         ,	default=0    ,				help="Number of eigenvectors, (Default = 0 meaning no PCA calculated)")
	parser.add_option("--symmetrize",	action="store_true",	default=False,				help="Prepare input stack for handling symmetry (Default False)")
	parser.add_option("--overhead",     type  ="float",         default=0.5,                help="python overhead per CPU.")

	(options,args) = parser.parse_args()
	#####
	from mpi import mpi_init, mpi_comm_rank, mpi_comm_size, mpi_recv, MPI_COMM_WORLD
	from mpi import mpi_barrier, mpi_reduce, mpi_bcast, mpi_send, MPI_FLOAT, MPI_SUM, MPI_INT, MPI_MAX
	#from mpi import *
	from applications   import MPI_start_end
	from reconstruction import recons3d_em, recons3d_em_MPI
	from reconstruction	import recons3d_4nn_MPI, recons3d_4nn_ctf_MPI
	from utilities      import print_begin_msg, print_end_msg, print_msg
	from utilities      import read_text_row, get_image, get_im, wrap_mpi_send, wrap_mpi_recv
	from utilities      import bcast_EMData_to_all, bcast_number_to_all
	from utilities      import get_symt

	#  This is code for handling symmetries by the above program.  To be incorporated. PAP 01/27/2015

	from EMAN2db import db_open_dict

	# Set up global variables related to bdb cache 
	if global_def.CACHE_DISABLE:
		from utilities import disable_bdb_cache
		disable_bdb_cache()
	
	# Set up global variables related to ERROR function
	global_def.BATCH = True
	
	# detect if program is running under MPI
	RUNNING_UNDER_MPI = "OMPI_COMM_WORLD_SIZE" in os.environ
	if RUNNING_UNDER_MPI: global_def.MPI = True
	if options.output_dir =="./": current_output_dir = os.path.abspath(options.output_dir)
	else: current_output_dir = options.output_dir
	if options.symmetrize :
		if RUNNING_UNDER_MPI:
			try:
				sys.argv = mpi_init(len(sys.argv), sys.argv)
				try:	
					number_of_proc = mpi_comm_size(MPI_COMM_WORLD)
					if( number_of_proc > 1 ):
						ERROR("Cannot use more than one CPU for symmetry preparation","sx3dvariability",1)
				except:
					pass
			except:
				pass
		if not os.path.exists(current_output_dir): os.mkdir(current_output_dir)
		
		#  Input
		#instack = "Clean_NORM_CTF_start_wparams.hdf"
		#instack = "bdb:data"
		
		
		from logger import Logger,BaseLogger_Files
		if os.path.exists(os.path.join(current_output_dir, "log.txt")): os.remove(os.path.join(current_output_dir, "log.txt"))
		log_main=Logger(BaseLogger_Files())
		log_main.prefix = os.path.join(current_output_dir, "./")
		
		instack = args[0]
		sym = options.sym.lower()
		if( sym == "c1" ):
			ERROR("There is no need to symmetrize stack for C1 symmetry","sx3dvariability",1)
		
		line =""
		for a in sys.argv:
			line +=" "+a
		log_main.add(line)
	
		if(instack[:4] !="bdb:"):
			#if output_dir =="./": stack = "bdb:data"
			stack = "bdb:"+current_output_dir+"/data"
			delete_bdb(stack)
			junk = cmdexecute("sxcpy.py  "+instack+"  "+stack)
		else: stack = instack
		
		qt = EMUtil.get_all_attributes(stack,'xform.projection')

		na = len(qt)
		ts = get_symt(sym)
		ks = len(ts)
		angsa = [None]*na
		
		for k in range(ks):
			#Qfile = "Q%1d"%k
			#if options.output_dir!="./": Qfile = os.path.join(options.output_dir,"Q%1d"%k)
			Qfile = os.path.join(current_output_dir, "Q%1d"%k)
			#delete_bdb("bdb:Q%1d"%k)
			delete_bdb("bdb:"+Qfile)
			#junk = cmdexecute("e2bdb.py  "+stack+"  --makevstack=bdb:Q%1d"%k)
			junk = cmdexecute("e2bdb.py  "+stack+"  --makevstack=bdb:"+Qfile)
			#DB = db_open_dict("bdb:Q%1d"%k)
			DB = db_open_dict("bdb:"+Qfile)
			for i in range(na):
				ut = qt[i]*ts[k]
				DB.set_attr(i, "xform.projection", ut)
				#bt = ut.get_params("spider")
				#angsa[i] = [round(bt["phi"],3)%360.0, round(bt["theta"],3)%360.0, bt["psi"], -bt["tx"], -bt["ty"]]
			#write_text_row(angsa, 'ptsma%1d.txt'%k)
			#junk = cmdexecute("e2bdb.py  "+stack+"  --makevstack=bdb:Q%1d"%k)
			#junk = cmdexecute("sxheader.py  bdb:Q%1d  --params=xform.projection  --import=ptsma%1d.txt"%(k,k))
			DB.close()
		#if options.output_dir =="./": delete_bdb("bdb:sdata")
		delete_bdb("bdb:" + current_output_dir + "/"+"sdata")
		#junk = cmdexecute("e2bdb.py . --makevstack=bdb:sdata --filt=Q")
		sdata = "bdb:"+current_output_dir+"/"+"sdata"
		print(sdata)
		junk = cmdexecute("e2bdb.py   " + current_output_dir +"  --makevstack="+sdata +" --filt=Q")
		#junk = cmdexecute("ls  EMAN2DB/sdata*")
		#a = get_im("bdb:sdata")
		a = get_im(sdata)
		a.set_attr("variabilitysymmetry",sym)
		#a.write_image("bdb:sdata")
		a.write_image(sdata)

	else:

		from fundamentals import window2d
		sys.argv       = mpi_init(len(sys.argv), sys.argv)
		myid           = mpi_comm_rank(MPI_COMM_WORLD)
		number_of_proc = mpi_comm_size(MPI_COMM_WORLD)
		main_node      = 0
		shared_comm  = mpi_comm_split_type(MPI_COMM_WORLD, MPI_COMM_TYPE_SHARED,  0, MPI_INFO_NULL)
		myid_on_node = mpi_comm_rank(shared_comm)
		no_of_processes_per_group = mpi_comm_size(shared_comm)
		masters_from_groups_vs_everything_else_comm = mpi_comm_split(MPI_COMM_WORLD, main_node == myid_on_node, myid_on_node)
		color, no_of_groups, balanced_processor_load_on_nodes = get_colors_and_subsets(main_node, MPI_COMM_WORLD, myid, \
		    shared_comm, myid_on_node, masters_from_groups_vs_everything_else_comm)
		overhead_loading = options.overhead*number_of_proc
		#memory_per_node  = options.memory_per_node
		#if memory_per_node == -1.: memory_per_node = 2.*no_of_processes_per_group
		keepgoing = 1
		
		current_window   = options.window
		current_decimate = options.decimate
		
		if len(args) == 1: stack = args[0]
		else:
			print(( "usage: " + usage))
			print(( "Please run '" + progname + " -h' for detailed options"))
			return 1

		t0 = time()	
		# obsolete flags
		options.MPI  = True
		#options.nvec = 0
		options.radiuspca = -1
		options.iter = 40
		options.abs  = 0.0
		options.squ  = 0.0

		if options.fl > 0.0 and options.aa == 0.0:
			ERROR("Fall off has to be given for the low-pass filter", "sx3dvariability", 1, myid)
			
		#if options.VAR and options.SND:
		#	ERROR("Only one of var and SND can be set!", "sx3dvariability", myid)
			
		if options.VAR and (options.ave2D or options.ave3D or options.var2D): 
			ERROR("When VAR is set, the program cannot output ave2D, ave3D or var2D", "sx3dvariability", 1, myid)
			
		#if options.SND and (options.ave2D or options.ave3D):
		#	ERROR("When SND is set, the program cannot output ave2D or ave3D", "sx3dvariability", 1, myid)
		
		#if options.nvec > 0 :
		#	ERROR("PCA option not implemented", "sx3dvariability", 1, myid)
			
		#if options.nvec > 0 and options.ave3D == None:
		#	ERROR("When doing PCA analysis, one must set ave3D", "sx3dvariability", 1, myid)
		
		if current_decimate>1.0 or current_decimate<0.0:
			ERROR("Decimate rate should be a value between 0.0 and 1.0", "sx3dvariability", 1, myid)
		
		if current_window < 0.0:
			ERROR("Target window size should be always larger than zero", "sx3dvariability", 1, myid)
			
		if myid == main_node:
			img  = get_image(stack, 0)
			nx   = img.get_xsize()
			ny   = img.get_ysize()
			if(min(nx, ny) < current_window):   keepgoing = 0
		keepgoing = bcast_number_to_all(keepgoing, main_node, MPI_COMM_WORLD)
		if keepgoing == 0: ERROR("The target window size cannot be larger than the size of decimated image", "sx3dvariability", 1, myid)

		import string
		options.sym = options.sym.lower()
		# if global_def.CACHE_DISABLE:
		# 	from utilities import disable_bdb_cache
		# 	disable_bdb_cache()
		# global_def.BATCH = True
		
		if myid == main_node:
			if not os.path.exists(current_output_dir): os.mkdir(current_output_dir)# Never delete output_dir in the program!
	
		img_per_grp = options.img_per_grp
		#nvec        = options.nvec
		radiuspca   = options.radiuspca
		from logger import Logger,BaseLogger_Files
		#if os.path.exists(os.path.join(options.output_dir, "log.txt")): os.remove(os.path.join(options.output_dir, "log.txt"))
		log_main=Logger(BaseLogger_Files())
		log_main.prefix = os.path.join(current_output_dir, "./")

		if myid == main_node:
			line = ""
			for a in sys.argv: line +=" "+a
			log_main.add(line)
			log_main.add("-------->>>Settings given by all options<<<-------")
			log_main.add("Symmetry             : %s"%options.sym)
			log_main.add("Input stack          : %s"%stack)
			log_main.add("Output_dir           : %s"%current_output_dir)
			
			if options.ave3D: log_main.add("Ave3d                : %s"%options.ave3D)
			if options.var3D: log_main.add("Var3d                : %s"%options.var3D)
			if options.ave2D: log_main.add("Ave2D                : %s"%options.ave2D)
			if options.var2D: log_main.add("Var2D                : %s"%options.var2D)
			if options.VAR:   log_main.add("VAR                  : True")
			else:             log_main.add("VAR                  : False")
			if options.CTF:   log_main.add("CTF correction       : True  ")
			else:             log_main.add("CTF correction       : False ")
			
			log_main.add("Image per group      : %5d"%options.img_per_grp)
			log_main.add("Image decimate rate  : %4.3f"%current_decimate)
			log_main.add("Low pass filter      : %4.3f"%options.fl)
			current_fl = options.fl
			if current_fl == 0.0: current_fl = 0.5
			log_main.add("Current low pass filter is equivalent to cutoff frequency %4.3f for original image size"%round((current_fl*current_decimate),3))
			log_main.add("Window size          : %5d "%current_window)
			log_main.add("sx3dvariability begins")
	
		symbaselen = 0
		if myid == main_node:
			nima = EMUtil.get_image_count(stack)
			img  = get_image(stack)
			nx   = img.get_xsize()
			ny   = img.get_ysize()
			nnxo = nx
			nnyo = ny
			if options.sym != "c1" :
				imgdata = get_im(stack)
				try:
					i = imgdata.get_attr("variabilitysymmetry").lower()
					if(i != options.sym):
						ERROR("The symmetry provided does not agree with the symmetry of the input stack", "sx3dvariability", 1, myid)
				except:
					ERROR("Input stack is not prepared for symmetry, please follow instructions", "sx3dvariability", 1, myid)
				from utilities import get_symt
				i = len(get_symt(options.sym))
				if((nima/i)*i != nima):
					ERROR("The length of the input stack is incorrect for symmetry processing", "sx3dvariability", 1, myid)
				symbaselen = nima/i
			else:  symbaselen = nima
		else:
			nima = 0
			nx = 0
			ny = 0
			nnxo = 0
			nnyo = 0
		nima    = bcast_number_to_all(nima)
		nx      = bcast_number_to_all(nx)
		ny      = bcast_number_to_all(ny)
		nnxo    = bcast_number_to_all(nnxo)
		nnyo    = bcast_number_to_all(nnyo)
		if current_window > max(nx, ny):
			ERROR("Window size is larger than the original image size", "sx3dvariability", 1)
		
		if current_decimate == 1.:
			if current_window !=0:
				nx = current_window
				ny = current_window
		else:
			if current_window == 0:
				nx = int(nx*current_decimate+0.5)
				ny = int(ny*current_decimate+0.5)
			else:
				nx = int(current_window*current_decimate+0.5)
				ny = nx
		symbaselen = bcast_number_to_all(symbaselen)
		
		# check FFT prime number
		from fundamentals import smallprime
		is_fft_friendly = (nx == smallprime(nx))
		
		if not is_fft_friendly:
			if myid == main_node:
				log_main.add("The target image size is not a product of small prime numbers")
				log_main.add("Program adjusts the input settings!")
			### two cases
			if current_decimate == 1.:
				nx = smallprime(nx)
				ny = nx
				current_window = nx # update
				if myid == main_node:
					log_main.add("The window size is updated to %d."%current_window)
			else:
				if current_window == 0:
					nx = smallprime(int(nx*current_decimate+0.5))
					current_decimate = float(nx)/nnxo
					ny = nx
					if (myid == main_node):
						log_main.add("The decimate rate is updated to %f."%current_decimate)
				else:
					nx = smallprime(int(current_window*current_decimate+0.5))
					ny = nx
					current_window = int(nx/current_decimate+0.5)
					if (myid == main_node):
						log_main.add("The window size is updated to %d."%current_window)
						
		if myid == main_node:
			log_main.add("The target image size is %d"%nx)
						
		if radiuspca == -1: radiuspca = nx/2-2
		if myid == main_node: log_main.add("%-70s:  %d\n"%("Number of projection", nima))
		img_begin, img_end = MPI_start_end(nima, number_of_proc, myid)
		
		"""
		if options.SND:
			from projection		import prep_vol, prgs
			from statistics		import im_diff
			from utilities		import get_im, model_circle, get_params_proj, set_params_proj
			from utilities		import get_ctf, generate_ctf
			from filter			import filt_ctf
		
			imgdata = EMData.read_images(stack, range(img_begin, img_end))

			if options.CTF:
				vol = recons3d_4nn_ctf_MPI(myid, imgdata, 1.0, symmetry=options.sym, npad=options.npad, xysize=-1, zsize=-1)
			else:
				vol = recons3d_4nn_MPI(myid, imgdata, symmetry=options.sym, npad=options.npad, xysize=-1, zsize=-1)

			bcast_EMData_to_all(vol, myid)
			volft, kb = prep_vol(vol)

			mask = model_circle(nx/2-2, nx, ny)
			varList = []
			for i in xrange(img_begin, img_end):
				phi, theta, psi, s2x, s2y = get_params_proj(imgdata[i-img_begin])
				ref_prj = prgs(volft, kb, [phi, theta, psi, -s2x, -s2y])
				if options.CTF:
					ctf_params = get_ctf(imgdata[i-img_begin])
					ref_prj = filt_ctf(ref_prj, generate_ctf(ctf_params))
				diff, A, B = im_diff(ref_prj, imgdata[i-img_begin], mask)
				diff2 = diff*diff
				set_params_proj(diff2, [phi, theta, psi, s2x, s2y])
				varList.append(diff2)
			mpi_barrier(MPI_COMM_WORLD)
		"""
		
		if options.VAR: # 2D variance images have no shifts
			#varList   = EMData.read_images(stack, range(img_begin, img_end))
			from EMAN2 import Region
			for index_of_particle in range(img_begin,img_end):
				image = get_im(stack, index_of_proj)
				if current_window > 0: varList.append(fdecimate(window2d(image,current_window,current_window), nx,ny))
				else:   varList.append(fdecimate(image, nx,ny))
				
		else:
			from utilities		import bcast_number_to_all, bcast_list_to_all, send_EMData, recv_EMData
			from utilities		import set_params_proj, get_params_proj, params_3D_2D, get_params2D, set_params2D, compose_transform2
			from utilities		import model_blank, nearest_proj, model_circle, write_text_row, wrap_mpi_gatherv
			from applications	import pca
			from statistics		import avgvar, avgvar_ctf, ccc
			from filter		    import filt_tanl
			from morphology		import threshold, square_root
			from projection 	import project, prep_vol, prgs
			from sets		    import Set
			from utilities      import wrap_mpi_recv, wrap_mpi_bcast, wrap_mpi_send
			import numpy as np
			if myid == main_node:
				t1          = time()
				proj_angles = []
				aveList     = []
				tab = EMUtil.get_all_attributes(stack, 'xform.projection')	
				for i in range(nima):
					t     = tab[i].get_params('spider')
					phi   = t['phi']
					theta = t['theta']
					psi   = t['psi']
					x     = theta
					if x > 90.0: x = 180.0 - x
					x = x*10000+psi
					proj_angles.append([x, t['phi'], t['theta'], t['psi'], i])
				t2 = time()
				log_main.add( "%-70s:  %d\n"%("Number of neighboring projections", img_per_grp))
				log_main.add("...... Finding neighboring projections\n")
				log_main.add( "Number of images per group: %d"%img_per_grp)
				log_main.add( "Now grouping projections")
				proj_angles.sort()
				proj_angles_list = np.full((nima, 4), 0.0, dtype=np.float32)	
				for i in range(nima):
					proj_angles_list[i][0] = proj_angles[i][1]
					proj_angles_list[i][1] = proj_angles[i][2]
					proj_angles_list[i][2] = proj_angles[i][3]
					proj_angles_list[i][3] = proj_angles[i][4]
			else: proj_angles_list = 0
			proj_angles_list = wrap_mpi_bcast(proj_angles_list, main_node, MPI_COMM_WORLD)
			proj_angles      = []
			for i in range(nima):
				proj_angles.append([proj_angles_list[i][0], proj_angles_list[i][1], proj_angles_list[i][2], int(proj_angles_list[i][3])])
			del proj_angles_list
			proj_list, mirror_list = nearest_proj(proj_angles, img_per_grp, range(img_begin, img_end))
			all_proj = Set()
			for im in proj_list:
				for jm in im:
					all_proj.add(proj_angles[jm][3])
			all_proj = list(all_proj)
			index = {}
			for i in range(len(all_proj)): index[all_proj[i]] = i
			mpi_barrier(MPI_COMM_WORLD)
			if myid == main_node:
				log_main.add("%-70s:  %.2f\n"%("Finding neighboring projections lasted [s]", time()-t2))
				log_main.add("%-70s:  %d\n"%("Number of groups processed on the main node", len(proj_list)))
				log_main.add("Grouping projections took:  %12.1f [m]"%((time()-t2)/60.))
				log_main.add("Number of groups on main node: ", len(proj_list))
			mpi_barrier(MPI_COMM_WORLD)

			if myid == main_node:
				log_main.add("...... Calculating the stack of 2D variances \n")
			# Memory estimation. There are two memory consumption peaks
			# peak 1. Compute ave, var; 
			# peak 2. Var volume reconstruction;
			# proj_params = [0.0]*(nima*5)
			aveList = []
			varList = []				
			#if nvec > 0: eigList = [[] for i in range(nvec)]
			dnumber   = len(all_proj)# all neighborhood set for assigned to myid
			pnumber   = len(proj_list)*2. + img_per_grp # aveList and varList 
			tnumber   = dnumber+pnumber
			vol_size2 = nx**3*4.*8/1.e9
			vol_size1 = 2.*nnxo**3*4.*8/1.e9
			proj_size         = nnxo*nnyo*len(proj_list)*4.*2./1.e9 # both aveList and varList
			orig_data_size    = nnxo*nnyo*4.*tnumber/1.e9
			reduced_data_size = nx*nx*4.*tnumber/1.e9
			full_data         = np.full((number_of_proc, 2), -1., dtype=np.float16)
			full_data[myid]   = orig_data_size, reduced_data_size
			if myid != main_node: wrap_mpi_send(full_data, main_node, MPI_COMM_WORLD)
			if myid == main_node:
				for iproc in range(number_of_proc):
					if iproc != main_node:
						dummy = wrap_mpi_recv(iproc, MPI_COMM_WORLD)
						full_data[np.where(dummy>-1)] = dummy[np.where(dummy>-1)]
				del dummy
			mpi_barrier(MPI_COMM_WORLD)
			full_data = wrap_mpi_bcast(full_data, main_node, MPI_COMM_WORLD)
			# find the CPU with heaviest load
			minindx         = np.argsort(full_data, 0)
			heavy_load_myid = minindx[-1][1]
			total_mem       = sum(full_data)
			if myid == main_node:
				if current_window == 0:
					log_main.add("Nx:   current image size = %d. Decimated by %f from %d"%(nx, current_decimate, nnxo))
				else:
					log_main.add("Nx:   current image size = %d. Windowed to %d, and decimated by %f from %d"%(nx, current_window, current_decimate, nnxo))
				log_main.add("Nproj:       number of particle images.")
				log_main.add("Navg:        number of 2D average images.")
				log_main.add("Nvar:        number of 2D variance images.")
				log_main.add("Img_per_grp: user defined image per group for averaging = %d"%img_per_grp)
				log_main.add("Overhead:    total python overhead memory consumption   = %f"%overhead_loading)
				log_main.add("Total memory) = 4.0*nx^2*(nproj + navg +nvar+ img_per_grp)/1.0e9 + overhead: %12.3f [GB]"%\
				   (total_mem[1] + overhead_loading))
			del full_data
			mpi_barrier(MPI_COMM_WORLD)
			if myid == heavy_load_myid:
				log_main.add("Begin reading and preprocessing images on processor. Wait... ")
				ttt = time()
			#imgdata = EMData.read_images(stack, all_proj)			
			imgdata = [ None for im in range(len(all_proj))]
			for index_of_proj in range(len(all_proj)):
				#image = get_im(stack, all_proj[index_of_proj])
				if( current_window > 0): imgdata[index_of_proj] = fdecimate(window2d(get_im(stack, all_proj[index_of_proj]),current_window,current_window), nx, ny)
				else:                    imgdata[index_of_proj] = fdecimate(get_im(stack, all_proj[index_of_proj]), nx, ny)
				
				if (current_decimate> 0.0 and options.CTF):
					ctf = imgdata[index_of_proj].get_attr("ctf")
					ctf.apix = ctf.apix/current_decimate
					imgdata[index_of_proj].set_attr("ctf", ctf)
					
				if myid == heavy_load_myid and index_of_proj%100 == 0:
					log_main.add(" ...... %6.2f%% "%(index_of_proj/float(len(all_proj))*100.))
			mpi_barrier(MPI_COMM_WORLD)
			if myid == heavy_load_myid:
				log_main.add("All_proj preprocessing cost %7.2f m"%((time()-ttt)/60.))
				log_main.add("Wait untill reading on all CPUs done...")
			'''	
			imgdata2 = EMData.read_images(stack, range(img_begin, img_end))
			if options.fl > 0.0:
				for k in xrange(len(imgdata2)):
					imgdata2[k] = filt_tanl(imgdata2[k], options.fl, options.aa)
			if options.CTF:
				vol = recons3d_4nn_ctf_MPI(myid, imgdata2, 1.0, symmetry=options.sym, npad=options.npad, xysize=-1, zsize=-1)
			else:
				vol = recons3d_4nn_MPI(myid, imgdata2, symmetry=options.sym, npad=options.npad, xysize=-1, zsize=-1)
			if myid == main_node:
				vol.write_image("vol_ctf.hdf")
				print_msg("Writing to the disk volume reconstructed from averages as		:  %s\n"%("vol_ctf.hdf"))
			del vol, imgdata2
			mpi_barrier(MPI_COMM_WORLD)
			'''
			from applications import prepare_2d_forPCA
			from utilities    import model_blank
			from EMAN2        import Transform
			if not options.no_norm: 
				mask = model_circle(nx/2-2, nx, nx)
			if options.CTF: 
				from utilities import pad
				from filter import filt_ctf
			from filter import filt_tanl
			if myid == heavy_load_myid:
				log_main.add("Start computing 2D aveList and varList. Wait...")
				ttt = time()
			inner=nx//2-4
			outer=inner+2
			xform_proj_for_2D = [ None for i in range(len(proj_list))]
			for i in range(len(proj_list)):
				ki = proj_angles[proj_list[i][0]][3]
				if ki >= symbaselen:  continue
				mi = index[ki]
				dpar = Util.get_transform_params(imgdata[mi], "xform.projection", "spider")
				phiM, thetaM, psiM, s2xM, s2yM  = dpar["phi"],dpar["theta"],dpar["psi"],-dpar["tx"]*current_decimate,-dpar["ty"]*current_decimate
				grp_imgdata = []
				for j in range(img_per_grp):
					mj = index[proj_angles[proj_list[i][j]][3]]
					cpar = Util.get_transform_params(imgdata[mj], "xform.projection", "spider")
					alpha, sx, sy, mirror = params_3D_2D_NEW(cpar["phi"], cpar["theta"],cpar["psi"], -cpar["tx"]*current_decimate, -cpar["ty"]*current_decimate, mirror_list[i][j])
					if thetaM <= 90:
						if mirror == 0:  alpha, sx, sy, scale = compose_transform2(alpha, sx, sy, 1.0, phiM - cpar["phi"], 0.0, 0.0, 1.0)
						else:            alpha, sx, sy, scale = compose_transform2(alpha, sx, sy, 1.0, 180-(phiM - cpar["phi"]), 0.0, 0.0, 1.0)
					else:
						if mirror == 0:  alpha, sx, sy, scale = compose_transform2(alpha, sx, sy, 1.0, -(phiM- cpar["phi"]), 0.0, 0.0, 1.0)
						else:            alpha, sx, sy, scale = compose_transform2(alpha, sx, sy, 1.0, -(180-(phiM - cpar["phi"])), 0.0, 0.0, 1.0)
					imgdata[mj].set_attr("xform.align2d", Transform({"type":"2D","alpha":alpha,"tx":sx,"ty":sy,"mirror":mirror,"scale":1.0}))
					grp_imgdata.append(imgdata[mj])
				if not options.no_norm:
					for k in range(img_per_grp):
						ave, std, minn, maxx = Util.infomask(grp_imgdata[k], mask, False)
						grp_imgdata[k] -= ave
						grp_imgdata[k] /= std
				if options.fl > 0.0:
					for k in range(img_per_grp):
						grp_imgdata[k] = filt_tanl(grp_imgdata[k], options.fl, options.aa)

				#  Because of background issues, only linear option works.
				if options.CTF:  ave, var = aves_wiener(grp_imgdata, SNR = 1.0e5, interpolation_method = "linear")
				else:  ave, var = ave_var(grp_imgdata)
				# Switch to std dev
				# threshold is not really needed,it is just in case due to numerical accuracy something turns out negative.
				var = square_root(threshold(var))

				set_params_proj(ave, [phiM, thetaM, 0.0, 0.0, 0.0])
				set_params_proj(var, [phiM, thetaM, 0.0, 0.0, 0.0])

				aveList.append(ave)
				varList.append(var)
				xform_proj_for_2D[i] = [phiM, thetaM, 0.0, 0.0, 0.0]

				'''
				if nvec > 0:
					eig = pca(input_stacks=grp_imgdata, subavg="", mask_radius=radiuspca, nvec=nvec, incore=True, shuffle=False, genbuf=True)
					for k in range(nvec):
						set_params_proj(eig[k], [phiM, thetaM, 0.0, 0.0, 0.0])
						eigList[k].append(eig[k])
					"""
					if myid == 0 and i == 0:
						for k in xrange(nvec):
							eig[k].write_image("eig.hdf", k)
					"""
				'''
				if (myid == heavy_load_myid) and (i%100 == 0):
					log_main.add(" ......%6.2f%%  "%(i/float(len(proj_list))*100.))		
			del imgdata, grp_imgdata, cpar, dpar, all_proj, proj_angles, index
			if not options.no_norm: del mask
			if myid == main_node: del tab
			#  At this point, all averages and variances are computed
			mpi_barrier(MPI_COMM_WORLD)
			
			if (myid == heavy_load_myid):
				log_main.add("Computing aveList and varList took %12.1f [m]"%((time()-ttt)/60.))
			
			xform_proj_for_2D = wrap_mpi_gatherv(xform_proj_for_2D, main_node, MPI_COMM_WORLD)
			if (myid == main_node):
				write_text_row(xform_proj_for_2D, os.path.join(current_output_dir, "params.txt"))
			del xform_proj_for_2D
			mpi_barrier(MPI_COMM_WORLD)
			if options.ave2D:
				from fundamentals import fpol
				from applications import header
				if myid == main_node:
					log_main.add("Compute ave2D ... ")
					km = 0
					for i in range(number_of_proc):
						if i == main_node :
							for im in range(len(aveList)):
								aveList[im].write_image(os.path.join(current_output_dir, options.ave2D), km)
								km += 1
						else:
							nl = mpi_recv(1, MPI_INT, i, SPARX_MPI_TAG_UNIVERSAL, MPI_COMM_WORLD)
							nl = int(nl[0])
							for im in range(nl):
								ave = recv_EMData(i, im+i+70000)
								"""
								nm = mpi_recv(1, MPI_INT, i, SPARX_MPI_TAG_UNIVERSAL, MPI_COMM_WORLD)
								nm = int(nm[0])
								members = mpi_recv(nm, MPI_INT, i, SPARX_MPI_TAG_UNIVERSAL, MPI_COMM_WORLD)
								ave.set_attr('members', map(int, members))
								members = mpi_recv(nm, MPI_FLOAT, i, SPARX_MPI_TAG_UNIVERSAL, MPI_COMM_WORLD)
								ave.set_attr('pix_err', map(float, members))
								members = mpi_recv(3, MPI_FLOAT, i, SPARX_MPI_TAG_UNIVERSAL, MPI_COMM_WORLD)
								ave.set_attr('refprojdir', map(float, members))
								"""
								tmpvol=fpol(ave, nx, nx,1)								
								tmpvol.write_image(os.path.join(current_output_dir, options.ave2D), km)
								km += 1
				else:
					mpi_send(len(aveList), 1, MPI_INT, main_node, SPARX_MPI_TAG_UNIVERSAL, MPI_COMM_WORLD)
					for im in range(len(aveList)):
						send_EMData(aveList[im], main_node,im+myid+70000)
						"""
						members = aveList[im].get_attr('members')
						mpi_send(len(members), 1, MPI_INT, main_node, SPARX_MPI_TAG_UNIVERSAL, MPI_COMM_WORLD)
						mpi_send(members, len(members), MPI_INT, main_node, SPARX_MPI_TAG_UNIVERSAL, MPI_COMM_WORLD)
						members = aveList[im].get_attr('pix_err')
						mpi_send(members, len(members), MPI_FLOAT, main_node, SPARX_MPI_TAG_UNIVERSAL, MPI_COMM_WORLD)
						try:
							members = aveList[im].get_attr('refprojdir')
							mpi_send(members, 3, MPI_FLOAT, main_node, SPARX_MPI_TAG_UNIVERSAL, MPI_COMM_WORLD)
						except:
							mpi_send([-999.0,-999.0,-999.0], 3, MPI_FLOAT, main_node, SPARX_MPI_TAG_UNIVERSAL, MPI_COMM_WORLD)
						"""
				if myid == main_node:
					header(os.path.join(current_output_dir, options.ave2D), params='xform.projection', fimport = os.path.join(current_output_dir, "params.txt"))
				mpi_barrier(MPI_COMM_WORLD)	
			if options.ave3D:
				from fundamentals import fpol
				t5 = time()
				if myid == main_node: log_main.add("Reconstruct ave3D ... ")
				ave3D = recons3d_4nn_MPI(myid, aveList, symmetry=options.sym, npad=options.npad)
				bcast_EMData_to_all(ave3D, myid)
				if myid == main_node:
					if current_decimate != 1.0: ave3D = resample(ave3D, 1./current_decimate)
					ave3D = fpol(ave3D, nnxo, nnxo, nnxo) # always to the orignal image size
					set_pixel_size(ave3D, 1.0)
					ave3D.write_image(os.path.join(current_output_dir, options.ave3D))
					log_main.add("Ave3D reconstruction took %12.1f [m]"%((time()-t5)/60.0))
					log_main.add("%-70s:  %s\n"%("The reconstructed ave3D is saved as ", options.ave3D))
					
			mpi_barrier(MPI_COMM_WORLD)		
			del ave, var, proj_list, stack, alpha, sx, sy, mirror, aveList
			'''
			if nvec > 0:
				for k in range(nvec):
					if myid == main_node:log_main.add("Reconstruction eigenvolumes", k)
					cont = True
					ITER = 0
					mask2d = model_circle(radiuspca, nx, nx)
					while cont:
						#print "On node %d, iteration %d"%(myid, ITER)
						eig3D = recons3d_4nn_MPI(myid, eigList[k], symmetry=options.sym, npad=options.npad)
						bcast_EMData_to_all(eig3D, myid, main_node)
						if options.fl > 0.0:
							eig3D = filt_tanl(eig3D, options.fl, options.aa)
						if myid == main_node:
							eig3D.write_image(os.path.join(options.outpout_dir, "eig3d_%03d.hdf"%(k, ITER)))
						Util.mul_img( eig3D, model_circle(radiuspca, nx, nx, nx) )
						eig3Df, kb = prep_vol(eig3D)
						del eig3D
						cont = False
						icont = 0
						for l in range(len(eigList[k])):
							phi, theta, psi, s2x, s2y = get_params_proj(eigList[k][l])
							proj = prgs(eig3Df, kb, [phi, theta, psi, s2x, s2y])
							cl = ccc(proj, eigList[k][l], mask2d)
							if cl < 0.0:
								icont += 1
								cont = True
								eigList[k][l] *= -1.0
						u = int(cont)
						u = mpi_reduce([u], 1, MPI_INT, MPI_MAX, main_node, MPI_COMM_WORLD)
						icont = mpi_reduce([icont], 1, MPI_INT, MPI_SUM, main_node, MPI_COMM_WORLD)

						if myid == main_node:
							u = int(u[0])
							log_main.add(" Eigenvector: ",k," number changed ",int(icont[0]))
						else: u = 0
						u = bcast_number_to_all(u, main_node)
						cont = bool(u)
						ITER += 1

					del eig3Df, kb
					mpi_barrier(MPI_COMM_WORLD)
				del eigList, mask2d
			'''
			if options.ave3D: del ave3D
			if options.var2D:
				from fundamentals import fpol 
				from applications import header
				if myid == main_node:
					log_main.add("Compute var2D...")
					km = 0
					for i in range(number_of_proc):
						if i == main_node :
							for im in range(len(varList)):
								tmpvol=fpol(varList[im], nx, nx,1)
								tmpvol.write_image(os.path.join(current_output_dir, options.var2D), km)
								km += 1
						else:
							nl = mpi_recv(1, MPI_INT, i, SPARX_MPI_TAG_UNIVERSAL, MPI_COMM_WORLD)
							nl = int(nl[0])
							for im in range(nl):
								ave = recv_EMData(i, im+i+70000)
								tmpvol=fpol(ave, nx, nx,1)
								tmpvol.write_image(os.path.join(current_output_dir, options.var2D), km)
								km += 1
				else:
					mpi_send(len(varList), 1, MPI_INT, main_node, SPARX_MPI_TAG_UNIVERSAL, MPI_COMM_WORLD)
					for im in range(len(varList)):
						send_EMData(varList[im], main_node, im+myid+70000)#  What with the attributes??
				mpi_barrier(MPI_COMM_WORLD)
				if myid == main_node:
					from applications import header
					header(os.path.join(current_output_dir, options.var2D), params = 'xform.projection',fimport = os.path.join(current_output_dir, "params.txt"))
				mpi_barrier(MPI_COMM_WORLD)
		if options.var3D:
			if myid == main_node: log_main.add("Reconstruct var3D ...")
			t6 = time()
			# radiusvar = options.radius
			# if( radiusvar < 0 ):  radiusvar = nx//2 -3
			res = recons3d_4nn_MPI(myid, varList, symmetry = options.sym, npad=options.npad)
			#res = recons3d_em_MPI(varList, vol_stack, options.iter, radiusvar, options.abs, True, options.sym, options.squ)
			if myid == main_node:
				from fundamentals import fpol
				if current_decimate != 1.0: res	= resample(res, 1./current_decimate)
				res = fpol(res, nnxo, nnxo, nnxo)
				set_pixel_size(res, 1.0)
				res.write_image(os.path.join(current_output_dir, options.var3D))
				log_main.add("%-70s:  %s\n"%("The reconstructed var3D is saved as ", options.var3D))
				log_main.add("Var3D reconstruction took %f12.1 [m]"%((time()-t6)/60.0))
				log_main.add("Total computation time %f12.1 [m]"%((time()-t0)/60.0))
				log_main.add("sx3dvariability finishes")
		from mpi import mpi_finalize
		mpi_finalize()
		
		if RUNNING_UNDER_MPI: global_def.MPI = False

		global_def.BATCH = False
Beispiel #47
0
def helicalshiftali_MPI(stack,
                        maskfile=None,
                        maxit=100,
                        CTF=False,
                        snr=1.0,
                        Fourvar=False,
                        search_rng=-1):

    nproc = mpi.mpi_comm_size(mpi.MPI_COMM_WORLD)
    myid = mpi.mpi_comm_rank(mpi.MPI_COMM_WORLD)
    main_node = 0

    ftp = file_type(stack)

    if myid == main_node:
        print_begin_msg("helical-shiftali_MPI")

    max_iter = int(maxit)
    if (myid == main_node):
        infils = EMUtil.get_all_attributes(stack, "filament")
        ptlcoords = EMUtil.get_all_attributes(stack, 'ptcl_source_coord')
        filaments = ordersegments(infils, ptlcoords)
        total_nfils = len(filaments)
        inidl = [0] * total_nfils
        for i in range(total_nfils):
            inidl[i] = len(filaments[i])
        linidl = sum(inidl)
        nima = linidl
        tfilaments = []
        for i in range(total_nfils):
            tfilaments += filaments[i]
        del filaments
    else:
        total_nfils = 0
        linidl = 0
    total_nfils = bcast_number_to_all(total_nfils, source_node=main_node)
    if myid != main_node:
        inidl = [-1] * total_nfils
    inidl = bcast_list_to_all(inidl, myid, source_node=main_node)
    linidl = bcast_number_to_all(linidl, source_node=main_node)
    if myid != main_node:
        tfilaments = [-1] * linidl
    tfilaments = bcast_list_to_all(tfilaments, myid, source_node=main_node)
    filaments = []
    iendi = 0
    for i in range(total_nfils):
        isti = iendi
        iendi = isti + inidl[i]
        filaments.append(tfilaments[isti:iendi])
    del tfilaments, inidl

    if myid == main_node:
        print_msg("total number of filaments: %d" % total_nfils)
    if total_nfils < nproc:
        ERROR(
            'number of CPUs (%i) is larger than the number of filaments (%i), please reduce the number of CPUs used'
            % (nproc, total_nfils),
            myid=myid)

    #  balanced load
    temp = chunks_distribution([[len(filaments[i]), i]
                                for i in range(len(filaments))],
                               nproc)[myid:myid + 1][0]
    filaments = [filaments[temp[i][1]] for i in range(len(temp))]
    nfils = len(filaments)

    #filaments = [[0,1]]
    #print "filaments",filaments
    list_of_particles = []
    indcs = []
    k = 0
    for i in range(nfils):
        list_of_particles += filaments[i]
        k1 = k + len(filaments[i])
        indcs.append([k, k1])
        k = k1
    data = EMData.read_images(stack, list_of_particles)
    ldata = len(data)
    sxprint("ldata=", ldata)
    nx = data[0].get_xsize()
    ny = data[0].get_ysize()
    if maskfile == None:
        mrad = min(nx, ny) // 2 - 2
        mask = pad(model_blank(2 * mrad + 1, ny, 1, 1.0), nx, ny, 1, 0.0)
    else:
        mask = get_im(maskfile)

    # apply initial xform.align2d parameters stored in header
    init_params = []
    for im in range(ldata):
        t = data[im].get_attr('xform.align2d')
        init_params.append(t)
        p = t.get_params("2d")
        data[im] = rot_shift2D(data[im], p['alpha'], p['tx'], p['ty'],
                               p['mirror'], p['scale'])

    if CTF:
        from sp_filter import filt_ctf
        from sp_morphology import ctf_img
        ctf_abs_sum = EMData(nx, ny, 1, False)
        ctf_2_sum = EMData(nx, ny, 1, False)
    else:
        ctf_2_sum = None
        ctf_abs_sum = None

    from sp_utilities import info

    for im in range(ldata):
        data[im].set_attr('ID', list_of_particles[im])
        st = Util.infomask(data[im], mask, False)
        data[im] -= st[0]
        if CTF:
            ctf_params = data[im].get_attr("ctf")
            qctf = data[im].get_attr("ctf_applied")
            if qctf == 0:
                data[im] = filt_ctf(fft(data[im]), ctf_params)
                data[im].set_attr('ctf_applied', 1)
            elif qctf != 1:
                ERROR('Incorrectly set qctf flag', myid=myid)
            ctfimg = ctf_img(nx, ctf_params, ny=ny)
            Util.add_img2(ctf_2_sum, ctfimg)
            Util.add_img_abs(ctf_abs_sum, ctfimg)
        else:
            data[im] = fft(data[im])

    del list_of_particles

    if CTF:
        reduce_EMData_to_root(ctf_2_sum, myid, main_node)
        reduce_EMData_to_root(ctf_abs_sum, myid, main_node)
    if CTF:
        if myid != main_node:
            del ctf_2_sum
            del ctf_abs_sum
        else:
            temp = EMData(nx, ny, 1, False)
            tsnr = 1. / snr
            for i in range(0, nx + 2, 2):
                for j in range(ny):
                    temp.set_value_at(i, j, tsnr)
                    temp.set_value_at(i + 1, j, 0.0)
            #info(ctf_2_sum)
            Util.add_img(ctf_2_sum, temp)
            #info(ctf_2_sum)
            del temp

    total_iter = 0
    shift_x = [0.0] * ldata

    for Iter in range(max_iter):
        if myid == main_node:
            start_time = time()
            print_msg("Iteration #%4d\n" % (total_iter))
        total_iter += 1
        avg = EMData(nx, ny, 1, False)
        for im in range(ldata):
            Util.add_img(avg, fshift(data[im], shift_x[im]))

        reduce_EMData_to_root(avg, myid, main_node)

        if myid == main_node:
            if CTF: tavg = Util.divn_filter(avg, ctf_2_sum)
            else: tavg = Util.mult_scalar(avg, 1.0 / float(nima))
        else:
            tavg = model_blank(nx, ny)

        if Fourvar:
            bcast_EMData_to_all(tavg, myid, main_node)
            vav, rvar = varf2d_MPI(myid, data, tavg, mask, "a", CTF)

        if myid == main_node:
            if Fourvar:
                tavg = fft(Util.divn_img(fft(tavg), vav))
                vav_r = Util.pack_complex_to_real(vav)
            # normalize and mask tavg in real space
            tavg = fft(tavg)
            stat = Util.infomask(tavg, mask, False)
            tavg -= stat[0]
            Util.mul_img(tavg, mask)
            tavg.write_image("tavg.hdf", Iter)
            # For testing purposes: shift tavg to some random place and see if the centering is still correct
            #tavg = rot_shift3D(tavg,sx=3,sy=-4)

        if Fourvar: del vav
        bcast_EMData_to_all(tavg, myid, main_node)
        tavg = fft(tavg)

        sx_sum = 0.0
        nxc = nx // 2

        for ifil in range(nfils):
            """
			# Calculate filament average
			avg = EMData(nx, ny, 1, False)
			filnima = 0
			for im in xrange(indcs[ifil][0], indcs[ifil][1]):
				Util.add_img(avg, data[im])
				filnima += 1
			tavg = Util.mult_scalar(avg, 1.0/float(filnima))
			"""
            # Calculate 1D ccf between each segment and filament average
            nsegms = indcs[ifil][1] - indcs[ifil][0]
            ctx = [None] * nsegms
            pcoords = [None] * nsegms
            for im in range(indcs[ifil][0], indcs[ifil][1]):
                ctx[im - indcs[ifil][0]] = Util.window(ccf(tavg, data[im]), nx,
                                                       1)
                pcoords[im - indcs[ifil][0]] = data[im].get_attr(
                    'ptcl_source_coord')
                #ctx[im-indcs[ifil][0]].write_image("ctx.hdf",im-indcs[ifil][0])
                #print "  CTX  ",myid,im,Util.infomask(ctx[im-indcs[ifil][0]], None, True)
            # search for best x-shift
            cents = nsegms // 2

            dst = sqrt(
                max((pcoords[cents][0] - pcoords[0][0])**2 +
                    (pcoords[cents][1] - pcoords[0][1])**2,
                    (pcoords[cents][0] - pcoords[-1][0])**2 +
                    (pcoords[cents][1] - pcoords[-1][1])**2))
            maxincline = atan2(ny // 2 - 2 - float(search_rng), dst)
            kang = int(dst * tan(maxincline) + 0.5)
            #print  "  settings ",nsegms,cents,dst,search_rng,maxincline,kang

            # ## C code for alignment. @ming
            results = [0.0] * 3
            results = Util.helixshiftali(ctx, pcoords, nsegms, maxincline,
                                         kang, search_rng, nxc)
            sib = int(results[0])
            bang = results[1]
            qm = results[2]
            #print qm, sib, bang

            # qm = -1.e23
            #
            # 			for six in xrange(-search_rng, search_rng+1,1):
            # 				q0 = ctx[cents].get_value_at(six+nxc)
            # 				for incline in xrange(kang+1):
            # 					qt = q0
            # 					qu = q0
            # 					if(kang>0):  tang = tan(maxincline/kang*incline)
            # 					else:        tang = 0.0
            # 					for kim in xrange(cents+1,nsegms):
            # 						dst = sqrt((pcoords[cents][0] - pcoords[kim][0])**2 + (pcoords[cents][1] - pcoords[kim][1])**2)
            # 						xl = dst*tang+six+nxc
            # 						ixl = int(xl)
            # 						dxl = xl - ixl
            # 						#print "  A  ", ifil,six,incline,kim,xl,ixl,dxl
            # 						qt += (1.0-dxl)*ctx[kim].get_value_at(ixl) + dxl*ctx[kim].get_value_at(ixl+1)
            # 						xl = -dst*tang+six+nxc
            # 						ixl = int(xl)
            # 						dxl = xl - ixl
            # 						qu += (1.0-dxl)*ctx[kim].get_value_at(ixl) + dxl*ctx[kim].get_value_at(ixl+1)
            # 					for kim in xrange(cents):
            # 						dst = sqrt((pcoords[cents][0] - pcoords[kim][0])**2 + (pcoords[cents][1] - pcoords[kim][1])**2)
            # 						xl = -dst*tang+six+nxc
            # 						ixl = int(xl)
            # 						dxl = xl - ixl
            # 						qt += (1.0-dxl)*ctx[kim].get_value_at(ixl) + dxl*ctx[kim].get_value_at(ixl+1)
            # 						xl =  dst*tang+six+nxc
            # 						ixl = int(xl)
            # 						dxl = xl - ixl
            # 						qu += (1.0-dxl)*ctx[kim].get_value_at(ixl) + dxl*ctx[kim].get_value_at(ixl+1)
            # 					if( qt > qm ):
            # 						qm = qt
            # 						sib = six
            # 						bang = tang
            # 					if( qu > qm ):
            # 						qm = qu
            # 						sib = six
            # 						bang = -tang
            #if incline == 0:  print  "incline = 0  ",six,tang,qt,qu
            #print qm,six,sib,bang
            #print " got results   ",indcs[ifil][0], indcs[ifil][1], ifil,myid,qm,sib,tang,bang,len(ctx),Util.infomask(ctx[0], None, True)
            for im in range(indcs[ifil][0], indcs[ifil][1]):
                kim = im - indcs[ifil][0]
                dst = sqrt((pcoords[cents][0] - pcoords[kim][0])**2 +
                           (pcoords[cents][1] - pcoords[kim][1])**2)
                if (kim < cents): xl = -dst * bang + sib
                else: xl = dst * bang + sib
                shift_x[im] = xl

            # Average shift
            sx_sum += shift_x[indcs[ifil][0] + cents]

        # #print myid,sx_sum,total_nfils
        sx_sum = mpi.mpi_reduce(sx_sum, 1, mpi.MPI_FLOAT, mpi.MPI_SUM,
                                main_node, mpi.MPI_COMM_WORLD)
        if myid == main_node:
            sx_sum = float(sx_sum[0]) / total_nfils
            print_msg("Average shift  %6.2f\n" % (sx_sum))
        else:
            sx_sum = 0.0
        sx_sum = 0.0
        sx_sum = bcast_number_to_all(sx_sum, source_node=main_node)
        for im in range(ldata):
            shift_x[im] -= sx_sum
            #print  "   %3d  %6.3f"%(im,shift_x[im])
        #exit()

    # combine shifts found with the original parameters
    for im in range(ldata):
        t1 = Transform()
        ##import random
        ##shix=random.randint(-10, 10)
        ##t1.set_params({"type":"2D","tx":shix})
        t1.set_params({"type": "2D", "tx": shift_x[im]})
        # combine t0 and t1
        tt = t1 * init_params[im]
        data[im].set_attr("xform.align2d", tt)
    # write out headers and STOP, under MPI writing has to be done sequentially
    mpi.mpi_barrier(mpi.MPI_COMM_WORLD)
    par_str = ["xform.align2d", "ID"]
    if myid == main_node:
        from sp_utilities import file_type
        if (file_type(stack) == "bdb"):
            from sp_utilities import recv_attr_dict_bdb
            recv_attr_dict_bdb(main_node, stack, data, par_str, 0, ldata,
                               nproc)
        else:
            from sp_utilities import recv_attr_dict
            recv_attr_dict(main_node, stack, data, par_str, 0, ldata, nproc)
    else:
        send_attr_dict(main_node, data, par_str, 0, ldata)
    if myid == main_node: print_end_msg("helical-shiftali_MPI")
Beispiel #48
0
def prepare_refrings(
    volft,
    kb,
    nz=-1,
    delta=2.0,
    ref_a="P",
    sym="c1",
    numr=None,
    MPI=False,
    phiEqpsi="Zero",
    kbx=None,
    kby=None,
    initial_theta=None,
    delta_theta=None,
    initial_phi=None,
):
    """
		Generate quasi-evenly distributed reference projections converted to rings
		ref_a can be a list of angles, in which case it is used instead of being generated
	"""
    # mpi communicator can be sent by the MPI parameter
    if type(MPI) is bool:
        if MPI:
            mpi_comm = mpi.MPI_COMM_WORLD
    else:
        mpi_comm = MPI
        MPI = True

    mode = "F"

    if type(ref_a) is list:
        # if ref_a is  list, it has to be a list of projection directions, use it
        ref_angles = ref_a
    else:
        # generate list of Eulerian angles for reference projections
        #  phi, theta, psi
        if initial_theta and initial_phi:
            ref_angles = sp_utilities.even_angles(
                delta,
                theta1=initial_theta,
                phi1=initial_phi,
                symmetry=sym,
                method=ref_a,
                phiEqpsi=phiEqpsi,
            )
        else:
            if initial_theta is None:
                if sym[:1] == "c" or sym[:1] == "d":
                    ref_angles = sp_utilities.even_angles(
                        delta, symmetry=sym, method=ref_a, phiEqpsi=phiEqpsi
                    )
                else:
                    psp = sp_fundamentals.symclass(sym)
                    ref_angles = psp.even_angles(delta)
                    del psp
            else:
                if delta_theta is None:
                    delta_theta = 1.0
                ref_angles = sp_utilities.even_angles(
                    delta,
                    theta1=initial_theta,
                    theta2=delta_theta,
                    symmetry=sym,
                    method=ref_a,
                    phiEqpsi=phiEqpsi,
                )

    wr_four = ringwe(numr, mode)
    cnx = old_div(nz, 2) + 1
    cny = old_div(nz, 2) + 1
    num_ref = len(ref_angles)

    if MPI:
        myid = mpi.mpi_comm_rank(mpi_comm)
        ncpu = mpi.mpi_comm_size(mpi_comm)
    else:
        ncpu = 1
        myid = 0

    if nz < 1:
        sp_global_def.ERROR(
            "Data size has to be given (nz)", "prepare_refrings", 1, myid
        )

    ref_start, ref_end = sp_applications.MPI_start_end(num_ref, ncpu, myid)

    refrings = (
        []
    )  # list of (image objects) reference projections in Fourier representation

    sizex = numr[len(numr) - 2] + numr[len(numr) - 1] - 1

    for i in range(num_ref):
        prjref = EMAN2_cppwrap.EMData()
        prjref.set_size(sizex, 1, 1)
        refrings.append(prjref)

    if kbx is None:
        for i in range(ref_start, ref_end):
            prjref = sp_projection.prgs(
                volft,
                kb,
                [ref_angles[i][0], ref_angles[i][1], ref_angles[i][2], 0.0, 0.0],
            )
            cimage = EMAN2_cppwrap.Util.Polar2Dm(
                prjref, cnx, cny, numr, mode
            )  # currently set to quadratic....
            EMAN2_cppwrap.Util.Normalize_ring(cimage, numr, 0)
            EMAN2_cppwrap.Util.Frngs(cimage, numr)
            EMAN2_cppwrap.Util.Applyws(cimage, numr, wr_four)
            refrings[i] = cimage
    else:
        for i in range(ref_start, ref_end):
            prjref = sp_projection.prgs(
                volft,
                kb,
                [ref_angles[i][0], ref_angles[i][1], ref_angles[i][2], 0.0, 0.0],
                kbx,
                kby,
            )
            cimage = EMAN2_cppwrap.Util.Polar2Dm(
                prjref, cnx, cny, numr, mode
            )  # currently set to quadratic....
            EMAN2_cppwrap.Util.Normalize_ring(cimage, numr, 0)
            EMAN2_cppwrap.Util.Frngs(cimage, numr)
            EMAN2_cppwrap.Util.Applyws(cimage, numr, wr_four)
            refrings[i] = cimage

    if MPI:
        sp_utilities.bcast_compacted_EMData_all_to_all(refrings, myid, comm=mpi_comm)

    for i in range(len(ref_angles)):
        n1, n2, n3 = sp_utilities.getfvec(ref_angles[i][0], ref_angles[i][1])
        refrings[i].set_attr_dict(
            {
                "phi": ref_angles[i][0],
                "theta": ref_angles[i][1],
                "psi": ref_angles[i][2],
                "n1": n1,
                "n2": n2,
                "n3": n3,
            }
        )

    return refrings