Ejemplo n.º 1
0
def wrap_mpi_split(mpi_comm, number_of_subcomm):
    from mpi import mpi_comm_rank, mpi_comm_size, mpi_comm_split
    from air import mpi_env_type

    main_size = mpi_comm_size(mpi_comm)
    if number_of_subcomm > main_size:
        raise RuntimeError("number_of_subcomm > main_size")

    me = mpi_env_type()
    me.main_comm = mpi_comm
    me.main_rank = mpi_comm_rank(mpi_comm)
    me.subcomm_id = me.main_rank % number_of_subcomm
    me.sub_rank = me.main_rank / number_of_subcomm
    me.sub_comm = mpi_comm_split(mpi_comm, me.subcomm_id, me.sub_rank)
    me.subcomms_count = number_of_subcomm
    me.subcomms_roots = range(number_of_subcomm)

    return me
Ejemplo n.º 2
0
#
import numpy
from numpy import *
import mpi
import sys
import math

#print "before",len(sys.argv),sys.argv
sys.argv = mpi.mpi_init(len(sys.argv), sys.argv)
#print "after ",len(sys.argv),sys.argv
myid = mpi.mpi_comm_rank(mpi.MPI_COMM_WORLD)
numnodes = mpi.mpi_comm_size(mpi.MPI_COMM_WORLD)
print "hello from ", myid, " of ", numnodes

color = myid % 2
new_comm = mpi.mpi_comm_split(mpi.MPI_COMM_WORLD, color, myid)
new_id = mpi.mpi_comm_rank(new_comm)
new_nodes = mpi.mpi_comm_size(new_comm)
zero_one = -1
if new_id == 0:
    zero_one = color

zero_one = mpi.mpi_bcast(zero_one, 1, mpi.MPI_INT, 0, new_comm)
if zero_one == 0:
    print myid, " part of even processor communicator ", new_id

if zero_one == 1:
    print myid, " part of odd processor communicator ", new_id

print "old_id=", myid, "new_id=", new_id
Ejemplo n.º 3
0
def calculate_volumes_after_rotation_and_save_them(ali3d_options, rviper_iter, masterdir, bdb_stack_location, mpi_rank, mpi_size,
												   no_of_viper_runs_analyzed_together, no_of_viper_runs_analyzed_together_from_user_options, mpi_comm = -1):
	
	# This function takes into account the case in which there are more processors than images

	if mpi_comm == -1:
		mpi_comm = MPI_COMM_WORLD

	# some arguments are for debugging purposes

	mainoutputdir = masterdir + DIR_DELIM + NAME_OF_MAIN_DIR + ("%03d" + DIR_DELIM) %(rviper_iter)

	# list_of_projection_indices_used_for_outlier_elimination = map(int, read_text_file(mainoutputdir + DIR_DELIM + "list_of_viper_runs_included_in_outlier_elimination.txt"))
	import json; f = open(mainoutputdir + "list_of_viper_runs_included_in_outlier_elimination.json", 'r')
	list_of_independent_viper_run_indices_used_for_outlier_elimination  = json.load(f); f.close()

	if len(list_of_independent_viper_run_indices_used_for_outlier_elimination)==0:
		print "Error: len(list_of_independent_viper_run_indices_used_for_outlier_elimination)==0"
		mpi_finalize()
		sys.exit()

	# if this data analysis step was already performed in the past then return
	# for future changes make sure that the file checked is the last one to be processed !!!
	
	# if(os.path.exists(mainoutputdir + DIR_DELIM + NAME_OF_RUN_DIR + "%03d"%(no_of_viper_runs_analyzed_together - 1) + DIR_DELIM + "rotated_volume.hdf")):
	# check_last_run = max(get_latest_directory_increment_value(mainoutputdir, NAME_OF_RUN_DIR, start_value=0), no_of_viper_runs_analyzed_together_from_user_options)
	# if(os.path.exists(mainoutputdir + DIR_DELIM + NAME_OF_RUN_DIR + "%03d"%(check_last_run) + DIR_DELIM + "rotated_volume.hdf")):
	# 	return

	# if this data analysis step was already performed in the past then return
	for check_run in list_of_independent_viper_run_indices_used_for_outlier_elimination:
		if not (os.path.exists(mainoutputdir + DIR_DELIM + NAME_OF_RUN_DIR + "%03d"%(check_run) + DIR_DELIM + "rotated_volume.hdf")):
			break
	else:
		return

	partstack = []
	# for i1 in range(0,no_of_viper_runs_analyzed_together):
	for i1 in list_of_independent_viper_run_indices_used_for_outlier_elimination:
		partstack.append(mainoutputdir + NAME_OF_RUN_DIR + "%03d"%(i1) + DIR_DELIM + "rotated_reduced_params.txt")
	partids_file_name = mainoutputdir + "this_iteration_index_keep_images.txt"

	lpartids = map(int, read_text_file(partids_file_name) )
	n_projs = len(lpartids)


	if (mpi_size > n_projs):
		# if there are more processors than images
		working = int(not(mpi_rank < n_projs))
		mpi_subcomm = mpi_comm_split(mpi_comm, working,  mpi_rank - working*n_projs)
		mpi_subsize = mpi_comm_size(mpi_subcomm)
		mpi_subrank = mpi_comm_rank(mpi_subcomm)
		if (mpi_rank < n_projs):

			# for i in xrange(no_of_viper_runs_analyzed_together):
			for idx, i in enumerate(list_of_independent_viper_run_indices_used_for_outlier_elimination):
				projdata = getindexdata(bdb_stack_location + "_%03d"%(rviper_iter - 1), partids_file_name, partstack[idx], mpi_rank, mpi_subsize)
				vol = do_volume(projdata, ali3d_options, 0, mpi_comm = mpi_subcomm)
				del projdata
				if( mpi_rank == 0):
					vol.write_image(mainoutputdir + DIR_DELIM + NAME_OF_RUN_DIR + "%03d"%(i) + DIR_DELIM + "rotated_volume.hdf")
					line = strftime("%Y-%m-%d_%H:%M:%S", localtime()) + " => "
					print line  + "Generated rec_ref_volume_run #%01d \n"%i
				del vol

		mpi_barrier(mpi_comm)
	else:
		for idx, i in enumerate(list_of_independent_viper_run_indices_used_for_outlier_elimination):
			projdata = getindexdata(bdb_stack_location + "_%03d"%(rviper_iter - 1), partids_file_name, partstack[idx], mpi_rank, mpi_size)
			vol = do_volume(projdata, ali3d_options, 0, mpi_comm = mpi_comm)
			del projdata
			if( mpi_rank == 0):
				vol.write_image(mainoutputdir + DIR_DELIM + NAME_OF_RUN_DIR + "%03d"%(i) + DIR_DELIM + "rotated_volume.hdf")
				line = strftime("%Y-%m-%d_%H:%M:%S", localtime()) + " => "
				print line + "Generated rec_ref_volume_run #%01d"%i
			del vol

	if( mpi_rank == 0):
		# Align all rotated volumes, calculate their average and save as an overall result
		from utilities import get_params3D, set_params3D, get_im, model_circle
		from statistics import ave_var
		from applications import ali_vol
		# vls = [None]*no_of_viper_runs_analyzed_together
		vls = [None]*len(list_of_independent_viper_run_indices_used_for_outlier_elimination)
		# for i in xrange(no_of_viper_runs_analyzed_together):
		for idx, i in enumerate(list_of_independent_viper_run_indices_used_for_outlier_elimination):
			vls[idx] = get_im(mainoutputdir + DIR_DELIM + NAME_OF_RUN_DIR + "%03d"%(i) + DIR_DELIM + "rotated_volume.hdf")
			set_params3D(vls[idx],[0.,0.,0.,0.,0.,0.,0,1.0])
		asa,sas = ave_var(vls)
		# do the alignment
		nx = asa.get_xsize()
		radius = nx/2 - .5
		st = Util.infomask(asa*asa, model_circle(radius,nx,nx,nx), True)
		goal = st[0]
		going = True
		while(going):
			set_params3D(asa,[0.,0.,0.,0.,0.,0.,0,1.0])
			# for i in xrange(no_of_viper_runs_analyzed_together):
			for idx, i in enumerate(list_of_independent_viper_run_indices_used_for_outlier_elimination):
				o = ali_vol(vls[idx],asa,7.0,5.,radius)  # range of angles and shifts, maybe should be adjusted
				p = get_params3D(o)
				del o
				set_params3D(vls[idx],p)
			asa,sas = ave_var(vls)
			st = Util.infomask(asa*asa, model_circle(radius,nx,nx,nx), True)
			if(st[0] > goal):  goal = st[0]
			else:  going = False
		# over and out
		asa.write_image(mainoutputdir + DIR_DELIM + "average_volume.hdf")
		sas.write_image(mainoutputdir + DIR_DELIM + "variance_volume.hdf")
	return
Ejemplo n.º 4
0
def main():
    def params_3D_2D_NEW(phi, theta, psi, s2x, s2y, mirror):
        # the final ali2d parameters already combine shifts operation first and rotation operation second for parameters converted from 3D
        if mirror:
            m = 1
            alpha, sx, sy, scalen = sp_utilities.compose_transform2(
                0, s2x, s2y, 1.0, 540.0 - psi, 0, 0, 1.0)
        else:
            m = 0
            alpha, sx, sy, scalen = sp_utilities.compose_transform2(
                0, s2x, s2y, 1.0, 360.0 - psi, 0, 0, 1.0)
        return alpha, sx, sy, m

    progname = optparse.os.path.basename(sys.argv[0])
    usage = (
        progname +
        " prj_stack  --ave2D= --var2D=  --ave3D= --var3D= --img_per_grp= --fl=  --aa=   --sym=symmetry --CTF"
    )
    parser = optparse.OptionParser(usage, version=sp_global_def.SPARXVERSION)

    parser.add_option("--output_dir",
                      type="string",
                      default="./",
                      help="Output directory")
    parser.add_option(
        "--ave2D",
        type="string",
        default=False,
        help="Write to the disk a stack of 2D averages",
    )
    parser.add_option(
        "--var2D",
        type="string",
        default=False,
        help="Write to the disk a stack of 2D variances",
    )
    parser.add_option(
        "--ave3D",
        type="string",
        default=False,
        help="Write to the disk reconstructed 3D average",
    )
    parser.add_option(
        "--var3D",
        type="string",
        default=False,
        help="Compute 3D variability (time consuming!)",
    )
    parser.add_option(
        "--img_per_grp",
        type="int",
        default=100,
        help="Number of neighbouring projections.(Default is 100)",
    )
    parser.add_option(
        "--no_norm",
        action="store_true",
        default=False,
        help="Do not use normalization.(Default is to apply normalization)",
    )
    # parser.add_option("--radius", 	    type="int"         ,	default=-1   ,				help="radius for 3D variability" )
    parser.add_option(
        "--npad",
        type="int",
        default=2,
        help=
        "Number of time to pad the original images.(Default is 2 times padding)",
    )
    parser.add_option("--sym",
                      type="string",
                      default="c1",
                      help="Symmetry. (Default is no symmetry)")
    parser.add_option(
        "--fl",
        type="float",
        default=0.0,
        help=
        "Low pass filter cutoff in absolute frequency (0.0 - 0.5) and is applied to decimated images. (Default - no filtration)",
    )
    parser.add_option(
        "--aa",
        type="float",
        default=0.02,
        help=
        "Fall off of the filter. Use default value if user has no clue about falloff (Default value is 0.02)",
    )
    parser.add_option(
        "--CTF",
        action="store_true",
        default=False,
        help="Use CFT correction.(Default is no CTF correction)",
    )
    # parser.add_option("--MPI" , 		action="store_true",	default=False,				help="use MPI version")
    # parser.add_option("--radiuspca", 	type="int"         ,	default=-1   ,				help="radius for PCA" )
    # parser.add_option("--iter", 		type="int"         ,	default=40   ,				help="maximum number of iterations (stop criterion of reconstruction process)" )
    # parser.add_option("--abs", 		type="float"   ,        default=0.0  ,				help="minimum average absolute change of voxels' values (stop criterion of reconstruction process)" )
    # parser.add_option("--squ", 		type="float"   ,	    default=0.0  ,				help="minimum average squared change of voxels' values (stop criterion of reconstruction process)" )
    parser.add_option(
        "--VAR",
        action="store_true",
        default=False,
        help="Stack of input consists of 2D variances (Default False)",
    )
    parser.add_option(
        "--decimate",
        type="float",
        default=0.25,
        help="Image decimate rate, a number less than 1. (Default is 0.25)",
    )
    parser.add_option(
        "--window",
        type="int",
        default=0,
        help=
        "Target image size relative to original image size. (Default value is zero.)",
    )
    # parser.add_option("--SND",			action="store_true",	default=False,				help="compute squared normalized differences (Default False)")
    # parser.add_option("--nvec",			type="int"         ,	default=0    ,				help="Number of eigenvectors, (Default = 0 meaning no PCA calculated)")
    parser.add_option(
        "--symmetrize",
        action="store_true",
        default=False,
        help="Prepare input stack for handling symmetry (Default False)",
    )
    parser.add_option("--overhead",
                      type="float",
                      default=0.5,
                      help="python overhead per CPU.")

    (options, args) = parser.parse_args()
    #####
    # from mpi import *

    #  This is code for handling symmetries by the above program.  To be incorporated. PAP 01/27/2015

    # Set up global variables related to bdb cache
    if sp_global_def.CACHE_DISABLE:
        sp_utilities.disable_bdb_cache()

    # Set up global variables related to ERROR function
    sp_global_def.BATCH = True

    # detect if program is running under MPI
    RUNNING_UNDER_MPI = "OMPI_COMM_WORLD_SIZE" in optparse.os.environ
    if RUNNING_UNDER_MPI:
        sp_global_def.MPI = True
    if options.output_dir == "./":
        current_output_dir = optparse.os.path.abspath(options.output_dir)
    else:
        current_output_dir = options.output_dir
    if options.symmetrize:

        if mpi.mpi_comm_size(mpi.MPI_COMM_WORLD) > 1:
            sp_global_def.ERROR(
                "Cannot use more than one CPU for symmetry preparation")

        if not optparse.os.path.exists(current_output_dir):
            optparse.os.makedirs(current_output_dir)
            sp_global_def.write_command(current_output_dir)

        if optparse.os.path.exists(
                optparse.os.path.join(current_output_dir, "log.txt")):
            optparse.os.remove(
                optparse.os.path.join(current_output_dir, "log.txt"))
        log_main = sp_logger.Logger(sp_logger.BaseLogger_Files())
        log_main.prefix = optparse.os.path.join(current_output_dir, "./")

        instack = args[0]
        sym = options.sym.lower()
        if sym == "c1":
            sp_global_def.ERROR(
                "There is no need to symmetrize stack for C1 symmetry")

        line = ""
        for a in sys.argv:
            line += " " + a
        log_main.add(line)

        if instack[:4] != "bdb:":
            # if output_dir =="./": stack = "bdb:data"
            stack = "bdb:" + current_output_dir + "/data"
            sp_utilities.delete_bdb(stack)
            junk = sp_utilities.cmdexecute("sp_cpy.py  " + instack + "  " +
                                           stack)
        else:
            stack = instack

        qt = EMAN2_cppwrap.EMUtil.get_all_attributes(stack, "xform.projection")

        na = len(qt)
        ts = sp_utilities.get_symt(sym)
        ks = len(ts)
        angsa = [None] * na

        for k in range(ks):
            # Qfile = "Q%1d"%k
            # if options.output_dir!="./": Qfile = os.path.join(options.output_dir,"Q%1d"%k)
            Qfile = optparse.os.path.join(current_output_dir, "Q%1d" % k)
            # delete_bdb("bdb:Q%1d"%k)
            sp_utilities.delete_bdb("bdb:" + Qfile)
            # junk = cmdexecute("e2bdb.py  "+stack+"  --makevstack=bdb:Q%1d"%k)
            junk = sp_utilities.cmdexecute("e2bdb.py  " + stack +
                                           "  --makevstack=bdb:" + Qfile)
            # DB = db_open_dict("bdb:Q%1d"%k)
            DB = EMAN2db.db_open_dict("bdb:" + Qfile)
            for i in range(na):
                ut = qt[i] * ts[k]
                DB.set_attr(i, "xform.projection", ut)
                # bt = ut.get_params("spider")
                # angsa[i] = [round(bt["phi"],3)%360.0, round(bt["theta"],3)%360.0, bt["psi"], -bt["tx"], -bt["ty"]]
            # write_text_row(angsa, 'ptsma%1d.txt'%k)
            # junk = cmdexecute("e2bdb.py  "+stack+"  --makevstack=bdb:Q%1d"%k)
            # junk = cmdexecute("sxheader.py  bdb:Q%1d  --params=xform.projection  --import=ptsma%1d.txt"%(k,k))
            DB.close()
        # if options.output_dir =="./": delete_bdb("bdb:sdata")
        sp_utilities.delete_bdb("bdb:" + current_output_dir + "/" + "sdata")
        # junk = cmdexecute("e2bdb.py . --makevstack=bdb:sdata --filt=Q")
        sdata = "bdb:" + current_output_dir + "/" + "sdata"
        sp_global_def.sxprint(sdata)
        junk = sp_utilities.cmdexecute("e2bdb.py   " + current_output_dir +
                                       "  --makevstack=" + sdata + " --filt=Q")
        # junk = cmdexecute("ls  EMAN2DB/sdata*")
        # a = get_im("bdb:sdata")
        a = sp_utilities.get_im(sdata)
        a.set_attr("variabilitysymmetry", sym)
        # a.write_image("bdb:sdata")
        a.write_image(sdata)

    else:

        myid = mpi.mpi_comm_rank(mpi.MPI_COMM_WORLD)
        number_of_proc = mpi.mpi_comm_size(mpi.MPI_COMM_WORLD)
        main_node = 0
        shared_comm = mpi.mpi_comm_split_type(mpi.MPI_COMM_WORLD,
                                              mpi.MPI_COMM_TYPE_SHARED, 0,
                                              mpi.MPI_INFO_NULL)
        myid_on_node = mpi.mpi_comm_rank(shared_comm)
        no_of_processes_per_group = mpi.mpi_comm_size(shared_comm)
        masters_from_groups_vs_everything_else_comm = mpi.mpi_comm_split(
            mpi.MPI_COMM_WORLD, main_node == myid_on_node, myid_on_node)
        color, no_of_groups, balanced_processor_load_on_nodes = sp_utilities.get_colors_and_subsets(
            main_node,
            mpi.MPI_COMM_WORLD,
            myid,
            shared_comm,
            myid_on_node,
            masters_from_groups_vs_everything_else_comm,
        )
        overhead_loading = options.overhead * number_of_proc
        # memory_per_node  = options.memory_per_node
        # if memory_per_node == -1.: memory_per_node = 2.*no_of_processes_per_group
        keepgoing = 1

        current_window = options.window
        current_decimate = options.decimate

        if len(args) == 1:
            stack = args[0]
        else:
            sp_global_def.sxprint("Usage: " + usage)
            sp_global_def.sxprint("Please run '" + progname +
                                  " -h' for detailed options")
            sp_global_def.ERROR(
                "Invalid number of parameters used. Please see usage information above."
            )
            return

        t0 = time.time()
        # obsolete flags
        options.MPI = True
        # options.nvec = 0
        options.radiuspca = -1
        options.iter = 40
        options.abs = 0.0
        options.squ = 0.0

        if options.fl > 0.0 and options.aa == 0.0:
            sp_global_def.ERROR(
                "Fall off has to be given for the low-pass filter", myid=myid)

        # if options.VAR and options.SND:
        # 	ERROR( "Only one of var and SND can be set!",myid=myid )

        if options.VAR and (options.ave2D or options.ave3D or options.var2D):
            sp_global_def.ERROR(
                "When VAR is set, the program cannot output ave2D, ave3D or var2D",
                myid=myid,
            )

        # if options.SND and (options.ave2D or options.ave3D):
        # 	ERROR( "When SND is set, the program cannot output ave2D or ave3D", myid=myid )

        # if options.nvec > 0 :
        # 	ERROR( "PCA option not implemented", myid=myid )

        # if options.nvec > 0 and options.ave3D == None:
        # 	ERROR( "When doing PCA analysis, one must set ave3D", myid=myid )

        if current_decimate > 1.0 or current_decimate < 0.0:
            sp_global_def.ERROR(
                "Decimate rate should be a value between 0.0 and 1.0",
                myid=myid)

        if current_window < 0.0:
            sp_global_def.ERROR(
                "Target window size should be always larger than zero",
                myid=myid)

        if myid == main_node:
            img = sp_utilities.get_image(stack, 0)
            nx = img.get_xsize()
            ny = img.get_ysize()
            if min(nx, ny) < current_window:
                keepgoing = 0
        keepgoing = sp_utilities.bcast_number_to_all(keepgoing, main_node,
                                                     mpi.MPI_COMM_WORLD)
        if keepgoing == 0:
            sp_global_def.ERROR(
                "The target window size cannot be larger than the size of decimated image",
                myid=myid,
            )

        options.sym = options.sym.lower()
        # if global_def.CACHE_DISABLE:
        # 	from utilities import disable_bdb_cache
        # 	disable_bdb_cache()
        # global_def.BATCH = True

        if myid == main_node:
            if not optparse.os.path.exists(current_output_dir):
                optparse.os.makedirs(
                    current_output_dir
                )  # Never delete output_dir in the program!

        img_per_grp = options.img_per_grp
        # nvec        = options.nvec
        radiuspca = options.radiuspca
        # if os.path.exists(os.path.join(options.output_dir, "log.txt")): os.remove(os.path.join(options.output_dir, "log.txt"))
        log_main = sp_logger.Logger(sp_logger.BaseLogger_Files())
        log_main.prefix = optparse.os.path.join(current_output_dir, "./")

        if myid == main_node:
            line = ""
            for a in sys.argv:
                line += " " + a
            log_main.add(line)
            log_main.add("-------->>>Settings given by all options<<<-------")
            log_main.add("Symmetry             : %s" % options.sym)
            log_main.add("Input stack          : %s" % stack)
            log_main.add("Output_dir           : %s" % current_output_dir)

            if options.ave3D:
                log_main.add("Ave3d                : %s" % options.ave3D)
            if options.var3D:
                log_main.add("Var3d                : %s" % options.var3D)
            if options.ave2D:
                log_main.add("Ave2D                : %s" % options.ave2D)
            if options.var2D:
                log_main.add("Var2D                : %s" % options.var2D)
            if options.VAR:
                log_main.add("VAR                  : True")
            else:
                log_main.add("VAR                  : False")
            if options.CTF:
                log_main.add("CTF correction       : True  ")
            else:
                log_main.add("CTF correction       : False ")

            log_main.add("Image per group      : %5d" % options.img_per_grp)
            log_main.add("Image decimate rate  : %4.3f" % current_decimate)
            log_main.add("Low pass filter      : %4.3f" % options.fl)
            current_fl = options.fl
            if current_fl == 0.0:
                current_fl = 0.5
            log_main.add(
                "Current low pass filter is equivalent to cutoff frequency %4.3f for original image size"
                % round((current_fl * current_decimate), 3))
            log_main.add("Window size          : %5d " % current_window)
            log_main.add("sx3dvariability begins")

        symbaselen = 0
        if myid == main_node:
            nima = EMAN2_cppwrap.EMUtil.get_image_count(stack)
            img = sp_utilities.get_image(stack)
            nx = img.get_xsize()
            ny = img.get_ysize()
            nnxo = nx
            nnyo = ny
            if options.sym != "c1":
                imgdata = sp_utilities.get_im(stack)
                try:
                    i = imgdata.get_attr("variabilitysymmetry").lower()
                    if i != options.sym:
                        sp_global_def.ERROR(
                            "The symmetry provided does not agree with the symmetry of the input stack",
                            myid=myid,
                        )
                except:
                    sp_global_def.ERROR(
                        "Input stack is not prepared for symmetry, please follow instructions",
                        myid=myid,
                    )
                i = len(sp_utilities.get_symt(options.sym))
                if (old_div(nima, i)) * i != nima:
                    sp_global_def.ERROR(
                        "The length of the input stack is incorrect for symmetry processing",
                        myid=myid,
                    )
                symbaselen = old_div(nima, i)
            else:
                symbaselen = nima
        else:
            nima = 0
            nx = 0
            ny = 0
            nnxo = 0
            nnyo = 0
        nima = sp_utilities.bcast_number_to_all(nima)
        nx = sp_utilities.bcast_number_to_all(nx)
        ny = sp_utilities.bcast_number_to_all(ny)
        nnxo = sp_utilities.bcast_number_to_all(nnxo)
        nnyo = sp_utilities.bcast_number_to_all(nnyo)
        if current_window > max(nx, ny):
            sp_global_def.ERROR(
                "Window size is larger than the original image size")

        if current_decimate == 1.0:
            if current_window != 0:
                nx = current_window
                ny = current_window
        else:
            if current_window == 0:
                nx = int(nx * current_decimate + 0.5)
                ny = int(ny * current_decimate + 0.5)
            else:
                nx = int(current_window * current_decimate + 0.5)
                ny = nx
        symbaselen = sp_utilities.bcast_number_to_all(symbaselen)

        # check FFT prime number
        is_fft_friendly = nx == sp_fundamentals.smallprime(nx)

        if not is_fft_friendly:
            if myid == main_node:
                log_main.add(
                    "The target image size is not a product of small prime numbers"
                )
                log_main.add("Program adjusts the input settings!")
            ### two cases
            if current_decimate == 1.0:
                nx = sp_fundamentals.smallprime(nx)
                ny = nx
                current_window = nx  # update
                if myid == main_node:
                    log_main.add("The window size is updated to %d." %
                                 current_window)
            else:
                if current_window == 0:
                    nx = sp_fundamentals.smallprime(
                        int(nx * current_decimate + 0.5))
                    current_decimate = old_div(float(nx), nnxo)
                    ny = nx
                    if myid == main_node:
                        log_main.add("The decimate rate is updated to %f." %
                                     current_decimate)
                else:
                    nx = sp_fundamentals.smallprime(
                        int(current_window * current_decimate + 0.5))
                    ny = nx
                    current_window = int(old_div(nx, current_decimate) + 0.5)
                    if myid == main_node:
                        log_main.add("The window size is updated to %d." %
                                     current_window)

        if myid == main_node:
            log_main.add("The target image size is %d" % nx)

        if radiuspca == -1:
            radiuspca = old_div(nx, 2) - 2
        if myid == main_node:
            log_main.add("%-70s:  %d\n" % ("Number of projection", nima))
        img_begin, img_end = sp_applications.MPI_start_end(
            nima, number_of_proc, myid)
        """Multiline Comment0"""
        """
        Comments from adnan, replace index_of_proj to index_of_particle, index_of_proj was not defined
        also varList is not defined not made an empty list there
        """

        if options.VAR:  # 2D variance images have no shifts
            varList = []
            # varList   = EMData.read_images(stack, range(img_begin, img_end))
            for index_of_particle in range(img_begin, img_end):
                image = sp_utilities.get_im(stack, index_of_particle)
                if current_window > 0:
                    varList.append(
                        sp_fundamentals.fdecimate(
                            sp_fundamentals.window2d(image, current_window,
                                                     current_window),
                            nx,
                            ny,
                        ))
                else:
                    varList.append(sp_fundamentals.fdecimate(image, nx, ny))

        else:
            if myid == main_node:
                t1 = time.time()
                proj_angles = []
                aveList = []
                tab = EMAN2_cppwrap.EMUtil.get_all_attributes(
                    stack, "xform.projection")
                for i in range(nima):
                    t = tab[i].get_params("spider")
                    phi = t["phi"]
                    theta = t["theta"]
                    psi = t["psi"]
                    x = theta
                    if x > 90.0:
                        x = 180.0 - x
                    x = x * 10000 + psi
                    proj_angles.append([x, t["phi"], t["theta"], t["psi"], i])
                t2 = time.time()
                log_main.add(
                    "%-70s:  %d\n" %
                    ("Number of neighboring projections", img_per_grp))
                log_main.add("...... Finding neighboring projections\n")
                log_main.add("Number of images per group: %d" % img_per_grp)
                log_main.add("Now grouping projections")
                proj_angles.sort()
                proj_angles_list = numpy.full((nima, 4),
                                              0.0,
                                              dtype=numpy.float32)
                for i in range(nima):
                    proj_angles_list[i][0] = proj_angles[i][1]
                    proj_angles_list[i][1] = proj_angles[i][2]
                    proj_angles_list[i][2] = proj_angles[i][3]
                    proj_angles_list[i][3] = proj_angles[i][4]
            else:
                proj_angles_list = 0
            proj_angles_list = sp_utilities.wrap_mpi_bcast(
                proj_angles_list, main_node, mpi.MPI_COMM_WORLD)
            proj_angles = []
            for i in range(nima):
                proj_angles.append([
                    proj_angles_list[i][0],
                    proj_angles_list[i][1],
                    proj_angles_list[i][2],
                    int(proj_angles_list[i][3]),
                ])
            del proj_angles_list
            proj_list, mirror_list = sp_utilities.nearest_proj(
                proj_angles, img_per_grp, range(img_begin, img_end))
            all_proj = []
            for im in proj_list:
                for jm in im:
                    all_proj.append(proj_angles[jm][3])
            all_proj = list(set(all_proj))
            index = {}
            for i in range(len(all_proj)):
                index[all_proj[i]] = i
            mpi.mpi_barrier(mpi.MPI_COMM_WORLD)
            if myid == main_node:
                log_main.add("%-70s:  %.2f\n" %
                             ("Finding neighboring projections lasted [s]",
                              time.time() - t2))
                log_main.add("%-70s:  %d\n" %
                             ("Number of groups processed on the main node",
                              len(proj_list)))
                log_main.add("Grouping projections took:  %12.1f [m]" %
                             (old_div((time.time() - t2), 60.0)))
                log_main.add("Number of groups on main node: ", len(proj_list))
            mpi.mpi_barrier(mpi.MPI_COMM_WORLD)

            if myid == main_node:
                log_main.add("...... Calculating the stack of 2D variances \n")
            # Memory estimation. There are two memory consumption peaks
            # peak 1. Compute ave, var;
            # peak 2. Var volume reconstruction;
            # proj_params = [0.0]*(nima*5)
            aveList = []
            varList = []
            # if nvec > 0: eigList = [[] for i in range(nvec)]
            dnumber = len(
                all_proj)  # all neighborhood set for assigned to myid
            pnumber = len(proj_list) * 2.0 + img_per_grp  # aveList and varList
            tnumber = dnumber + pnumber
            vol_size2 = old_div(nx**3 * 4.0 * 8, 1.0e9)
            vol_size1 = old_div(2.0 * nnxo**3 * 4.0 * 8, 1.0e9)
            proj_size = old_div(nnxo * nnyo * len(proj_list) * 4.0 * 2.0,
                                1.0e9)  # both aveList and varList
            orig_data_size = old_div(nnxo * nnyo * 4.0 * tnumber, 1.0e9)
            reduced_data_size = old_div(nx * nx * 4.0 * tnumber, 1.0e9)
            full_data = numpy.full((number_of_proc, 2),
                                   -1.0,
                                   dtype=numpy.float16)
            full_data[myid] = orig_data_size, reduced_data_size
            if myid != main_node:
                sp_utilities.wrap_mpi_send(full_data, main_node,
                                           mpi.MPI_COMM_WORLD)
            if myid == main_node:
                for iproc in range(number_of_proc):
                    if iproc != main_node:
                        dummy = sp_utilities.wrap_mpi_recv(
                            iproc, mpi.MPI_COMM_WORLD)
                        full_data[numpy.where(dummy > -1)] = dummy[numpy.where(
                            dummy > -1)]
                del dummy
            mpi.mpi_barrier(mpi.MPI_COMM_WORLD)
            full_data = sp_utilities.wrap_mpi_bcast(full_data, main_node,
                                                    mpi.MPI_COMM_WORLD)
            # find the CPU with heaviest load
            minindx = numpy.argsort(full_data, 0)
            heavy_load_myid = minindx[-1][1]
            total_mem = sum(full_data)
            if myid == main_node:
                if current_window == 0:
                    log_main.add(
                        "Nx:   current image size = %d. Decimated by %f from %d"
                        % (nx, current_decimate, nnxo))
                else:
                    log_main.add(
                        "Nx:   current image size = %d. Windowed to %d, and decimated by %f from %d"
                        % (nx, current_window, current_decimate, nnxo))
                log_main.add("Nproj:       number of particle images.")
                log_main.add("Navg:        number of 2D average images.")
                log_main.add("Nvar:        number of 2D variance images.")
                log_main.add(
                    "Img_per_grp: user defined image per group for averaging = %d"
                    % img_per_grp)
                log_main.add(
                    "Overhead:    total python overhead memory consumption   = %f"
                    % overhead_loading)
                log_main.add(
                    "Total memory) = 4.0*nx^2*(nproj + navg +nvar+ img_per_grp)/1.0e9 + overhead: %12.3f [GB]"
                    % (total_mem[1] + overhead_loading))
            del full_data
            mpi.mpi_barrier(mpi.MPI_COMM_WORLD)
            if myid == heavy_load_myid:
                log_main.add(
                    "Begin reading and preprocessing images on processor. Wait... "
                )
                ttt = time.time()
            # imgdata = EMData.read_images(stack, all_proj)
            imgdata = [None for im in range(len(all_proj))]
            for index_of_proj in range(len(all_proj)):
                # image = get_im(stack, all_proj[index_of_proj])
                if current_window > 0:
                    imgdata[index_of_proj] = sp_fundamentals.fdecimate(
                        sp_fundamentals.window2d(
                            sp_utilities.get_im(stack,
                                                all_proj[index_of_proj]),
                            current_window,
                            current_window,
                        ),
                        nx,
                        ny,
                    )
                else:
                    imgdata[index_of_proj] = sp_fundamentals.fdecimate(
                        sp_utilities.get_im(stack, all_proj[index_of_proj]),
                        nx, ny)

                if current_decimate > 0.0 and options.CTF:
                    ctf = imgdata[index_of_proj].get_attr("ctf")
                    ctf.apix = old_div(ctf.apix, current_decimate)
                    imgdata[index_of_proj].set_attr("ctf", ctf)

                if myid == heavy_load_myid and index_of_proj % 100 == 0:
                    log_main.add(
                        " ...... %6.2f%% " %
                        (old_div(index_of_proj, float(len(all_proj))) * 100.0))
            mpi.mpi_barrier(mpi.MPI_COMM_WORLD)
            if myid == heavy_load_myid:
                log_main.add("All_proj preprocessing cost %7.2f m" % (old_div(
                    (time.time() - ttt), 60.0)))
                log_main.add("Wait untill reading on all CPUs done...")
            """Multiline Comment1"""
            if not options.no_norm:
                mask = sp_utilities.model_circle(old_div(nx, 2) - 2, nx, nx)
            if myid == heavy_load_myid:
                log_main.add("Start computing 2D aveList and varList. Wait...")
                ttt = time.time()
            inner = old_div(nx, 2) - 4
            outer = inner + 2
            xform_proj_for_2D = [None for i in range(len(proj_list))]
            for i in range(len(proj_list)):
                ki = proj_angles[proj_list[i][0]][3]
                if ki >= symbaselen:
                    continue
                mi = index[ki]
                dpar = EMAN2_cppwrap.Util.get_transform_params(
                    imgdata[mi], "xform.projection", "spider")
                phiM, thetaM, psiM, s2xM, s2yM = (
                    dpar["phi"],
                    dpar["theta"],
                    dpar["psi"],
                    -dpar["tx"] * current_decimate,
                    -dpar["ty"] * current_decimate,
                )
                grp_imgdata = []
                for j in range(img_per_grp):
                    mj = index[proj_angles[proj_list[i][j]][3]]
                    cpar = EMAN2_cppwrap.Util.get_transform_params(
                        imgdata[mj], "xform.projection", "spider")
                    alpha, sx, sy, mirror = params_3D_2D_NEW(
                        cpar["phi"],
                        cpar["theta"],
                        cpar["psi"],
                        -cpar["tx"] * current_decimate,
                        -cpar["ty"] * current_decimate,
                        mirror_list[i][j],
                    )
                    if thetaM <= 90:
                        if mirror == 0:
                            alpha, sx, sy, scale = sp_utilities.compose_transform2(
                                alpha, sx, sy, 1.0, phiM - cpar["phi"], 0.0,
                                0.0, 1.0)
                        else:
                            alpha, sx, sy, scale = sp_utilities.compose_transform2(
                                alpha,
                                sx,
                                sy,
                                1.0,
                                180 - (phiM - cpar["phi"]),
                                0.0,
                                0.0,
                                1.0,
                            )
                    else:
                        if mirror == 0:
                            alpha, sx, sy, scale = sp_utilities.compose_transform2(
                                alpha, sx, sy, 1.0, -(phiM - cpar["phi"]), 0.0,
                                0.0, 1.0)
                        else:
                            alpha, sx, sy, scale = sp_utilities.compose_transform2(
                                alpha,
                                sx,
                                sy,
                                1.0,
                                -(180 - (phiM - cpar["phi"])),
                                0.0,
                                0.0,
                                1.0,
                            )
                    imgdata[mj].set_attr(
                        "xform.align2d",
                        EMAN2_cppwrap.Transform({
                            "type": "2D",
                            "alpha": alpha,
                            "tx": sx,
                            "ty": sy,
                            "mirror": mirror,
                            "scale": 1.0,
                        }),
                    )
                    grp_imgdata.append(imgdata[mj])
                if not options.no_norm:
                    for k in range(img_per_grp):
                        ave, std, minn, maxx = EMAN2_cppwrap.Util.infomask(
                            grp_imgdata[k], mask, False)
                        grp_imgdata[k] -= ave
                        grp_imgdata[k] = old_div(grp_imgdata[k], std)
                if options.fl > 0.0:
                    for k in range(img_per_grp):
                        grp_imgdata[k] = sp_filter.filt_tanl(
                            grp_imgdata[k], options.fl, options.aa)

                #  Because of background issues, only linear option works.
                if options.CTF:
                    ave, var = sp_statistics.aves_wiener(
                        grp_imgdata, SNR=1.0e5, interpolation_method="linear")
                else:
                    ave, var = sp_statistics.ave_var(grp_imgdata)
                # Switch to std dev
                # threshold is not really needed,it is just in case due to numerical accuracy something turns out negative.
                var = sp_morphology.square_root(sp_morphology.threshold(var))

                sp_utilities.set_params_proj(ave,
                                             [phiM, thetaM, 0.0, 0.0, 0.0])
                sp_utilities.set_params_proj(var,
                                             [phiM, thetaM, 0.0, 0.0, 0.0])

                aveList.append(ave)
                varList.append(var)
                xform_proj_for_2D[i] = [phiM, thetaM, 0.0, 0.0, 0.0]
                """Multiline Comment2"""
                if (myid == heavy_load_myid) and (i % 100 == 0):
                    log_main.add(" ......%6.2f%%  " %
                                 (old_div(i, float(len(proj_list))) * 100.0))
            del imgdata, grp_imgdata, cpar, dpar, all_proj, proj_angles, index
            if not options.no_norm:
                del mask
            if myid == main_node:
                del tab
            #  At this point, all averages and variances are computed
            mpi.mpi_barrier(mpi.MPI_COMM_WORLD)

            if myid == heavy_load_myid:
                log_main.add("Computing aveList and varList took %12.1f [m]" %
                             (old_div((time.time() - ttt), 60.0)))

            xform_proj_for_2D = sp_utilities.wrap_mpi_gatherv(
                xform_proj_for_2D, main_node, mpi.MPI_COMM_WORLD)
            if myid == main_node:
                sp_utilities.write_text_row(
                    [str(entry) for entry in xform_proj_for_2D],
                    optparse.os.path.join(current_output_dir, "params.txt"),
                )
            del xform_proj_for_2D
            mpi.mpi_barrier(mpi.MPI_COMM_WORLD)
            if options.ave2D:
                if myid == main_node:
                    log_main.add("Compute ave2D ... ")
                    km = 0
                    for i in range(number_of_proc):
                        if i == main_node:
                            for im in range(len(aveList)):
                                aveList[im].write_image(
                                    optparse.os.path.join(
                                        current_output_dir, options.ave2D),
                                    km,
                                )
                                km += 1
                        else:
                            nl = mpi.mpi_recv(
                                1,
                                mpi.MPI_INT,
                                i,
                                sp_global_def.SPARX_MPI_TAG_UNIVERSAL,
                                mpi.MPI_COMM_WORLD,
                            )
                            nl = int(nl[0])
                            for im in range(nl):
                                ave = sp_utilities.recv_EMData(
                                    i, im + i + 70000)
                                """Multiline Comment3"""
                                tmpvol = sp_fundamentals.fpol(ave, nx, nx, 1)
                                tmpvol.write_image(
                                    optparse.os.path.join(
                                        current_output_dir, options.ave2D),
                                    km,
                                )
                                km += 1
                else:
                    mpi.mpi_send(
                        len(aveList),
                        1,
                        mpi.MPI_INT,
                        main_node,
                        sp_global_def.SPARX_MPI_TAG_UNIVERSAL,
                        mpi.MPI_COMM_WORLD,
                    )
                    for im in range(len(aveList)):
                        sp_utilities.send_EMData(aveList[im], main_node,
                                                 im + myid + 70000)
                        """Multiline Comment4"""
                if myid == main_node:
                    sp_applications.header(
                        optparse.os.path.join(current_output_dir,
                                              options.ave2D),
                        params="xform.projection",
                        fimport=optparse.os.path.join(current_output_dir,
                                                      "params.txt"),
                    )
                mpi.mpi_barrier(mpi.MPI_COMM_WORLD)
            if options.ave3D:
                t5 = time.time()
                if myid == main_node:
                    log_main.add("Reconstruct ave3D ... ")
                ave3D = sp_reconstruction.recons3d_4nn_MPI(
                    myid, aveList, symmetry=options.sym, npad=options.npad)
                sp_utilities.bcast_EMData_to_all(ave3D, myid)
                if myid == main_node:
                    if current_decimate != 1.0:
                        ave3D = sp_fundamentals.resample(
                            ave3D, old_div(1.0, current_decimate))
                    ave3D = sp_fundamentals.fpol(
                        ave3D, nnxo, nnxo,
                        nnxo)  # always to the orignal image size
                    sp_utilities.set_pixel_size(ave3D, 1.0)
                    ave3D.write_image(
                        optparse.os.path.join(current_output_dir,
                                              options.ave3D))
                    log_main.add("Ave3D reconstruction took %12.1f [m]" %
                                 (old_div((time.time() - t5), 60.0)))
                    log_main.add("%-70s:  %s\n" %
                                 ("The reconstructed ave3D is saved as ",
                                  options.ave3D))

            mpi.mpi_barrier(mpi.MPI_COMM_WORLD)
            del ave, var, proj_list, stack, alpha, sx, sy, mirror, aveList
            """Multiline Comment5"""

            if options.ave3D:
                del ave3D
            if options.var2D:
                if myid == main_node:
                    log_main.add("Compute var2D...")
                    km = 0
                    for i in range(number_of_proc):
                        if i == main_node:
                            for im in range(len(varList)):
                                tmpvol = sp_fundamentals.fpol(
                                    varList[im], nx, nx, 1)
                                tmpvol.write_image(
                                    optparse.os.path.join(
                                        current_output_dir, options.var2D),
                                    km,
                                )
                                km += 1
                        else:
                            nl = mpi.mpi_recv(
                                1,
                                mpi.MPI_INT,
                                i,
                                sp_global_def.SPARX_MPI_TAG_UNIVERSAL,
                                mpi.MPI_COMM_WORLD,
                            )
                            nl = int(nl[0])
                            for im in range(nl):
                                ave = sp_utilities.recv_EMData(
                                    i, im + i + 70000)
                                tmpvol = sp_fundamentals.fpol(ave, nx, nx, 1)
                                tmpvol.write_image(
                                    optparse.os.path.join(
                                        current_output_dir, options.var2D),
                                    km,
                                )
                                km += 1
                else:
                    mpi.mpi_send(
                        len(varList),
                        1,
                        mpi.MPI_INT,
                        main_node,
                        sp_global_def.SPARX_MPI_TAG_UNIVERSAL,
                        mpi.MPI_COMM_WORLD,
                    )
                    for im in range(len(varList)):
                        sp_utilities.send_EMData(
                            varList[im], main_node,
                            im + myid + 70000)  # What with the attributes??
                mpi.mpi_barrier(mpi.MPI_COMM_WORLD)
                if myid == main_node:
                    sp_applications.header(
                        optparse.os.path.join(current_output_dir,
                                              options.var2D),
                        params="xform.projection",
                        fimport=optparse.os.path.join(current_output_dir,
                                                      "params.txt"),
                    )
                mpi.mpi_barrier(mpi.MPI_COMM_WORLD)
        if options.var3D:
            if myid == main_node:
                log_main.add("Reconstruct var3D ...")
            t6 = time.time()
            # radiusvar = options.radius
            # if( radiusvar < 0 ):  radiusvar = nx//2 -3
            res = sp_reconstruction.recons3d_4nn_MPI(myid,
                                                     varList,
                                                     symmetry=options.sym,
                                                     npad=options.npad)
            # res = recons3d_em_MPI(varList, vol_stack, options.iter, radiusvar, options.abs, True, options.sym, options.squ)
            if myid == main_node:
                if current_decimate != 1.0:
                    res = sp_fundamentals.resample(
                        res, old_div(1.0, current_decimate))
                res = sp_fundamentals.fpol(res, nnxo, nnxo, nnxo)
                sp_utilities.set_pixel_size(res, 1.0)
                res.write_image(os.path.join(current_output_dir,
                                             options.var3D))
                log_main.add(
                    "%-70s:  %s\n" %
                    ("The reconstructed var3D is saved as ", options.var3D))
                log_main.add("Var3D reconstruction took %f12.1 [m]" % (old_div(
                    (time.time() - t6), 60.0)))
                log_main.add("Total computation time %f12.1 [m]" % (old_div(
                    (time.time() - t0), 60.0)))
                log_main.add("sx3dvariability finishes")

        if RUNNING_UNDER_MPI:
            sp_global_def.MPI = False

        sp_global_def.BATCH = False
Ejemplo n.º 5
0
def main():
    global Tracker, Blockdata
    progname = os.path.basename(sys.argv[0])
    usage = progname + " --output_dir=output_dir  --isac_dir=output_dir_of_isac "
    parser = optparse.OptionParser(usage, version=sp_global_def.SPARXVERSION)
    parser.add_option(
        "--pw_adjustment",
        type="string",
        default="analytical_model",
        help=
        "adjust power spectrum of 2-D averages to an analytic model. Other opions: no_adjustment; bfactor; a text file of 1D rotationally averaged PW",
    )
    #### Four options for --pw_adjustment:
    # 1> analytical_model(default);
    # 2> no_adjustment;
    # 3> bfactor;
    # 4> adjust_to_given_pw2(user has to provide a text file that contains 1D rotationally averaged PW)

    # options in common
    parser.add_option(
        "--isac_dir",
        type="string",
        default="",
        help="ISAC run output directory, input directory for this command",
    )
    parser.add_option(
        "--output_dir",
        type="string",
        default="",
        help="output directory where computed averages are saved",
    )
    parser.add_option(
        "--pixel_size",
        type="float",
        default=-1.0,
        help=
        "pixel_size of raw images. one can put 1.0 in case of negative stain data",
    )
    parser.add_option(
        "--fl",
        type="float",
        default=-1.0,
        help=
        "low pass filter, = -1.0, not applied; =0.0, using FH1 (initial resolution), = 1.0 using FH2 (resolution after local alignment), or user provided value in absolute freqency [0.0:0.5]",
    )
    parser.add_option("--stack",
                      type="string",
                      default="",
                      help="data stack used in ISAC")
    parser.add_option("--radius", type="int", default=-1, help="radius")
    parser.add_option("--xr",
                      type="float",
                      default=-1.0,
                      help="local alignment search range")
    # parser.add_option("--ts",                    type   ="float",          default =1.0,    help= "local alignment search step")
    parser.add_option(
        "--fh",
        type="float",
        default=-1.0,
        help="local alignment high frequencies limit",
    )
    # parser.add_option("--maxit",                 type   ="int",            default =5,      help= "local alignment iterations")
    parser.add_option("--navg",
                      type="int",
                      default=1000000,
                      help="number of aveages")
    parser.add_option(
        "--local_alignment",
        action="store_true",
        default=False,
        help="do local alignment",
    )
    parser.add_option(
        "--noctf",
        action="store_true",
        default=False,
        help=
        "no ctf correction, useful for negative stained data. always ctf for cryo data",
    )
    parser.add_option(
        "--B_start",
        type="float",
        default=45.0,
        help=
        "start frequency (Angstrom) of power spectrum for B_factor estimation",
    )
    parser.add_option(
        "--Bfactor",
        type="float",
        default=-1.0,
        help=
        "User defined bactors (e.g. 25.0[A^2]). By default, the program automatically estimates B-factor. ",
    )

    (options, args) = parser.parse_args(sys.argv[1:])

    adjust_to_analytic_model = (True if options.pw_adjustment
                                == "analytical_model" else False)
    no_adjustment = True if options.pw_adjustment == "no_adjustment" else False
    B_enhance = True if options.pw_adjustment == "bfactor" else False
    adjust_to_given_pw2 = (
        True if not (adjust_to_analytic_model or no_adjustment or B_enhance)
        else False)

    # mpi
    nproc = mpi.mpi_comm_size(mpi.MPI_COMM_WORLD)
    myid = mpi.mpi_comm_rank(mpi.MPI_COMM_WORLD)

    Blockdata = {}
    Blockdata["nproc"] = nproc
    Blockdata["myid"] = myid
    Blockdata["main_node"] = 0
    Blockdata["shared_comm"] = mpi.mpi_comm_split_type(
        mpi.MPI_COMM_WORLD, mpi.MPI_COMM_TYPE_SHARED, 0, mpi.MPI_INFO_NULL)
    Blockdata["myid_on_node"] = mpi.mpi_comm_rank(Blockdata["shared_comm"])
    Blockdata["no_of_processes_per_group"] = mpi.mpi_comm_size(
        Blockdata["shared_comm"])
    masters_from_groups_vs_everything_else_comm = mpi.mpi_comm_split(
        mpi.MPI_COMM_WORLD,
        Blockdata["main_node"] == Blockdata["myid_on_node"],
        Blockdata["myid_on_node"],
    )
    Blockdata["color"], Blockdata[
        "no_of_groups"], balanced_processor_load_on_nodes = sp_utilities.get_colors_and_subsets(
            Blockdata["main_node"],
            mpi.MPI_COMM_WORLD,
            Blockdata["myid"],
            Blockdata["shared_comm"],
            Blockdata["myid_on_node"],
            masters_from_groups_vs_everything_else_comm,
        )
    #  We need two nodes for processing of volumes
    Blockdata["node_volume"] = [
        Blockdata["no_of_groups"] - 3,
        Blockdata["no_of_groups"] - 2,
        Blockdata["no_of_groups"] - 1,
    ]  # For 3D stuff take three last nodes
    #  We need two CPUs for processing of volumes, they are taken to be main CPUs on each volume
    #  We have to send the two myids to all nodes so we can identify main nodes on two selected groups.
    Blockdata["nodes"] = [
        Blockdata["node_volume"][0] * Blockdata["no_of_processes_per_group"],
        Blockdata["node_volume"][1] * Blockdata["no_of_processes_per_group"],
        Blockdata["node_volume"][2] * Blockdata["no_of_processes_per_group"],
    ]
    # End of Blockdata: sorting requires at least three nodes, and the used number of nodes be integer times of three
    sp_global_def.BATCH = True
    sp_global_def.MPI = True

    if adjust_to_given_pw2:
        checking_flag = 0
        if Blockdata["myid"] == Blockdata["main_node"]:
            if not os.path.exists(options.pw_adjustment):
                checking_flag = 1
        checking_flag = sp_utilities.bcast_number_to_all(
            checking_flag, Blockdata["main_node"], mpi.MPI_COMM_WORLD)

        if checking_flag == 1:
            sp_global_def.ERROR("User provided power spectrum does not exist",
                                myid=Blockdata["myid"])

    Tracker = {}
    Constants = {}
    Constants["isac_dir"] = options.isac_dir
    Constants["masterdir"] = options.output_dir
    Constants["pixel_size"] = options.pixel_size
    Constants["orgstack"] = options.stack
    Constants["radius"] = options.radius
    Constants["xrange"] = options.xr
    Constants["FH"] = options.fh
    Constants["low_pass_filter"] = options.fl
    # Constants["maxit"]                        = options.maxit
    Constants["navg"] = options.navg
    Constants["B_start"] = options.B_start
    Constants["Bfactor"] = options.Bfactor

    if adjust_to_given_pw2:
        Constants["modelpw"] = options.pw_adjustment
    Tracker["constants"] = Constants
    # -------------------------------------------------------------
    #
    # Create and initialize Tracker dictionary with input options  # State Variables

    # <<<---------------------->>>imported functions<<<---------------------------------------------

    # x_range = max(Tracker["constants"]["xrange"], int(1./Tracker["ini_shrink"])+1)
    # y_range =  x_range

    ####-----------------------------------------------------------
    # Create Master directory and associated subdirectories
    line = time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()) + " =>"
    if Tracker["constants"]["masterdir"] == Tracker["constants"]["isac_dir"]:
        masterdir = os.path.join(Tracker["constants"]["isac_dir"], "sharpen")
    else:
        masterdir = Tracker["constants"]["masterdir"]

    if Blockdata["myid"] == Blockdata["main_node"]:
        msg = "Postprocessing ISAC 2D averages starts"
        sp_global_def.sxprint(line, "Postprocessing ISAC 2D averages starts")
        if not masterdir:
            timestring = time.strftime("_%d_%b_%Y_%H_%M_%S", time.localtime())
            masterdir = "sharpen_" + Tracker["constants"]["isac_dir"]
            os.makedirs(masterdir)
        else:
            if os.path.exists(masterdir):
                sp_global_def.sxprint("%s already exists" % masterdir)
            else:
                os.makedirs(masterdir)
        sp_global_def.write_command(masterdir)
        subdir_path = os.path.join(masterdir, "ali2d_local_params_avg")
        if not os.path.exists(subdir_path):
            os.mkdir(subdir_path)
        subdir_path = os.path.join(masterdir, "params_avg")
        if not os.path.exists(subdir_path):
            os.mkdir(subdir_path)
        li = len(masterdir)
    else:
        li = 0
    li = mpi.mpi_bcast(li, 1, mpi.MPI_INT, Blockdata["main_node"],
                       mpi.MPI_COMM_WORLD)[0]
    masterdir = mpi.mpi_bcast(masterdir, li, mpi.MPI_CHAR,
                              Blockdata["main_node"], mpi.MPI_COMM_WORLD)
    masterdir = b"".join(masterdir).decode('latin1')
    Tracker["constants"]["masterdir"] = masterdir
    log_main = sp_logger.Logger(sp_logger.BaseLogger_Files())
    log_main.prefix = Tracker["constants"]["masterdir"] + "/"

    while not os.path.exists(Tracker["constants"]["masterdir"]):
        sp_global_def.sxprint(
            "Node ",
            Blockdata["myid"],
            "  waiting...",
            Tracker["constants"]["masterdir"],
        )
        time.sleep(1)
    mpi.mpi_barrier(mpi.MPI_COMM_WORLD)

    if Blockdata["myid"] == Blockdata["main_node"]:
        init_dict = {}
        sp_global_def.sxprint(Tracker["constants"]["isac_dir"])
        Tracker["directory"] = os.path.join(Tracker["constants"]["isac_dir"],
                                            "2dalignment")
        core = sp_utilities.read_text_row(
            os.path.join(Tracker["directory"], "initial2Dparams.txt"))
        for im in range(len(core)):
            init_dict[im] = core[im]
        del core
    else:
        init_dict = 0
    init_dict = sp_utilities.wrap_mpi_bcast(init_dict,
                                            Blockdata["main_node"],
                                            communicator=mpi.MPI_COMM_WORLD)
    ###
    do_ctf = True
    if options.noctf:
        do_ctf = False
    if Blockdata["myid"] == Blockdata["main_node"]:
        if do_ctf:
            sp_global_def.sxprint("CTF correction is on")
        else:
            sp_global_def.sxprint("CTF correction is off")
        if options.local_alignment:
            sp_global_def.sxprint("local refinement is on")
        else:
            sp_global_def.sxprint("local refinement is off")
        if B_enhance:
            sp_global_def.sxprint("Bfactor is to be applied on averages")
        elif adjust_to_given_pw2:
            sp_global_def.sxprint(
                "PW of averages is adjusted to a given 1D PW curve")
        elif adjust_to_analytic_model:
            sp_global_def.sxprint(
                "PW of averages is adjusted to analytical model")
        else:
            sp_global_def.sxprint("PW of averages is not adjusted")
        # Tracker["constants"]["orgstack"] = "bdb:"+ os.path.join(Tracker["constants"]["isac_dir"],"../","sparx_stack")
        image = sp_utilities.get_im(Tracker["constants"]["orgstack"], 0)
        Tracker["constants"]["nnxo"] = image.get_xsize()
        if Tracker["constants"]["pixel_size"] == -1.0:
            sp_global_def.sxprint(
                "Pixel size value is not provided by user. extracting it from ctf header entry of the original stack."
            )
            try:
                ctf_params = image.get_attr("ctf")
                Tracker["constants"]["pixel_size"] = ctf_params.apix
            except:
                sp_global_def.ERROR(
                    "Pixel size could not be extracted from the original stack.",
                    myid=Blockdata["myid"],
                )
        ## Now fill in low-pass filter

        isac_shrink_path = os.path.join(Tracker["constants"]["isac_dir"],
                                        "README_shrink_ratio.txt")

        if not os.path.exists(isac_shrink_path):
            sp_global_def.ERROR(
                "%s does not exist in the specified ISAC run output directory"
                % (isac_shrink_path),
                myid=Blockdata["myid"],
            )

        isac_shrink_file = open(isac_shrink_path, "r")
        isac_shrink_lines = isac_shrink_file.readlines()
        isac_shrink_ratio = float(
            isac_shrink_lines[5]
        )  # 6th line: shrink ratio (= [target particle radius]/[particle radius]) used in the ISAC run
        isac_radius = float(
            isac_shrink_lines[6]
        )  # 7th line: particle radius at original pixel size used in the ISAC run
        isac_shrink_file.close()
        print("Extracted parameter values")
        print("ISAC shrink ratio    : {0}".format(isac_shrink_ratio))
        print("ISAC particle radius : {0}".format(isac_radius))
        Tracker["ini_shrink"] = isac_shrink_ratio
    else:
        Tracker["ini_shrink"] = 0.0
    Tracker = sp_utilities.wrap_mpi_bcast(Tracker,
                                          Blockdata["main_node"],
                                          communicator=mpi.MPI_COMM_WORLD)

    # print(Tracker["constants"]["pixel_size"], "pixel_size")
    x_range = max(
        Tracker["constants"]["xrange"],
        int(old_div(1.0, Tracker["ini_shrink"]) + 0.99999),
    )
    a_range = y_range = x_range

    if Blockdata["myid"] == Blockdata["main_node"]:
        parameters = sp_utilities.read_text_row(
            os.path.join(Tracker["constants"]["isac_dir"],
                         "all_parameters.txt"))
    else:
        parameters = 0
    parameters = sp_utilities.wrap_mpi_bcast(parameters,
                                             Blockdata["main_node"],
                                             communicator=mpi.MPI_COMM_WORLD)
    params_dict = {}
    list_dict = {}
    # parepare params_dict

    # navg = min(Tracker["constants"]["navg"]*Blockdata["nproc"], EMUtil.get_image_count(os.path.join(Tracker["constants"]["isac_dir"], "class_averages.hdf")))
    navg = min(
        Tracker["constants"]["navg"],
        EMAN2_cppwrap.EMUtil.get_image_count(
            os.path.join(Tracker["constants"]["isac_dir"],
                         "class_averages.hdf")),
    )
    global_dict = {}
    ptl_list = []
    memlist = []
    if Blockdata["myid"] == Blockdata["main_node"]:
        sp_global_def.sxprint("Number of averages computed in this run is %d" %
                              navg)
        for iavg in range(navg):
            params_of_this_average = []
            image = sp_utilities.get_im(
                os.path.join(Tracker["constants"]["isac_dir"],
                             "class_averages.hdf"),
                iavg,
            )
            members = sorted(image.get_attr("members"))
            memlist.append(members)
            for im in range(len(members)):
                abs_id = members[im]
                global_dict[abs_id] = [iavg, im]
                P = sp_utilities.combine_params2(
                    init_dict[abs_id][0],
                    init_dict[abs_id][1],
                    init_dict[abs_id][2],
                    init_dict[abs_id][3],
                    parameters[abs_id][0],
                    old_div(parameters[abs_id][1], Tracker["ini_shrink"]),
                    old_div(parameters[abs_id][2], Tracker["ini_shrink"]),
                    parameters[abs_id][3],
                )
                if parameters[abs_id][3] == -1:
                    sp_global_def.sxprint(
                        "WARNING: Image #{0} is an unaccounted particle with invalid 2D alignment parameters and should not be the member of any classes. Please check the consitency of input dataset."
                        .format(abs_id)
                    )  # How to check what is wrong about mirror = -1 (Toshio 2018/01/11)
                params_of_this_average.append([P[0], P[1], P[2], P[3], 1.0])
                ptl_list.append(abs_id)
            params_dict[iavg] = params_of_this_average
            list_dict[iavg] = members
            sp_utilities.write_text_row(
                params_of_this_average,
                os.path.join(
                    Tracker["constants"]["masterdir"],
                    "params_avg",
                    "params_avg_%03d.txt" % iavg,
                ),
            )
        ptl_list.sort()
        init_params = [None for im in range(len(ptl_list))]
        for im in range(len(ptl_list)):
            init_params[im] = [ptl_list[im]] + params_dict[global_dict[
                ptl_list[im]][0]][global_dict[ptl_list[im]][1]]
        sp_utilities.write_text_row(
            init_params,
            os.path.join(Tracker["constants"]["masterdir"],
                         "init_isac_params.txt"),
        )
    else:
        params_dict = 0
        list_dict = 0
        memlist = 0
    params_dict = sp_utilities.wrap_mpi_bcast(params_dict,
                                              Blockdata["main_node"],
                                              communicator=mpi.MPI_COMM_WORLD)
    list_dict = sp_utilities.wrap_mpi_bcast(list_dict,
                                            Blockdata["main_node"],
                                            communicator=mpi.MPI_COMM_WORLD)
    memlist = sp_utilities.wrap_mpi_bcast(memlist,
                                          Blockdata["main_node"],
                                          communicator=mpi.MPI_COMM_WORLD)
    # Now computing!
    del init_dict
    tag_sharpen_avg = 1000
    ## always apply low pass filter to B_enhanced images to suppress noise in high frequencies
    enforced_to_H1 = False
    if B_enhance:
        if Tracker["constants"]["low_pass_filter"] == -1.0:
            enforced_to_H1 = True

    # distribute workload among mpi processes
    image_start, image_end = sp_applications.MPI_start_end(
        navg, Blockdata["nproc"], Blockdata["myid"])

    if Blockdata["myid"] == Blockdata["main_node"]:
        cpu_dict = {}
        for iproc in range(Blockdata["nproc"]):
            local_image_start, local_image_end = sp_applications.MPI_start_end(
                navg, Blockdata["nproc"], iproc)
            for im in range(local_image_start, local_image_end):
                cpu_dict[im] = iproc
    else:
        cpu_dict = 0

    cpu_dict = sp_utilities.wrap_mpi_bcast(cpu_dict,
                                           Blockdata["main_node"],
                                           communicator=mpi.MPI_COMM_WORLD)

    slist = [None for im in range(navg)]
    ini_list = [None for im in range(navg)]
    avg1_list = [None for im in range(navg)]
    avg2_list = [None for im in range(navg)]
    data_list = [None for im in range(navg)]
    plist_dict = {}

    if Blockdata["myid"] == Blockdata["main_node"]:
        if B_enhance:
            sp_global_def.sxprint(
                "Avg ID   B-factor  FH1(Res before ali) FH2(Res after ali)")
        else:
            sp_global_def.sxprint(
                "Avg ID   FH1(Res before ali)  FH2(Res after ali)")

    FH_list = [[0, 0.0, 0.0] for im in range(navg)]
    for iavg in range(image_start, image_end):

        mlist = EMAN2_cppwrap.EMData.read_images(
            Tracker["constants"]["orgstack"], list_dict[iavg])

        for im in range(len(mlist)):
            sp_utilities.set_params2D(mlist[im],
                                      params_dict[iavg][im],
                                      xform="xform.align2d")

        if options.local_alignment:
            new_avg, plist, FH2 = sp_applications.refinement_2d_local(
                mlist,
                Tracker["constants"]["radius"],
                a_range,
                x_range,
                y_range,
                CTF=do_ctf,
                SNR=1.0e10,
            )
            plist_dict[iavg] = plist
            FH1 = -1.0

        else:
            new_avg, frc, plist = compute_average(
                mlist, Tracker["constants"]["radius"], do_ctf)
            FH1 = get_optimistic_res(frc)
            FH2 = -1.0

        FH_list[iavg] = [iavg, FH1, FH2]

        if B_enhance:
            new_avg, gb = apply_enhancement(
                new_avg,
                Tracker["constants"]["B_start"],
                Tracker["constants"]["pixel_size"],
                Tracker["constants"]["Bfactor"],
            )
            sp_global_def.sxprint("  %6d      %6.3f  %4.3f  %4.3f" %
                                  (iavg, gb, FH1, FH2))

        elif adjust_to_given_pw2:
            roo = sp_utilities.read_text_file(Tracker["constants"]["modelpw"],
                                              -1)
            roo = roo[0]  # always on the first column
            new_avg = adjust_pw_to_model(new_avg,
                                         Tracker["constants"]["pixel_size"],
                                         roo)
            sp_global_def.sxprint("  %6d      %4.3f  %4.3f  " %
                                  (iavg, FH1, FH2))

        elif adjust_to_analytic_model:
            new_avg = adjust_pw_to_model(new_avg,
                                         Tracker["constants"]["pixel_size"],
                                         None)
            sp_global_def.sxprint("  %6d      %4.3f  %4.3f   " %
                                  (iavg, FH1, FH2))

        elif no_adjustment:
            pass

        if Tracker["constants"]["low_pass_filter"] != -1.0:
            if Tracker["constants"]["low_pass_filter"] == 0.0:
                low_pass_filter = FH1
            elif Tracker["constants"]["low_pass_filter"] == 1.0:
                low_pass_filter = FH2
                if not options.local_alignment:
                    low_pass_filter = FH1
            else:
                low_pass_filter = Tracker["constants"]["low_pass_filter"]
                if low_pass_filter >= 0.45:
                    low_pass_filter = 0.45
            new_avg = sp_filter.filt_tanl(new_avg, low_pass_filter, 0.02)
        else:  # No low pass filter but if enforced
            if enforced_to_H1:
                new_avg = sp_filter.filt_tanl(new_avg, FH1, 0.02)
        if B_enhance:
            new_avg = sp_fundamentals.fft(new_avg)

        new_avg.set_attr("members", list_dict[iavg])
        new_avg.set_attr("n_objects", len(list_dict[iavg]))
        slist[iavg] = new_avg
        sp_global_def.sxprint(
            time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()) + " =>",
            "Refined average %7d" % iavg,
        )

    ## send to main node to write
    mpi.mpi_barrier(mpi.MPI_COMM_WORLD)

    for im in range(navg):
        # avg
        if (cpu_dict[im] == Blockdata["myid"]
                and Blockdata["myid"] != Blockdata["main_node"]):
            sp_utilities.send_EMData(slist[im], Blockdata["main_node"],
                                     tag_sharpen_avg)

        elif (cpu_dict[im] == Blockdata["myid"]
              and Blockdata["myid"] == Blockdata["main_node"]):
            slist[im].set_attr("members", memlist[im])
            slist[im].set_attr("n_objects", len(memlist[im]))
            slist[im].write_image(
                os.path.join(Tracker["constants"]["masterdir"],
                             "class_averages.hdf"),
                im,
            )

        elif (cpu_dict[im] != Blockdata["myid"]
              and Blockdata["myid"] == Blockdata["main_node"]):
            new_avg_other_cpu = sp_utilities.recv_EMData(
                cpu_dict[im], tag_sharpen_avg)
            new_avg_other_cpu.set_attr("members", memlist[im])
            new_avg_other_cpu.set_attr("n_objects", len(memlist[im]))
            new_avg_other_cpu.write_image(
                os.path.join(Tracker["constants"]["masterdir"],
                             "class_averages.hdf"),
                im,
            )

        if options.local_alignment:
            if cpu_dict[im] == Blockdata["myid"]:
                sp_utilities.write_text_row(
                    plist_dict[im],
                    os.path.join(
                        Tracker["constants"]["masterdir"],
                        "ali2d_local_params_avg",
                        "ali2d_local_params_avg_%03d.txt" % im,
                    ),
                )

            if (cpu_dict[im] == Blockdata["myid"]
                    and cpu_dict[im] != Blockdata["main_node"]):
                sp_utilities.wrap_mpi_send(plist_dict[im],
                                           Blockdata["main_node"],
                                           mpi.MPI_COMM_WORLD)
                sp_utilities.wrap_mpi_send(FH_list, Blockdata["main_node"],
                                           mpi.MPI_COMM_WORLD)

            elif (cpu_dict[im] != Blockdata["main_node"]
                  and Blockdata["myid"] == Blockdata["main_node"]):
                dummy = sp_utilities.wrap_mpi_recv(cpu_dict[im],
                                                   mpi.MPI_COMM_WORLD)
                plist_dict[im] = dummy
                dummy = sp_utilities.wrap_mpi_recv(cpu_dict[im],
                                                   mpi.MPI_COMM_WORLD)
                FH_list[im] = dummy[im]
        else:
            if (cpu_dict[im] == Blockdata["myid"]
                    and cpu_dict[im] != Blockdata["main_node"]):
                sp_utilities.wrap_mpi_send(FH_list, Blockdata["main_node"],
                                           mpi.MPI_COMM_WORLD)

            elif (cpu_dict[im] != Blockdata["main_node"]
                  and Blockdata["myid"] == Blockdata["main_node"]):
                dummy = sp_utilities.wrap_mpi_recv(cpu_dict[im],
                                                   mpi.MPI_COMM_WORLD)
                FH_list[im] = dummy[im]

        mpi.mpi_barrier(mpi.MPI_COMM_WORLD)
    mpi.mpi_barrier(mpi.MPI_COMM_WORLD)

    if options.local_alignment:
        if Blockdata["myid"] == Blockdata["main_node"]:
            ali3d_local_params = [None for im in range(len(ptl_list))]
            for im in range(len(ptl_list)):
                ali3d_local_params[im] = [ptl_list[im]] + plist_dict[
                    global_dict[ptl_list[im]][0]][global_dict[ptl_list[im]][1]]
            sp_utilities.write_text_row(
                ali3d_local_params,
                os.path.join(Tracker["constants"]["masterdir"],
                             "ali2d_local_params.txt"),
            )
            sp_utilities.write_text_row(
                FH_list,
                os.path.join(Tracker["constants"]["masterdir"], "FH_list.txt"))
    else:
        if Blockdata["myid"] == Blockdata["main_node"]:
            sp_utilities.write_text_row(
                FH_list,
                os.path.join(Tracker["constants"]["masterdir"], "FH_list.txt"))

    mpi.mpi_barrier(mpi.MPI_COMM_WORLD)
    target_xr = 3
    target_yr = 3
    if Blockdata["myid"] == 0:
        cmd = "{} {} {} {} {} {} {} {} {} {}".format(
            "sp_chains.py",
            os.path.join(Tracker["constants"]["masterdir"],
                         "class_averages.hdf"),
            os.path.join(Tracker["constants"]["masterdir"], "junk.hdf"),
            os.path.join(Tracker["constants"]["masterdir"],
                         "ordered_class_averages.hdf"),
            "--circular",
            "--radius=%d" % Tracker["constants"]["radius"],
            "--xr=%d" % (target_xr + 1),
            "--yr=%d" % (target_yr + 1),
            "--align",
            ">/dev/null",
        )
        junk = sp_utilities.cmdexecute(cmd)
        cmd = "{} {}".format(
            "rm -rf",
            os.path.join(Tracker["constants"]["masterdir"], "junk.hdf"))
        junk = sp_utilities.cmdexecute(cmd)

    return