Ejemplo n.º 1
0
def prepref(data, maskfile, cnx, cny, numr, mode, maxrangex, maxrangey, step):
    # step = 1
    mashi = cnx - numr[-3] - 2
    nima = len(data)
    istep = int(old_div(1.0, step))
    dimage = [[[None for j in range(2 * maxrangey * istep + 1)]
               for i in range(2 * maxrangex * istep + 1)]
              for im in range(nima)]
    for im in range(nima):
        sts = EMAN2_cppwrap.Util.infomask(data[im], maskfile, False)
        data[im] -= sts[0]
        data[im] = old_div(data[im], sts[1])
        alpha, sx, sy, mirror, dummy = sp_utilities.get_params2D(data[im])
        # alpha, sx, sy, dummy         = combine_params2(alpha, sx, sy, mirror, 0.0, -cs[0], -cs[1], 0)
        alphai, sxi, syi, dummy = sp_utilities.combine_params2(
            0.0, sx, sy, 0, -alpha, 0, 0, 0)
        #  introduce constraints on parameters to accomodate use of cs centering
        sxi = min(max(sxi, -mashi), mashi)
        syi = min(max(syi, -mashi), mashi)
        for j in range(-maxrangey * istep, maxrangey * istep + 1):
            iy = j * step
            for i in range(-maxrangex * istep, maxrangex * istep + 1):
                ix = i * step
                dimage[im][i +
                           maxrangex][j +
                                      maxrangey] = EMAN2_cppwrap.Util.Polar2Dm(
                                          data[im], cnx + sxi + ix,
                                          cny + syi + iy, numr, mode)
                # print ' prepref  ',j,i,j+maxrangey,i+maxrangex
                EMAN2_cppwrap.Util.Frngs(
                    dimage[im][i + maxrangex][j + maxrangey], numr)
        dimage[im][0][0].set_attr("sxi", sxi)
        dimage[im][0][0].set_attr("syi", syi)

    return dimage
Ejemplo n.º 2
0
def align_diff_params(ali_params1, ali_params2):
    '''
	This function determines the relative angle, shifts and mirrorness between
	two sets of alignment parameters.	
	'''
    pass  #IMPORTIMPORTIMPORT from math import cos, sin, pi
    pass  #IMPORTIMPORTIMPORT from sp_utilities import combine_params2

    nima = len(ali_params1)
    nima2 = len(ali_params2)
    if nima2 != nima:
        sp_global_def.sxprint("Error: Number of images do not agree!")
        return 0.0, 0.0, 0.0, 0
    else:
        nima /= 4
        del nima2

    # Read the alignment parameters and determine the relative mirrorness
    mirror_same = 0
    for i in range(nima):
        if ali_params1[i * 4 + 3] == ali_params2[i * 4 + 3]: mirror_same += 1
    if mirror_same > nima / 2:
        mirror = 0
    else:
        mirror_same = nima - mirror_same
        mirror = 1

    # Determine the relative angle
    cosi = 0.0
    sini = 0.0
    angle1 = []
    angle2 = []
    for i in range(nima):
        mirror1 = ali_params1[i * 4 + 3]
        mirror2 = ali_params2[i * 4 + 3]
        if abs(mirror1 - mirror2) == mirror:
            alpha1 = ali_params1[i * 4]
            alpha2 = ali_params2[i * 4]
            if mirror1 == 1:
                alpha1 = -alpha1
                alpha2 = -alpha2
            angle1.append(alpha1)
            angle2.append(alpha2)
    alphai = angle_diff(angle1, angle2)

    # Determine the relative shift
    sxi = 0.0
    syi = 0.0
    for i in range(nima):
        mirror1 = ali_params1[i * 4 + 3]
        mirror2 = ali_params2[i * 4 + 3]
        if abs(mirror1 - mirror2) == mirror:
            alpha1 = ali_params1[i * 4]
            #alpha2 = ali_params2[i*4]
            sx1 = ali_params1[i * 4 + 1]
            sx2 = ali_params2[i * 4 + 1]
            sy1 = ali_params1[i * 4 + 2]
            sy2 = ali_params2[i * 4 + 2]
            alpha12, sx12, sy12, mirror12 = sp_utilities.combine_params2(
                alpha1, sx1, sy1, int(mirror1), alphai, 0.0, 0.0, 0)
            if mirror1 == 0: sxi += sx2 - sx12
            else: sxi -= sx2 - sx12
            syi += sy2 - sy12

    sxi /= mirror_same
    syi /= mirror_same

    return alphai, sxi, syi, mirror
Ejemplo n.º 3
0
def main():
    global Tracker, Blockdata
    progname = os.path.basename(sys.argv[0])
    usage = progname + " --output_dir=output_dir  --isac_dir=output_dir_of_isac "
    parser = optparse.OptionParser(usage, version=sp_global_def.SPARXVERSION)
    parser.add_option(
        "--pw_adjustment",
        type="string",
        default="analytical_model",
        help=
        "adjust power spectrum of 2-D averages to an analytic model. Other opions: no_adjustment; bfactor; a text file of 1D rotationally averaged PW",
    )
    #### Four options for --pw_adjustment:
    # 1> analytical_model(default);
    # 2> no_adjustment;
    # 3> bfactor;
    # 4> adjust_to_given_pw2(user has to provide a text file that contains 1D rotationally averaged PW)

    # options in common
    parser.add_option(
        "--isac_dir",
        type="string",
        default="",
        help="ISAC run output directory, input directory for this command",
    )
    parser.add_option(
        "--output_dir",
        type="string",
        default="",
        help="output directory where computed averages are saved",
    )
    parser.add_option(
        "--pixel_size",
        type="float",
        default=-1.0,
        help=
        "pixel_size of raw images. one can put 1.0 in case of negative stain data",
    )
    parser.add_option(
        "--fl",
        type="float",
        default=-1.0,
        help=
        "low pass filter, = -1.0, not applied; =0.0, using FH1 (initial resolution), = 1.0 using FH2 (resolution after local alignment), or user provided value in absolute freqency [0.0:0.5]",
    )
    parser.add_option("--stack",
                      type="string",
                      default="",
                      help="data stack used in ISAC")
    parser.add_option("--radius", type="int", default=-1, help="radius")
    parser.add_option("--xr",
                      type="float",
                      default=-1.0,
                      help="local alignment search range")
    # parser.add_option("--ts",                    type   ="float",          default =1.0,    help= "local alignment search step")
    parser.add_option(
        "--fh",
        type="float",
        default=-1.0,
        help="local alignment high frequencies limit",
    )
    # parser.add_option("--maxit",                 type   ="int",            default =5,      help= "local alignment iterations")
    parser.add_option("--navg",
                      type="int",
                      default=1000000,
                      help="number of aveages")
    parser.add_option(
        "--local_alignment",
        action="store_true",
        default=False,
        help="do local alignment",
    )
    parser.add_option(
        "--noctf",
        action="store_true",
        default=False,
        help=
        "no ctf correction, useful for negative stained data. always ctf for cryo data",
    )
    parser.add_option(
        "--B_start",
        type="float",
        default=45.0,
        help=
        "start frequency (Angstrom) of power spectrum for B_factor estimation",
    )
    parser.add_option(
        "--Bfactor",
        type="float",
        default=-1.0,
        help=
        "User defined bactors (e.g. 25.0[A^2]). By default, the program automatically estimates B-factor. ",
    )

    (options, args) = parser.parse_args(sys.argv[1:])

    adjust_to_analytic_model = (True if options.pw_adjustment
                                == "analytical_model" else False)
    no_adjustment = True if options.pw_adjustment == "no_adjustment" else False
    B_enhance = True if options.pw_adjustment == "bfactor" else False
    adjust_to_given_pw2 = (
        True if not (adjust_to_analytic_model or no_adjustment or B_enhance)
        else False)

    # mpi
    nproc = mpi.mpi_comm_size(mpi.MPI_COMM_WORLD)
    myid = mpi.mpi_comm_rank(mpi.MPI_COMM_WORLD)

    Blockdata = {}
    Blockdata["nproc"] = nproc
    Blockdata["myid"] = myid
    Blockdata["main_node"] = 0
    Blockdata["shared_comm"] = mpi.mpi_comm_split_type(
        mpi.MPI_COMM_WORLD, mpi.MPI_COMM_TYPE_SHARED, 0, mpi.MPI_INFO_NULL)
    Blockdata["myid_on_node"] = mpi.mpi_comm_rank(Blockdata["shared_comm"])
    Blockdata["no_of_processes_per_group"] = mpi.mpi_comm_size(
        Blockdata["shared_comm"])
    masters_from_groups_vs_everything_else_comm = mpi.mpi_comm_split(
        mpi.MPI_COMM_WORLD,
        Blockdata["main_node"] == Blockdata["myid_on_node"],
        Blockdata["myid_on_node"],
    )
    Blockdata["color"], Blockdata[
        "no_of_groups"], balanced_processor_load_on_nodes = sp_utilities.get_colors_and_subsets(
            Blockdata["main_node"],
            mpi.MPI_COMM_WORLD,
            Blockdata["myid"],
            Blockdata["shared_comm"],
            Blockdata["myid_on_node"],
            masters_from_groups_vs_everything_else_comm,
        )
    #  We need two nodes for processing of volumes
    Blockdata["node_volume"] = [
        Blockdata["no_of_groups"] - 3,
        Blockdata["no_of_groups"] - 2,
        Blockdata["no_of_groups"] - 1,
    ]  # For 3D stuff take three last nodes
    #  We need two CPUs for processing of volumes, they are taken to be main CPUs on each volume
    #  We have to send the two myids to all nodes so we can identify main nodes on two selected groups.
    Blockdata["nodes"] = [
        Blockdata["node_volume"][0] * Blockdata["no_of_processes_per_group"],
        Blockdata["node_volume"][1] * Blockdata["no_of_processes_per_group"],
        Blockdata["node_volume"][2] * Blockdata["no_of_processes_per_group"],
    ]
    # End of Blockdata: sorting requires at least three nodes, and the used number of nodes be integer times of three
    sp_global_def.BATCH = True
    sp_global_def.MPI = True

    if adjust_to_given_pw2:
        checking_flag = 0
        if Blockdata["myid"] == Blockdata["main_node"]:
            if not os.path.exists(options.pw_adjustment):
                checking_flag = 1
        checking_flag = sp_utilities.bcast_number_to_all(
            checking_flag, Blockdata["main_node"], mpi.MPI_COMM_WORLD)

        if checking_flag == 1:
            sp_global_def.ERROR("User provided power spectrum does not exist",
                                myid=Blockdata["myid"])

    Tracker = {}
    Constants = {}
    Constants["isac_dir"] = options.isac_dir
    Constants["masterdir"] = options.output_dir
    Constants["pixel_size"] = options.pixel_size
    Constants["orgstack"] = options.stack
    Constants["radius"] = options.radius
    Constants["xrange"] = options.xr
    Constants["FH"] = options.fh
    Constants["low_pass_filter"] = options.fl
    # Constants["maxit"]                        = options.maxit
    Constants["navg"] = options.navg
    Constants["B_start"] = options.B_start
    Constants["Bfactor"] = options.Bfactor

    if adjust_to_given_pw2:
        Constants["modelpw"] = options.pw_adjustment
    Tracker["constants"] = Constants
    # -------------------------------------------------------------
    #
    # Create and initialize Tracker dictionary with input options  # State Variables

    # <<<---------------------->>>imported functions<<<---------------------------------------------

    # x_range = max(Tracker["constants"]["xrange"], int(1./Tracker["ini_shrink"])+1)
    # y_range =  x_range

    ####-----------------------------------------------------------
    # Create Master directory and associated subdirectories
    line = time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()) + " =>"
    if Tracker["constants"]["masterdir"] == Tracker["constants"]["isac_dir"]:
        masterdir = os.path.join(Tracker["constants"]["isac_dir"], "sharpen")
    else:
        masterdir = Tracker["constants"]["masterdir"]

    if Blockdata["myid"] == Blockdata["main_node"]:
        msg = "Postprocessing ISAC 2D averages starts"
        sp_global_def.sxprint(line, "Postprocessing ISAC 2D averages starts")
        if not masterdir:
            timestring = time.strftime("_%d_%b_%Y_%H_%M_%S", time.localtime())
            masterdir = "sharpen_" + Tracker["constants"]["isac_dir"]
            os.makedirs(masterdir)
        else:
            if os.path.exists(masterdir):
                sp_global_def.sxprint("%s already exists" % masterdir)
            else:
                os.makedirs(masterdir)
        sp_global_def.write_command(masterdir)
        subdir_path = os.path.join(masterdir, "ali2d_local_params_avg")
        if not os.path.exists(subdir_path):
            os.mkdir(subdir_path)
        subdir_path = os.path.join(masterdir, "params_avg")
        if not os.path.exists(subdir_path):
            os.mkdir(subdir_path)
        li = len(masterdir)
    else:
        li = 0
    li = mpi.mpi_bcast(li, 1, mpi.MPI_INT, Blockdata["main_node"],
                       mpi.MPI_COMM_WORLD)[0]
    masterdir = mpi.mpi_bcast(masterdir, li, mpi.MPI_CHAR,
                              Blockdata["main_node"], mpi.MPI_COMM_WORLD)
    masterdir = b"".join(masterdir).decode('latin1')
    Tracker["constants"]["masterdir"] = masterdir
    log_main = sp_logger.Logger(sp_logger.BaseLogger_Files())
    log_main.prefix = Tracker["constants"]["masterdir"] + "/"

    while not os.path.exists(Tracker["constants"]["masterdir"]):
        sp_global_def.sxprint(
            "Node ",
            Blockdata["myid"],
            "  waiting...",
            Tracker["constants"]["masterdir"],
        )
        time.sleep(1)
    mpi.mpi_barrier(mpi.MPI_COMM_WORLD)

    if Blockdata["myid"] == Blockdata["main_node"]:
        init_dict = {}
        sp_global_def.sxprint(Tracker["constants"]["isac_dir"])
        Tracker["directory"] = os.path.join(Tracker["constants"]["isac_dir"],
                                            "2dalignment")
        core = sp_utilities.read_text_row(
            os.path.join(Tracker["directory"], "initial2Dparams.txt"))
        for im in range(len(core)):
            init_dict[im] = core[im]
        del core
    else:
        init_dict = 0
    init_dict = sp_utilities.wrap_mpi_bcast(init_dict,
                                            Blockdata["main_node"],
                                            communicator=mpi.MPI_COMM_WORLD)
    ###
    do_ctf = True
    if options.noctf:
        do_ctf = False
    if Blockdata["myid"] == Blockdata["main_node"]:
        if do_ctf:
            sp_global_def.sxprint("CTF correction is on")
        else:
            sp_global_def.sxprint("CTF correction is off")
        if options.local_alignment:
            sp_global_def.sxprint("local refinement is on")
        else:
            sp_global_def.sxprint("local refinement is off")
        if B_enhance:
            sp_global_def.sxprint("Bfactor is to be applied on averages")
        elif adjust_to_given_pw2:
            sp_global_def.sxprint(
                "PW of averages is adjusted to a given 1D PW curve")
        elif adjust_to_analytic_model:
            sp_global_def.sxprint(
                "PW of averages is adjusted to analytical model")
        else:
            sp_global_def.sxprint("PW of averages is not adjusted")
        # Tracker["constants"]["orgstack"] = "bdb:"+ os.path.join(Tracker["constants"]["isac_dir"],"../","sparx_stack")
        image = sp_utilities.get_im(Tracker["constants"]["orgstack"], 0)
        Tracker["constants"]["nnxo"] = image.get_xsize()
        if Tracker["constants"]["pixel_size"] == -1.0:
            sp_global_def.sxprint(
                "Pixel size value is not provided by user. extracting it from ctf header entry of the original stack."
            )
            try:
                ctf_params = image.get_attr("ctf")
                Tracker["constants"]["pixel_size"] = ctf_params.apix
            except:
                sp_global_def.ERROR(
                    "Pixel size could not be extracted from the original stack.",
                    myid=Blockdata["myid"],
                )
        ## Now fill in low-pass filter

        isac_shrink_path = os.path.join(Tracker["constants"]["isac_dir"],
                                        "README_shrink_ratio.txt")

        if not os.path.exists(isac_shrink_path):
            sp_global_def.ERROR(
                "%s does not exist in the specified ISAC run output directory"
                % (isac_shrink_path),
                myid=Blockdata["myid"],
            )

        isac_shrink_file = open(isac_shrink_path, "r")
        isac_shrink_lines = isac_shrink_file.readlines()
        isac_shrink_ratio = float(
            isac_shrink_lines[5]
        )  # 6th line: shrink ratio (= [target particle radius]/[particle radius]) used in the ISAC run
        isac_radius = float(
            isac_shrink_lines[6]
        )  # 7th line: particle radius at original pixel size used in the ISAC run
        isac_shrink_file.close()
        print("Extracted parameter values")
        print("ISAC shrink ratio    : {0}".format(isac_shrink_ratio))
        print("ISAC particle radius : {0}".format(isac_radius))
        Tracker["ini_shrink"] = isac_shrink_ratio
    else:
        Tracker["ini_shrink"] = 0.0
    Tracker = sp_utilities.wrap_mpi_bcast(Tracker,
                                          Blockdata["main_node"],
                                          communicator=mpi.MPI_COMM_WORLD)

    # print(Tracker["constants"]["pixel_size"], "pixel_size")
    x_range = max(
        Tracker["constants"]["xrange"],
        int(old_div(1.0, Tracker["ini_shrink"]) + 0.99999),
    )
    a_range = y_range = x_range

    if Blockdata["myid"] == Blockdata["main_node"]:
        parameters = sp_utilities.read_text_row(
            os.path.join(Tracker["constants"]["isac_dir"],
                         "all_parameters.txt"))
    else:
        parameters = 0
    parameters = sp_utilities.wrap_mpi_bcast(parameters,
                                             Blockdata["main_node"],
                                             communicator=mpi.MPI_COMM_WORLD)
    params_dict = {}
    list_dict = {}
    # parepare params_dict

    # navg = min(Tracker["constants"]["navg"]*Blockdata["nproc"], EMUtil.get_image_count(os.path.join(Tracker["constants"]["isac_dir"], "class_averages.hdf")))
    navg = min(
        Tracker["constants"]["navg"],
        EMAN2_cppwrap.EMUtil.get_image_count(
            os.path.join(Tracker["constants"]["isac_dir"],
                         "class_averages.hdf")),
    )
    global_dict = {}
    ptl_list = []
    memlist = []
    if Blockdata["myid"] == Blockdata["main_node"]:
        sp_global_def.sxprint("Number of averages computed in this run is %d" %
                              navg)
        for iavg in range(navg):
            params_of_this_average = []
            image = sp_utilities.get_im(
                os.path.join(Tracker["constants"]["isac_dir"],
                             "class_averages.hdf"),
                iavg,
            )
            members = sorted(image.get_attr("members"))
            memlist.append(members)
            for im in range(len(members)):
                abs_id = members[im]
                global_dict[abs_id] = [iavg, im]
                P = sp_utilities.combine_params2(
                    init_dict[abs_id][0],
                    init_dict[abs_id][1],
                    init_dict[abs_id][2],
                    init_dict[abs_id][3],
                    parameters[abs_id][0],
                    old_div(parameters[abs_id][1], Tracker["ini_shrink"]),
                    old_div(parameters[abs_id][2], Tracker["ini_shrink"]),
                    parameters[abs_id][3],
                )
                if parameters[abs_id][3] == -1:
                    sp_global_def.sxprint(
                        "WARNING: Image #{0} is an unaccounted particle with invalid 2D alignment parameters and should not be the member of any classes. Please check the consitency of input dataset."
                        .format(abs_id)
                    )  # How to check what is wrong about mirror = -1 (Toshio 2018/01/11)
                params_of_this_average.append([P[0], P[1], P[2], P[3], 1.0])
                ptl_list.append(abs_id)
            params_dict[iavg] = params_of_this_average
            list_dict[iavg] = members
            sp_utilities.write_text_row(
                params_of_this_average,
                os.path.join(
                    Tracker["constants"]["masterdir"],
                    "params_avg",
                    "params_avg_%03d.txt" % iavg,
                ),
            )
        ptl_list.sort()
        init_params = [None for im in range(len(ptl_list))]
        for im in range(len(ptl_list)):
            init_params[im] = [ptl_list[im]] + params_dict[global_dict[
                ptl_list[im]][0]][global_dict[ptl_list[im]][1]]
        sp_utilities.write_text_row(
            init_params,
            os.path.join(Tracker["constants"]["masterdir"],
                         "init_isac_params.txt"),
        )
    else:
        params_dict = 0
        list_dict = 0
        memlist = 0
    params_dict = sp_utilities.wrap_mpi_bcast(params_dict,
                                              Blockdata["main_node"],
                                              communicator=mpi.MPI_COMM_WORLD)
    list_dict = sp_utilities.wrap_mpi_bcast(list_dict,
                                            Blockdata["main_node"],
                                            communicator=mpi.MPI_COMM_WORLD)
    memlist = sp_utilities.wrap_mpi_bcast(memlist,
                                          Blockdata["main_node"],
                                          communicator=mpi.MPI_COMM_WORLD)
    # Now computing!
    del init_dict
    tag_sharpen_avg = 1000
    ## always apply low pass filter to B_enhanced images to suppress noise in high frequencies
    enforced_to_H1 = False
    if B_enhance:
        if Tracker["constants"]["low_pass_filter"] == -1.0:
            enforced_to_H1 = True

    # distribute workload among mpi processes
    image_start, image_end = sp_applications.MPI_start_end(
        navg, Blockdata["nproc"], Blockdata["myid"])

    if Blockdata["myid"] == Blockdata["main_node"]:
        cpu_dict = {}
        for iproc in range(Blockdata["nproc"]):
            local_image_start, local_image_end = sp_applications.MPI_start_end(
                navg, Blockdata["nproc"], iproc)
            for im in range(local_image_start, local_image_end):
                cpu_dict[im] = iproc
    else:
        cpu_dict = 0

    cpu_dict = sp_utilities.wrap_mpi_bcast(cpu_dict,
                                           Blockdata["main_node"],
                                           communicator=mpi.MPI_COMM_WORLD)

    slist = [None for im in range(navg)]
    ini_list = [None for im in range(navg)]
    avg1_list = [None for im in range(navg)]
    avg2_list = [None for im in range(navg)]
    data_list = [None for im in range(navg)]
    plist_dict = {}

    if Blockdata["myid"] == Blockdata["main_node"]:
        if B_enhance:
            sp_global_def.sxprint(
                "Avg ID   B-factor  FH1(Res before ali) FH2(Res after ali)")
        else:
            sp_global_def.sxprint(
                "Avg ID   FH1(Res before ali)  FH2(Res after ali)")

    FH_list = [[0, 0.0, 0.0] for im in range(navg)]
    for iavg in range(image_start, image_end):

        mlist = EMAN2_cppwrap.EMData.read_images(
            Tracker["constants"]["orgstack"], list_dict[iavg])

        for im in range(len(mlist)):
            sp_utilities.set_params2D(mlist[im],
                                      params_dict[iavg][im],
                                      xform="xform.align2d")

        if options.local_alignment:
            new_avg, plist, FH2 = sp_applications.refinement_2d_local(
                mlist,
                Tracker["constants"]["radius"],
                a_range,
                x_range,
                y_range,
                CTF=do_ctf,
                SNR=1.0e10,
            )
            plist_dict[iavg] = plist
            FH1 = -1.0

        else:
            new_avg, frc, plist = compute_average(
                mlist, Tracker["constants"]["radius"], do_ctf)
            FH1 = get_optimistic_res(frc)
            FH2 = -1.0

        FH_list[iavg] = [iavg, FH1, FH2]

        if B_enhance:
            new_avg, gb = apply_enhancement(
                new_avg,
                Tracker["constants"]["B_start"],
                Tracker["constants"]["pixel_size"],
                Tracker["constants"]["Bfactor"],
            )
            sp_global_def.sxprint("  %6d      %6.3f  %4.3f  %4.3f" %
                                  (iavg, gb, FH1, FH2))

        elif adjust_to_given_pw2:
            roo = sp_utilities.read_text_file(Tracker["constants"]["modelpw"],
                                              -1)
            roo = roo[0]  # always on the first column
            new_avg = adjust_pw_to_model(new_avg,
                                         Tracker["constants"]["pixel_size"],
                                         roo)
            sp_global_def.sxprint("  %6d      %4.3f  %4.3f  " %
                                  (iavg, FH1, FH2))

        elif adjust_to_analytic_model:
            new_avg = adjust_pw_to_model(new_avg,
                                         Tracker["constants"]["pixel_size"],
                                         None)
            sp_global_def.sxprint("  %6d      %4.3f  %4.3f   " %
                                  (iavg, FH1, FH2))

        elif no_adjustment:
            pass

        if Tracker["constants"]["low_pass_filter"] != -1.0:
            if Tracker["constants"]["low_pass_filter"] == 0.0:
                low_pass_filter = FH1
            elif Tracker["constants"]["low_pass_filter"] == 1.0:
                low_pass_filter = FH2
                if not options.local_alignment:
                    low_pass_filter = FH1
            else:
                low_pass_filter = Tracker["constants"]["low_pass_filter"]
                if low_pass_filter >= 0.45:
                    low_pass_filter = 0.45
            new_avg = sp_filter.filt_tanl(new_avg, low_pass_filter, 0.02)
        else:  # No low pass filter but if enforced
            if enforced_to_H1:
                new_avg = sp_filter.filt_tanl(new_avg, FH1, 0.02)
        if B_enhance:
            new_avg = sp_fundamentals.fft(new_avg)

        new_avg.set_attr("members", list_dict[iavg])
        new_avg.set_attr("n_objects", len(list_dict[iavg]))
        slist[iavg] = new_avg
        sp_global_def.sxprint(
            time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()) + " =>",
            "Refined average %7d" % iavg,
        )

    ## send to main node to write
    mpi.mpi_barrier(mpi.MPI_COMM_WORLD)

    for im in range(navg):
        # avg
        if (cpu_dict[im] == Blockdata["myid"]
                and Blockdata["myid"] != Blockdata["main_node"]):
            sp_utilities.send_EMData(slist[im], Blockdata["main_node"],
                                     tag_sharpen_avg)

        elif (cpu_dict[im] == Blockdata["myid"]
              and Blockdata["myid"] == Blockdata["main_node"]):
            slist[im].set_attr("members", memlist[im])
            slist[im].set_attr("n_objects", len(memlist[im]))
            slist[im].write_image(
                os.path.join(Tracker["constants"]["masterdir"],
                             "class_averages.hdf"),
                im,
            )

        elif (cpu_dict[im] != Blockdata["myid"]
              and Blockdata["myid"] == Blockdata["main_node"]):
            new_avg_other_cpu = sp_utilities.recv_EMData(
                cpu_dict[im], tag_sharpen_avg)
            new_avg_other_cpu.set_attr("members", memlist[im])
            new_avg_other_cpu.set_attr("n_objects", len(memlist[im]))
            new_avg_other_cpu.write_image(
                os.path.join(Tracker["constants"]["masterdir"],
                             "class_averages.hdf"),
                im,
            )

        if options.local_alignment:
            if cpu_dict[im] == Blockdata["myid"]:
                sp_utilities.write_text_row(
                    plist_dict[im],
                    os.path.join(
                        Tracker["constants"]["masterdir"],
                        "ali2d_local_params_avg",
                        "ali2d_local_params_avg_%03d.txt" % im,
                    ),
                )

            if (cpu_dict[im] == Blockdata["myid"]
                    and cpu_dict[im] != Blockdata["main_node"]):
                sp_utilities.wrap_mpi_send(plist_dict[im],
                                           Blockdata["main_node"],
                                           mpi.MPI_COMM_WORLD)
                sp_utilities.wrap_mpi_send(FH_list, Blockdata["main_node"],
                                           mpi.MPI_COMM_WORLD)

            elif (cpu_dict[im] != Blockdata["main_node"]
                  and Blockdata["myid"] == Blockdata["main_node"]):
                dummy = sp_utilities.wrap_mpi_recv(cpu_dict[im],
                                                   mpi.MPI_COMM_WORLD)
                plist_dict[im] = dummy
                dummy = sp_utilities.wrap_mpi_recv(cpu_dict[im],
                                                   mpi.MPI_COMM_WORLD)
                FH_list[im] = dummy[im]
        else:
            if (cpu_dict[im] == Blockdata["myid"]
                    and cpu_dict[im] != Blockdata["main_node"]):
                sp_utilities.wrap_mpi_send(FH_list, Blockdata["main_node"],
                                           mpi.MPI_COMM_WORLD)

            elif (cpu_dict[im] != Blockdata["main_node"]
                  and Blockdata["myid"] == Blockdata["main_node"]):
                dummy = sp_utilities.wrap_mpi_recv(cpu_dict[im],
                                                   mpi.MPI_COMM_WORLD)
                FH_list[im] = dummy[im]

        mpi.mpi_barrier(mpi.MPI_COMM_WORLD)
    mpi.mpi_barrier(mpi.MPI_COMM_WORLD)

    if options.local_alignment:
        if Blockdata["myid"] == Blockdata["main_node"]:
            ali3d_local_params = [None for im in range(len(ptl_list))]
            for im in range(len(ptl_list)):
                ali3d_local_params[im] = [ptl_list[im]] + plist_dict[
                    global_dict[ptl_list[im]][0]][global_dict[ptl_list[im]][1]]
            sp_utilities.write_text_row(
                ali3d_local_params,
                os.path.join(Tracker["constants"]["masterdir"],
                             "ali2d_local_params.txt"),
            )
            sp_utilities.write_text_row(
                FH_list,
                os.path.join(Tracker["constants"]["masterdir"], "FH_list.txt"))
    else:
        if Blockdata["myid"] == Blockdata["main_node"]:
            sp_utilities.write_text_row(
                FH_list,
                os.path.join(Tracker["constants"]["masterdir"], "FH_list.txt"))

    mpi.mpi_barrier(mpi.MPI_COMM_WORLD)
    target_xr = 3
    target_yr = 3
    if Blockdata["myid"] == 0:
        cmd = "{} {} {} {} {} {} {} {} {} {}".format(
            "sp_chains.py",
            os.path.join(Tracker["constants"]["masterdir"],
                         "class_averages.hdf"),
            os.path.join(Tracker["constants"]["masterdir"], "junk.hdf"),
            os.path.join(Tracker["constants"]["masterdir"],
                         "ordered_class_averages.hdf"),
            "--circular",
            "--radius=%d" % Tracker["constants"]["radius"],
            "--xr=%d" % (target_xr + 1),
            "--yr=%d" % (target_yr + 1),
            "--align",
            ">/dev/null",
        )
        junk = sp_utilities.cmdexecute(cmd)
        cmd = "{} {}".format(
            "rm -rf",
            os.path.join(Tracker["constants"]["masterdir"], "junk.hdf"))
        junk = sp_utilities.cmdexecute(cmd)

    return
Ejemplo n.º 4
0
def ali2d_single_iter(
    data,
    numr,
    wr,
    cs,
    tavg,
    cnx,
    cny,
    xrng,
    yrng,
    step,
    nomirror=False,
    mode="F",
    CTF=False,
    random_method="",
    T=1.0,
    ali_params="xform.align2d",
    delta=0.0,
):
    """
		single iteration of 2D alignment using ormq
		if CTF = True, apply CTF to data (not to reference!)
	"""

    maxrin = numr[-1]  #  length
    ou = numr[-3]  #  maximum radius
    if random_method == "SCF":
        frotim = [sp_fundamentals.fft(tavg)]
        xrng = int(xrng + 0.5)
        yrng = int(yrng + 0.5)
        cimage = EMAN2_cppwrap.Util.Polar2Dm(sp_fundamentals.scf(tavg), cnx,
                                             cny, numr, mode)
        EMAN2_cppwrap.Util.Frngs(cimage, numr)
        EMAN2_cppwrap.Util.Applyws(cimage, numr, wr)
    else:
        # 2D alignment using rotational ccf in polar coords and quadratic interpolation
        cimage = EMAN2_cppwrap.Util.Polar2Dm(tavg, cnx, cny, numr, mode)
        EMAN2_cppwrap.Util.Frngs(cimage, numr)
        EMAN2_cppwrap.Util.Applyws(cimage, numr, wr)

    sx_sum = 0.0
    sy_sum = 0.0
    sxn = 0.0
    syn = 0.0
    mn = 0
    nope = 0
    mashi = cnx - ou - 2
    for im in range(len(data)):
        if CTF:
            # Apply CTF to image
            ctf_params = data[im].get_attr("ctf")
            ima = sp_filter.filt_ctf(data[im], ctf_params, True)
        else:
            ima = data[im]

        if random_method == "PCP":
            sxi = data[im][0][0].get_attr("sxi")
            syi = data[im][0][0].get_attr("syi")
            nx = ny = data[im][0][0].get_attr("inx")
        else:
            nx = ima.get_xsize()
            ny = ima.get_ysize()
            alpha, sx, sy, mirror, dummy = sp_utilities.get_params2D(
                data[im], ali_params)
            alpha, sx, sy, dummy = sp_utilities.combine_params2(
                alpha, sx, sy, mirror, 0.0, -cs[0], -cs[1], 0)
            alphai, sxi, syi, scalei = sp_utilities.inverse_transform2(
                alpha, sx, sy)
            #  introduce constraints on parameters to accomodate use of cs centering
            sxi = min(max(sxi, -mashi), mashi)
            syi = min(max(syi, -mashi), mashi)

        #  The search range procedure was adjusted for 3D searches, so since in 2D the order of operations is inverted, we have to invert ranges
        txrng = search_range(nx, ou, sxi, xrng, "ali2d_single_iter")
        txrng = [txrng[1], txrng[0]]
        tyrng = search_range(ny, ou, syi, yrng, "ali2d_single_iter")
        tyrng = [tyrng[1], tyrng[0]]
        # print im, "B",cnx,sxi,syi,txrng, tyrng
        # align current image to the reference
        if random_method == "SHC":
            """Multiline Comment0"""
            #  For shc combining of shifts is problematic as the image may randomly slide away and never come back.
            #  A possibility would be to reject moves that results in too large departure from the center.
            #  On the other hand, one cannot simply do searches around the proper center all the time,
            #    as if xr is decreased, the image cannot be brought back if the established shifts are further than new range
            olo = EMAN2_cppwrap.Util.shc(
                ima,
                [cimage],
                txrng,
                tyrng,
                step,
                -1.0,
                mode,
                numr,
                cnx + sxi,
                cny + syi,
                "c1",
            )
            ##olo = Util.shc(ima, [cimage], xrng, yrng, step, -1.0, mode, numr, cnx, cny, "c1")
            if data[im].get_attr("previousmax") < olo[5]:
                # [angt, sxst, syst, mirrort, peakt] = ormq(ima, cimage, xrng, yrng, step, mode, numr, cnx+sxi, cny+syi, delta)
                # print  angt, sxst, syst, mirrort, peakt,olo
                angt = olo[0]
                sxst = olo[1]
                syst = olo[2]
                mirrort = int(olo[3])
                # combine parameters and set them to the header, ignore previous angle and mirror
                [alphan, sxn, syn,
                 mn] = sp_utilities.combine_params2(0.0, -sxi, -syi, 0, angt,
                                                    sxst, syst, mirrort)
                sp_utilities.set_params2D(data[im],
                                          [alphan, sxn, syn, mn, 1.0],
                                          ali_params)
                ##set_params2D(data[im], [angt, sxst, syst, mirrort, 1.0], ali_params)
                data[im].set_attr("previousmax", olo[5])
            else:
                # Did not find a better peak, but we have to set shifted parameters, as the average shifted
                sp_utilities.set_params2D(data[im],
                                          [alpha, sx, sy, mirror, 1.0],
                                          ali_params)
                nope += 1
                mn = 0
                sxn = 0.0
                syn = 0.0
        elif random_method == "SCF":
            sxst, syst, iref, angt, mirrort, totpeak = multalign2d_scf(
                data[im], [cimage], frotim, numr, xrng, yrng, ou=ou)
            [alphan, sxn, syn,
             mn] = sp_utilities.combine_params2(0.0, -sxi, -syi, 0, angt, sxst,
                                                syst, mirrort)
            sp_utilities.set_params2D(data[im], [alphan, sxn, syn, mn, 1.0],
                                      ali_params)
        elif random_method == "PCP":
            [angt, sxst, syst, mirrort,
             peakt] = ormq_fast(data[im], cimage, txrng, tyrng, step, numr,
                                mode, delta)
            sxst = rings[0][0][0].get_attr("sxi")
            syst = rings[0][0][0].get_attr("syi")
            sp_global_def.sxprint(sxst, syst, sx, sy)
            dummy, sxs, sys, dummy = sp_utilities.inverse_transform2(
                -angt, sx + sxst, sy + syst)
            sp_utilities.set_params2D(data[im][0][0],
                                      [angt, sxs, sys, mirrort, 1.0],
                                      ali_params)
        else:
            if nomirror:
                [angt, sxst, syst, mirrort,
                 peakt] = ornq(ima, cimage, txrng, tyrng, step, mode, numr,
                               cnx + sxi, cny + syi)
            else:
                [angt, sxst, syst, mirrort, peakt] = ormq(
                    ima,
                    cimage,
                    txrng,
                    tyrng,
                    step,
                    mode,
                    numr,
                    cnx + sxi,
                    cny + syi,
                    delta,
                )
            # combine parameters and set them to the header, ignore previous angle and mirror
            [alphan, sxn, syn,
             mn] = sp_utilities.combine_params2(0.0, -sxi, -syi, 0, angt, sxst,
                                                syst, mirrort)
            sp_utilities.set_params2D(data[im], [alphan, sxn, syn, mn, 1.0],
                                      ali_params)

        if mn == 0:
            sx_sum += sxn
        else:
            sx_sum -= sxn
        sy_sum += syn

    return sx_sum, sy_sum, nope
Ejemplo n.º 5
0
def combine_isac_params(isac_dir,
                        classavgstack,
                        chains_params_file,
                        old_combined_parts,
                        classdoc,
                        combined_params_file,
                        log=False,
                        verbose=False):
    """
	Combines initial and all_params from ISAC.
	
	Arguments:
		isac_dir : ISAC directory
		classavgstack : Input image stack
		chains_params_file : Input alignment parameters applied to averages in sp_chains
		old_combined_parts
		classdoc
		combined_params_file : Output combined alignment parameters
		log : instance of Logger class
		verbose : (boolean) Whether to write to screen
	"""

    from sp_utilities import combine_params2, read_text_row

    # File-handling
    init_params_file = os.path.join(isac_dir, "2dalignment",
                                    "initial2Dparams.txt")
    all_params_file = os.path.join(isac_dir, "all_parameters.txt")
    init_params_list = read_text_row(init_params_file)
    all_params = read_text_row(all_params_file)
    isac_shrink_path = os.path.join(isac_dir, "README_shrink_ratio.txt")
    isac_shrink_file = open(isac_shrink_path, "r")
    isac_shrink_lines = isac_shrink_file.readlines()
    isac_shrink_ratio = float(isac_shrink_lines[5])
    """
	Three cases:
		1) Using class_averages.hdf
		2) Using ordered_class_averages.hdf, but chains_params.txt doesn't exist
		3) Using ordered_class_averages.hdf and chains_params.txt 
	"""

    msg = "Combining alignment parameters from %s and %s, dividing by %s, and writing to %s" % \
     (init_params_file, all_params_file, isac_shrink_ratio, combined_params_file)

    # Check if using ordered class averages and whether chains_params exists
    if os.path.basename(classavgstack) == 'ordered_class_averages.hdf':
        if not os.path.exists(chains_params_file):
            msg += "WARNING: '%s' does not exist. " % chains_params_file
            msg += "         Using '%s' but alignment parameters correspond to 'class_averages.hdf'.\n" % classavgstack
        else:
            msg = "Combining alignment parameters from %s, %s, and %s, dividing by %s, and writing to %s" % \
             (init_params_file, all_params_file, chains_params_file, isac_shrink_ratio, combined_params_file)

    print_log_msg(msg, log, verbose)

    if os.path.basename(
            classavgstack) == 'ordered_class_averages.hdf' and os.path.exists(
                chains_params_file):
        chains_params_list = read_text_row(chains_params_file)
        old_combined_list = read_text_row(old_combined_parts)
        num_classes = EMUtil.get_image_count(classavgstack)
        tmp_combined = []

        # Loop through classes
        for class_num in range(num_classes):
            # Extract members
            image = get_im(classavgstack, class_num)
            members = sorted(image.get_attr("members"))
            old_class_list = read_text_row(classdoc.format(class_num))
            new_class_list = []

            # Loop through particles
            for idx, im in enumerate(members):
                tmp_par = combine_params2(
                    init_params_list[im][0], init_params_list[im][1],
                    init_params_list[im][2], init_params_list[im][3],
                    all_params[im][0], all_params[im][1] / isac_shrink_ratio,
                    all_params[im][2] / isac_shrink_ratio, all_params[im][3])

                # Combine with class-average parameters
                P = combine_params2(tmp_par[0], tmp_par[1], tmp_par[2],
                                    tmp_par[3],
                                    chains_params_list[class_num][2],
                                    chains_params_list[class_num][3],
                                    chains_params_list[class_num][4],
                                    chains_params_list[class_num][5])

                tmp_combined.append([im, P[0], P[1], P[2], P[3]])

                # Need to update class number in class docs
                old_part_num = old_class_list[idx]

                try:
                    new_part_num = old_combined_list.index(old_part_num)
                except ValueError:
                    print(
                        "Couldn't find particle: class_num %s, old_part_num %s, new_part_num %s"
                        % (class_num, old_part_num[0], new_part_num))

                new_class_list.append(new_part_num)
            # End particle-loop

            # Overwrite pre-existing class doc
            write_text_row(new_class_list, classdoc.format(class_num))
        # End class-loop

        # Sort by particle number
        combined_params = sorted(tmp_combined, key=itemgetter(0))

        # Remove first column
        for row in combined_params:
            del row[0]

    # Not applying alignments of ordered_class_averages
    else:
        combined_params = []

        # Loop through images
        for im in range(len(all_params)):
            P = combine_params2(
                init_params_list[im][0], init_params_list[im][1],
                init_params_list[im][2], init_params_list[im][3],
                all_params[im][0], all_params[im][1] / isac_shrink_ratio,
                all_params[im][2] / isac_shrink_ratio, all_params[im][3])
            combined_params.append([P[0], P[1], P[2], P[3], 1.0])

    write_text_row(combined_params, combined_params_file)
    print_log_msg(
        'Wrote %s entries to %s\n' %
        (len(combined_params), combined_params_file), log, verbose)

    return combined_params
Ejemplo n.º 6
0
def mref_ali2d_MPI(stack,
                   refim,
                   outdir,
                   maskfile=None,
                   ir=1,
                   ou=-1,
                   rs=1,
                   xrng=0,
                   yrng=0,
                   step=1,
                   center=1,
                   maxit=10,
                   CTF=False,
                   snr=1.0,
                   user_func_name="ref_ali2d",
                   rand_seed=1000):
    # 2D multi-reference alignment using rotational ccf in polar coordinates and quadratic interpolation

    from sp_utilities import model_circle, combine_params2, inverse_transform2, drop_image, get_image, get_im
    from sp_utilities import reduce_EMData_to_root, bcast_EMData_to_all, bcast_number_to_all
    from sp_utilities import send_attr_dict
    from sp_utilities import center_2D
    from sp_statistics import fsc_mask
    from sp_alignment import Numrinit, ringwe, search_range
    from sp_fundamentals import rot_shift2D, fshift
    from sp_utilities import get_params2D, set_params2D
    from random import seed, randint
    from sp_morphology import ctf_2
    from sp_filter import filt_btwl, filt_params
    from numpy import reshape, shape
    from sp_utilities import print_msg, print_begin_msg, print_end_msg
    import os
    import sys
    import shutil
    from sp_applications import MPI_start_end
    from mpi import mpi_comm_size, mpi_comm_rank, MPI_COMM_WORLD
    from mpi import mpi_reduce, mpi_bcast, mpi_barrier, mpi_recv, mpi_send
    from mpi import MPI_SUM, MPI_FLOAT, MPI_INT

    number_of_proc = mpi_comm_size(MPI_COMM_WORLD)
    myid = mpi_comm_rank(MPI_COMM_WORLD)
    main_node = 0

    # create the output directory, if it does not exist

    if os.path.exists(outdir):
        ERROR(
            'Output directory exists, please change the name and restart the program',
            "mref_ali2d_MPI ", 1, myid)
    mpi_barrier(MPI_COMM_WORLD)

    import sp_global_def
    if myid == main_node:
        os.mkdir(outdir)
        sp_global_def.LOGFILE = os.path.join(outdir, sp_global_def.LOGFILE)
        print_begin_msg("mref_ali2d_MPI")

    nima = EMUtil.get_image_count(stack)

    image_start, image_end = MPI_start_end(nima, number_of_proc, myid)

    nima = EMUtil.get_image_count(stack)
    ima = EMData()
    ima.read_image(stack, image_start)

    first_ring = int(ir)
    last_ring = int(ou)
    rstep = int(rs)
    max_iter = int(maxit)

    if max_iter == 0:
        max_iter = 10
        auto_stop = True
    else:
        auto_stop = False

    if myid == main_node:
        print_msg("Input stack                 : %s\n" % (stack))
        print_msg("Reference stack             : %s\n" % (refim))
        print_msg("Output directory            : %s\n" % (outdir))
        print_msg("Maskfile                    : %s\n" % (maskfile))
        print_msg("Inner radius                : %i\n" % (first_ring))

    nx = ima.get_xsize()
    # default value for the last ring
    if last_ring == -1: last_ring = nx / 2 - 2

    if myid == main_node:
        print_msg("Outer radius                : %i\n" % (last_ring))
        print_msg("Ring step                   : %i\n" % (rstep))
        print_msg("X search range              : %f\n" % (xrng))
        print_msg("Y search range              : %f\n" % (yrng))
        print_msg("Translational step          : %f\n" % (step))
        print_msg("Center type                 : %i\n" % (center))
        print_msg("Maximum iteration           : %i\n" % (max_iter))
        print_msg("CTF correction              : %s\n" % (CTF))
        print_msg("Signal-to-Noise Ratio       : %f\n" % (snr))
        print_msg("Random seed                 : %i\n\n" % (rand_seed))
        print_msg("User function               : %s\n" % (user_func_name))
    import sp_user_functions
    user_func = sp_user_functions.factory[user_func_name]

    if maskfile:
        import types
        if type(maskfile) is bytes: mask = get_image(maskfile)
        else: mask = maskfile
    else: mask = model_circle(last_ring, nx, nx)
    #  references, do them on all processors...
    refi = []
    numref = EMUtil.get_image_count(refim)

    # IMAGES ARE SQUARES! center is in SPIDER convention
    cnx = nx / 2 + 1
    cny = cnx

    mode = "F"
    #precalculate rings
    numr = Numrinit(first_ring, last_ring, rstep, mode)
    wr = ringwe(numr, mode)

    # prepare reference images on all nodes
    ima.to_zero()
    for j in range(numref):
        #  even, odd, numer of even, number of images.  After frc, totav
        refi.append([get_im(refim, j), ima.copy(), 0])
    #  for each node read its share of data
    data = EMData.read_images(stack, list(range(image_start, image_end)))
    for im in range(image_start, image_end):
        data[im - image_start].set_attr('ID', im)

    if myid == main_node: seed(rand_seed)

    a0 = -1.0
    again = True
    Iter = 0

    ref_data = [mask, center, None, None]

    while Iter < max_iter and again:
        ringref = []
        mashi = cnx - last_ring - 2
        for j in range(numref):
            refi[j][0].process_inplace("normalize.mask", {
                "mask": mask,
                "no_sigma": 1
            })  # normalize reference images to N(0,1)
            cimage = Util.Polar2Dm(refi[j][0], cnx, cny, numr, mode)
            Util.Frngs(cimage, numr)
            Util.Applyws(cimage, numr, wr)
            ringref.append(cimage)
            # zero refi
            refi[j][0].to_zero()
            refi[j][1].to_zero()
            refi[j][2] = 0

        assign = [[] for i in range(numref)]
        # begin MPI section
        for im in range(image_start, image_end):
            alpha, sx, sy, mirror, scale = get_params2D(data[im - image_start])
            #  Why inverse?  07/11/2015 PAP
            alphai, sxi, syi, scalei = inverse_transform2(alpha, sx, sy)
            # normalize
            data[im - image_start].process_inplace("normalize.mask", {
                "mask": mask,
                "no_sigma": 0
            })  # subtract average under the mask
            # If shifts are outside of the permissible range, reset them
            if (abs(sxi) > mashi or abs(syi) > mashi):
                sxi = 0.0
                syi = 0.0
                set_params2D(data[im - image_start], [0.0, 0.0, 0.0, 0, 1.0])
            ny = nx
            txrng = search_range(nx, last_ring, sxi, xrng, "mref_ali2d_MPI")
            txrng = [txrng[1], txrng[0]]
            tyrng = search_range(ny, last_ring, syi, yrng, "mref_ali2d_MPI")
            tyrng = [tyrng[1], tyrng[0]]
            # align current image to the reference
            [angt, sxst, syst, mirrort, xiref,
             peakt] = Util.multiref_polar_ali_2d(data[im - image_start],
                                                 ringref, txrng, tyrng, step,
                                                 mode, numr, cnx + sxi,
                                                 cny + syi)

            iref = int(xiref)
            # combine parameters and set them to the header, ignore previous angle and mirror
            [alphan, sxn, syn,
             mn] = combine_params2(0.0, -sxi, -syi, 0, angt, sxst, syst,
                                   (int)(mirrort))
            set_params2D(data[im - image_start],
                         [alphan, sxn, syn, int(mn), scale])
            data[im - image_start].set_attr('assign', iref)
            # apply current parameters and add to the average
            temp = rot_shift2D(data[im - image_start], alphan, sxn, syn, mn)
            it = im % 2
            Util.add_img(refi[iref][it], temp)
            assign[iref].append(im)
            #assign[im] = iref
            refi[iref][2] += 1.0
        del ringref
        # end MPI section, bring partial things together, calculate new reference images, broadcast them back

        for j in range(numref):
            reduce_EMData_to_root(refi[j][0], myid, main_node)
            reduce_EMData_to_root(refi[j][1], myid, main_node)
            refi[j][2] = mpi_reduce(refi[j][2], 1, MPI_FLOAT, MPI_SUM,
                                    main_node, MPI_COMM_WORLD)
            if (myid == main_node): refi[j][2] = int(refi[j][2][0])
        # gather assignements
        for j in range(numref):
            if myid == main_node:
                for n in range(number_of_proc):
                    if n != main_node:
                        import sp_global_def
                        ln = mpi_recv(1, MPI_INT, n,
                                      sp_global_def.SPARX_MPI_TAG_UNIVERSAL,
                                      MPI_COMM_WORLD)
                        lis = mpi_recv(ln[0], MPI_INT, n,
                                       sp_global_def.SPARX_MPI_TAG_UNIVERSAL,
                                       MPI_COMM_WORLD)
                        for l in range(ln[0]):
                            assign[j].append(int(lis[l]))
            else:
                import sp_global_def
                mpi_send(len(assign[j]), 1, MPI_INT, main_node,
                         sp_global_def.SPARX_MPI_TAG_UNIVERSAL, MPI_COMM_WORLD)
                mpi_send(assign[j], len(assign[j]), MPI_INT, main_node,
                         sp_global_def.SPARX_MPI_TAG_UNIVERSAL, MPI_COMM_WORLD)

        if myid == main_node:
            # replace the name of the stack with reference with the current one
            refim = os.path.join(outdir, "aqm%03d.hdf" % Iter)
            a1 = 0.0
            ave_fsc = []
            for j in range(numref):
                if refi[j][2] < 4:
                    #ERROR("One of the references vanished","mref_ali2d_MPI",1)
                    #  if vanished, put a random image (only from main node!) there
                    assign[j] = []
                    assign[j].append(
                        randint(image_start, image_end - 1) - image_start)
                    refi[j][0] = data[assign[j][0]].copy()
                    #print 'ERROR', j
                else:
                    #frsc = fsc_mask(refi[j][0], refi[j][1], mask, 1.0, os.path.join(outdir,"drm%03d%04d"%(Iter, j)))
                    from sp_statistics import fsc
                    frsc = fsc(
                        refi[j][0], refi[j][1], 1.0,
                        os.path.join(outdir, "drm%03d%04d.txt" % (Iter, j)))
                    Util.add_img(refi[j][0], refi[j][1])
                    Util.mul_scalar(refi[j][0], 1.0 / float(refi[j][2]))

                    if ave_fsc == []:
                        for i in range(len(frsc[1])):
                            ave_fsc.append(frsc[1][i])
                        c_fsc = 1
                    else:
                        for i in range(len(frsc[1])):
                            ave_fsc[i] += frsc[1][i]
                        c_fsc += 1
                    #print 'OK', j, len(frsc[1]), frsc[1][0:5], ave_fsc[0:5]

            #print 'sum', sum(ave_fsc)
            if sum(ave_fsc) != 0:
                for i in range(len(ave_fsc)):
                    ave_fsc[i] /= float(c_fsc)
                    frsc[1][i] = ave_fsc[i]

            for j in range(numref):
                ref_data[2] = refi[j][0]
                ref_data[3] = frsc
                refi[j][0], cs = user_func(ref_data)

                # write the current average
                TMP = []
                for i_tmp in range(len(assign[j])):
                    TMP.append(float(assign[j][i_tmp]))
                TMP.sort()
                refi[j][0].set_attr_dict({'ave_n': refi[j][2], 'members': TMP})
                del TMP
                refi[j][0].process_inplace("normalize.mask", {
                    "mask": mask,
                    "no_sigma": 1
                })
                refi[j][0].write_image(refim, j)

            Iter += 1
            msg = "ITERATION #%3d        %d\n\n" % (Iter, again)
            print_msg(msg)
            for j in range(numref):
                msg = "   group #%3d   number of particles = %7d\n" % (
                    j, refi[j][2])
                print_msg(msg)
        Iter = bcast_number_to_all(Iter, main_node)  # need to tell all
        if again:
            for j in range(numref):
                bcast_EMData_to_all(refi[j][0], myid, main_node)

    #  clean up
    del assign
    # write out headers  and STOP, under MPI writing has to be done sequentially (time-consumming)
    mpi_barrier(MPI_COMM_WORLD)
    if CTF and data_had_ctf == 0:
        for im in range(len(data)):
            data[im].set_attr('ctf_applied', 0)
    par_str = ['xform.align2d', 'assign', 'ID']
    if myid == main_node:
        from sp_utilities import file_type
        if (file_type(stack) == "bdb"):
            from sp_utilities import recv_attr_dict_bdb
            recv_attr_dict_bdb(main_node, stack, data, par_str, image_start,
                               image_end, number_of_proc)
        else:
            from sp_utilities import recv_attr_dict
            recv_attr_dict(main_node, stack, data, par_str, image_start,
                           image_end, number_of_proc)
    else:
        send_attr_dict(main_node, data, par_str, image_start, image_end)
    if myid == main_node:
        print_end_msg("mref_ali2d_MPI")
Ejemplo n.º 7
0
def mref_ali2d(stack,
               refim,
               outdir,
               maskfile=None,
               ir=1,
               ou=-1,
               rs=1,
               xrng=0,
               yrng=0,
               step=1,
               center=1,
               maxit=0,
               CTF=False,
               snr=1.0,
               user_func_name="ref_ali2d",
               rand_seed=1000,
               MPI=False):
    """
        Name
            mref_ali2d - Perform 2-D multi-reference alignment of an image series
        Input
            stack: set of 2-D images in a stack file, images have to be squares
            refim: set of initial reference 2-D images in a stack file 
            maskfile: optional maskfile to be used in the alignment
            inner_radius: inner radius for rotational correlation > 0
            outer_radius: outer radius for rotational correlation < nx/2-1
            ring_step: step between rings in rotational correlation >0
            x_range: range for translation search in x direction, search is +/xr 
            y_range: range for translation search in y direction, search is +/yr 
            translation_step: step of translation search in both directions
            center: center the average
            max_iter: maximum number of iterations the program will perform
            CTF: if this flag is set, the program will use CTF information provided in file headers
            snr: signal-to-noise ratio of the data
            rand_seed: the seed used for generating random numbers
            MPI: whether to use MPI version
        Output
            output_directory: directory name into which the output files will be written.
            header: the alignment parameters are stored in the headers of input files as 'xform.align2d'.
    """
    # 2D multi-reference alignment using rotational ccf in polar coordinates and quadratic interpolation
    if MPI:
        mref_ali2d_MPI(stack, refim, outdir, maskfile, ir, ou, rs, xrng, yrng,
                       step, center, maxit, CTF, snr, user_func_name,
                       rand_seed)
        return

    from sp_utilities import model_circle, combine_params2, inverse_transform2, drop_image, get_image
    from sp_utilities import center_2D, get_im, get_params2D, set_params2D
    from sp_statistics import fsc
    from sp_alignment import Numrinit, ringwe, fine_2D_refinement, search_range
    from sp_fundamentals import rot_shift2D, fshift
    from random import seed, randint
    import os
    import sys

    from sp_utilities import print_begin_msg, print_end_msg, print_msg
    import shutil

    # create the output directory, if it does not exist
    if os.path.exists(outdir):
        shutil.rmtree(
            outdir
        )  #ERROR('Output directory exists, please change the name and restart the program', "mref_ali2d", 1)
    os.mkdir(outdir)
    import sp_global_def
    sp_global_def.LOGFILE = os.path.join(outdir, sp_global_def.LOGFILE)

    first_ring = int(ir)
    last_ring = int(ou)
    rstep = int(rs)
    max_iter = int(maxit)

    print_begin_msg("mref_ali2d")

    print_msg("Input stack                 : %s\n" % (stack))
    print_msg("Reference stack             : %s\n" % (refim))
    print_msg("Output directory            : %s\n" % (outdir))
    print_msg("Maskfile                    : %s\n" % (maskfile))
    print_msg("Inner radius                : %i\n" % (first_ring))

    ima = EMData()
    ima.read_image(stack, 0)
    nx = ima.get_xsize()
    # default value for the last ring
    if last_ring == -1: last_ring = nx / 2 - 2

    print_msg("Outer radius                : %i\n" % (last_ring))
    print_msg("Ring step                   : %i\n" % (rstep))
    print_msg("X search range              : %i\n" % (xrng))
    print_msg("Y search range              : %i\n" % (yrng))
    print_msg("Translational step          : %i\n" % (step))
    print_msg("Center type                 : %i\n" % (center))
    print_msg("Maximum iteration           : %i\n" % (max_iter))
    print_msg("CTF correction              : %s\n" % (CTF))
    print_msg("Signal-to-Noise Ratio       : %f\n" % (snr))
    print_msg("Random seed                 : %i\n\n" % (rand_seed))
    print_msg("User function               : %s\n" % (user_func_name))
    output = sys.stdout

    import sp_user_functions
    user_func = sp_user_functions.factory[user_func_name]

    if maskfile:
        import types
        if type(maskfile) is bytes: mask = get_image(maskfile)
        else: mask = maskfile
    else: mask = model_circle(last_ring, nx, nx)
    #  references
    refi = []
    numref = EMUtil.get_image_count(refim)

    # IMAGES ARE SQUARES! center is in SPIDER convention
    cnx = nx / 2 + 1
    cny = cnx

    mode = "F"
    #precalculate rings
    numr = Numrinit(first_ring, last_ring, rstep, mode)
    wr = ringwe(numr, mode)
    # reference images
    params = []
    #read all data
    data = EMData.read_images(stack)
    nima = len(data)
    # prepare the reference
    ima.to_zero()
    for j in range(numref):
        temp = EMData()
        temp.read_image(refim, j)
        #  eve, odd, numer of even, number of images.  After frc, totav
        refi.append([temp, ima.copy(), 0])

    seed(rand_seed)
    again = True

    ref_data = [mask, center, None, None]

    Iter = 0

    while Iter < max_iter and again:
        ringref = []
        #print "numref",numref

        ### Reference ###
        mashi = cnx - last_ring - 2
        for j in range(numref):
            refi[j][0].process_inplace("normalize.mask", {
                "mask": mask,
                "no_sigma": 1
            })
            cimage = Util.Polar2Dm(refi[j][0], cnx, cny, numr, mode)
            Util.Frngs(cimage, numr)
            Util.Applyws(cimage, numr, wr)
            ringref.append(cimage)

        assign = [[] for i in range(numref)]
        sx_sum = [0.0] * numref
        sy_sum = [0.0] * numref
        for im in range(nima):
            alpha, sx, sy, mirror, scale = get_params2D(data[im])
            #  Why inverse?  07/11/2015  PAP
            alphai, sxi, syi, scalei = inverse_transform2(alpha, sx, sy)
            # normalize
            data[im].process_inplace("normalize.mask", {
                "mask": mask,
                "no_sigma": 0
            })
            # If shifts are outside of the permissible range, reset them
            if (abs(sxi) > mashi or abs(syi) > mashi):
                sxi = 0.0
                syi = 0.0
                set_params2D(data[im], [0.0, 0.0, 0.0, 0, 1.0])
            ny = nx
            txrng = search_range(nx, last_ring, sxi, xrng, "mref_ali2d")
            txrng = [txrng[1], txrng[0]]
            tyrng = search_range(ny, last_ring, syi, yrng, "mref_ali2d")
            tyrng = [tyrng[1], tyrng[0]]
            # align current image to the reference
            #[angt, sxst, syst, mirrort, xiref, peakt] = Util.multiref_polar_ali_2d_p(data[im],
            #    ringref, txrng, tyrng, step, mode, numr, cnx+sxi, cny+syi)
            #print(angt, sxst, syst, mirrort, xiref, peakt)
            [angt, sxst, syst, mirrort, xiref,
             peakt] = Util.multiref_polar_ali_2d(data[im], ringref, txrng,
                                                 tyrng, step, mode, numr,
                                                 cnx + sxi, cny + syi)

            iref = int(xiref)
            # combine parameters and set them to the header, ignore previous angle and mirror
            [alphan, sxn, syn, mn] = combine_params2(0.0, -sxi, -syi, 0, angt,
                                                     sxst, syst, int(mirrort))
            set_params2D(data[im], [alphan, sxn, syn, int(mn), scale])
            if mn == 0: sx_sum[iref] += sxn
            else: sx_sum[iref] -= sxn
            sy_sum[iref] += syn
            data[im].set_attr('assign', iref)
            # apply current parameters and add to the average
            temp = rot_shift2D(data[im], alphan, sxn, syn, mn)
            it = im % 2
            Util.add_img(refi[iref][it], temp)

            assign[iref].append(im)
            refi[iref][2] += 1
        del ringref
        if again:
            a1 = 0.0
            for j in range(numref):
                msg = "   group #%3d   number of particles = %7d\n" % (
                    j, refi[j][2])
                print_msg(msg)
                if refi[j][2] < 4:
                    #ERROR("One of the references vanished","mref_ali2d",1)
                    #  if vanished, put a random image there
                    assign[j] = []
                    assign[j].append(randint(0, nima - 1))
                    refi[j][0] = data[assign[j][0]].copy()
                else:
                    max_inter = 0  # switch off fine refi.
                    br = 1.75
                    #  the loop has to
                    for INter in range(max_inter + 1):
                        # Calculate averages at least ones, meaning even if no within group refinement was requested
                        frsc = fsc(
                            refi[j][0], refi[j][1], 1.0,
                            os.path.join(outdir,
                                         "drm_%03d_%04d.txt" % (Iter, j)))
                        Util.add_img(refi[j][0], refi[j][1])
                        Util.mul_scalar(refi[j][0], 1.0 / float(refi[j][2]))

                        ref_data[2] = refi[j][0]
                        ref_data[3] = frsc
                        refi[j][0], cs = user_func(ref_data)
                        if center == -1:
                            cs[0] = sx_sum[j] / len(assign[j])
                            cs[1] = sy_sum[j] / len(assign[j])
                            refi[j][0] = fshift(refi[j][0], -cs[0], -cs[1])
                        for i in range(len(assign[j])):
                            im = assign[j][i]
                            alpha, sx, sy, mirror, scale = get_params2D(
                                data[im])
                            alphan, sxn, syn, mirrorn = combine_params2(
                                alpha, sx, sy, mirror, 0.0, -cs[0], -cs[1], 0)
                            set_params2D(
                                data[im],
                                [alphan, sxn, syn,
                                 int(mirrorn), scale])
                        # refine images within the group
                        #  Do the refinement only if max_inter>0, but skip it for the last iteration.
                        if INter < max_inter:
                            fine_2D_refinement(data, br, mask, refi[j][0], j)
                            #  Calculate updated average
                            refi[j][0].to_zero()
                            refi[j][1].to_zero()
                            for i in range(len(assign[j])):
                                im = assign[j][i]
                                alpha, sx, sy, mirror, scale = get_params2D(
                                    data[im])
                                # apply current parameters and add to the average
                                temp = rot_shift2D(data[im], alpha, sx, sy, mn)
                                it = im % 2
                                Util.add_img(refi[j][it], temp)
                # write the current average
                TMP = []
                for i_tmp in range(len(assign[j])):
                    TMP.append(float(assign[j][i_tmp]))
                TMP.sort()
                refi[j][0].set_attr_dict({'ave_n': refi[j][2], 'members': TMP})
                del TMP
                # replace the name of the stack with reference with the current one
                newrefim = os.path.join(outdir, "aqm%03d.hdf" % Iter)
                refi[j][0].write_image(newrefim, j)
            Iter += 1
            msg = "ITERATION #%3d        \n" % (Iter)
            print_msg(msg)

    newrefim = os.path.join(outdir, "multi_ref.hdf")
    for j in range(numref):
        refi[j][0].write_image(newrefim, j)
    from sp_utilities import write_headers
    write_headers(stack, data, list(range(nima)))
    print_end_msg("mref_ali2d")