Ejemplo n.º 1
0
def filterlocal(ui, vi, m, falloff, myid, main_node, number_of_proc):

    if myid == main_node:

        nx = vi.get_xsize()
        ny = vi.get_ysize()
        nz = vi.get_zsize()
        #  Round all resolution numbers to two digits
        for x in range(nx):
            for y in range(ny):
                for z in range(nz):
                    ui.set_value_at_fast(x, y, z, round(ui.get_value_at(x, y, z), 2))
        dis = [nx, ny, nz]
    else:
        falloff = 0.0
        radius = 0
        dis = [0, 0, 0]
    falloff = sp_utilities.bcast_number_to_all(falloff, main_node)
    dis = sp_utilities.bcast_list_to_all(dis, myid, source_node=main_node)

    if myid != main_node:
        nx = int(dis[0])
        ny = int(dis[1])
        nz = int(dis[2])

        vi = sp_utilities.model_blank(nx, ny, nz)
        ui = sp_utilities.model_blank(nx, ny, nz)

    sp_utilities.bcast_EMData_to_all(vi, myid, main_node)
    sp_utilities.bcast_EMData_to_all(ui, myid, main_node)

    sp_fundamentals.fftip(vi)  #  volume to be filtered

    st = EMAN2_cppwrap.Util.infomask(ui, m, True)

    filteredvol = sp_utilities.model_blank(nx, ny, nz)
    cutoff = max(st[2] - 0.01, 0.0)
    while cutoff < st[3]:
        cutoff = round(cutoff + 0.01, 2)
        # if(myid == main_node):  print  cutoff,st
        pt = EMAN2_cppwrap.Util.infomask(
            sp_morphology.threshold_outside(ui, cutoff - 0.00501, cutoff + 0.005),
            m,
            True,
        )  # Ideally, one would want to check only slices in question...
        if pt[0] != 0.0:
            # print cutoff,pt[0]
            vovo = sp_fundamentals.fft(filt_tanl(vi, cutoff, falloff))
            for z in range(myid, nz, number_of_proc):
                for x in range(nx):
                    for y in range(ny):
                        if m.get_value_at(x, y, z) > 0.5:
                            if round(ui.get_value_at(x, y, z), 2) == cutoff:
                                filteredvol.set_value_at_fast(
                                    x, y, z, vovo.get_value_at(x, y, z)
                                )

    mpi.mpi_barrier(mpi.MPI_COMM_WORLD)
    sp_utilities.reduce_EMData_to_root(filteredvol, myid, main_node, mpi.MPI_COMM_WORLD)
    return filteredvol
Ejemplo n.º 2
0
def helicalshiftali_MPI(stack,
                        maskfile=None,
                        maxit=100,
                        CTF=False,
                        snr=1.0,
                        Fourvar=False,
                        search_rng=-1):

    nproc = mpi.mpi_comm_size(mpi.MPI_COMM_WORLD)
    myid = mpi.mpi_comm_rank(mpi.MPI_COMM_WORLD)
    main_node = 0

    ftp = file_type(stack)

    if myid == main_node:
        print_begin_msg("helical-shiftali_MPI")

    max_iter = int(maxit)
    if (myid == main_node):
        infils = EMUtil.get_all_attributes(stack, "filament")
        ptlcoords = EMUtil.get_all_attributes(stack, 'ptcl_source_coord')
        filaments = ordersegments(infils, ptlcoords)
        total_nfils = len(filaments)
        inidl = [0] * total_nfils
        for i in range(total_nfils):
            inidl[i] = len(filaments[i])
        linidl = sum(inidl)
        nima = linidl
        tfilaments = []
        for i in range(total_nfils):
            tfilaments += filaments[i]
        del filaments
    else:
        total_nfils = 0
        linidl = 0
    total_nfils = bcast_number_to_all(total_nfils, source_node=main_node)
    if myid != main_node:
        inidl = [-1] * total_nfils
    inidl = bcast_list_to_all(inidl, myid, source_node=main_node)
    linidl = bcast_number_to_all(linidl, source_node=main_node)
    if myid != main_node:
        tfilaments = [-1] * linidl
    tfilaments = bcast_list_to_all(tfilaments, myid, source_node=main_node)
    filaments = []
    iendi = 0
    for i in range(total_nfils):
        isti = iendi
        iendi = isti + inidl[i]
        filaments.append(tfilaments[isti:iendi])
    del tfilaments, inidl

    if myid == main_node:
        print_msg("total number of filaments: %d" % total_nfils)
    if total_nfils < nproc:
        ERROR(
            'number of CPUs (%i) is larger than the number of filaments (%i), please reduce the number of CPUs used'
            % (nproc, total_nfils),
            myid=myid)

    #  balanced load
    temp = chunks_distribution([[len(filaments[i]), i]
                                for i in range(len(filaments))],
                               nproc)[myid:myid + 1][0]
    filaments = [filaments[temp[i][1]] for i in range(len(temp))]
    nfils = len(filaments)

    #filaments = [[0,1]]
    #print "filaments",filaments
    list_of_particles = []
    indcs = []
    k = 0
    for i in range(nfils):
        list_of_particles += filaments[i]
        k1 = k + len(filaments[i])
        indcs.append([k, k1])
        k = k1
    data = EMData.read_images(stack, list_of_particles)
    ldata = len(data)
    sxprint("ldata=", ldata)
    nx = data[0].get_xsize()
    ny = data[0].get_ysize()
    if maskfile == None:
        mrad = min(nx, ny) // 2 - 2
        mask = pad(model_blank(2 * mrad + 1, ny, 1, 1.0), nx, ny, 1, 0.0)
    else:
        mask = get_im(maskfile)

    # apply initial xform.align2d parameters stored in header
    init_params = []
    for im in range(ldata):
        t = data[im].get_attr('xform.align2d')
        init_params.append(t)
        p = t.get_params("2d")
        data[im] = rot_shift2D(data[im], p['alpha'], p['tx'], p['ty'],
                               p['mirror'], p['scale'])

    if CTF:
        from sp_filter import filt_ctf
        from sp_morphology import ctf_img
        ctf_abs_sum = EMData(nx, ny, 1, False)
        ctf_2_sum = EMData(nx, ny, 1, False)
    else:
        ctf_2_sum = None
        ctf_abs_sum = None

    from sp_utilities import info

    for im in range(ldata):
        data[im].set_attr('ID', list_of_particles[im])
        st = Util.infomask(data[im], mask, False)
        data[im] -= st[0]
        if CTF:
            ctf_params = data[im].get_attr("ctf")
            qctf = data[im].get_attr("ctf_applied")
            if qctf == 0:
                data[im] = filt_ctf(fft(data[im]), ctf_params)
                data[im].set_attr('ctf_applied', 1)
            elif qctf != 1:
                ERROR('Incorrectly set qctf flag', myid=myid)
            ctfimg = ctf_img(nx, ctf_params, ny=ny)
            Util.add_img2(ctf_2_sum, ctfimg)
            Util.add_img_abs(ctf_abs_sum, ctfimg)
        else:
            data[im] = fft(data[im])

    del list_of_particles

    if CTF:
        reduce_EMData_to_root(ctf_2_sum, myid, main_node)
        reduce_EMData_to_root(ctf_abs_sum, myid, main_node)
    if CTF:
        if myid != main_node:
            del ctf_2_sum
            del ctf_abs_sum
        else:
            temp = EMData(nx, ny, 1, False)
            tsnr = 1. / snr
            for i in range(0, nx + 2, 2):
                for j in range(ny):
                    temp.set_value_at(i, j, tsnr)
                    temp.set_value_at(i + 1, j, 0.0)
            #info(ctf_2_sum)
            Util.add_img(ctf_2_sum, temp)
            #info(ctf_2_sum)
            del temp

    total_iter = 0
    shift_x = [0.0] * ldata

    for Iter in range(max_iter):
        if myid == main_node:
            start_time = time()
            print_msg("Iteration #%4d\n" % (total_iter))
        total_iter += 1
        avg = EMData(nx, ny, 1, False)
        for im in range(ldata):
            Util.add_img(avg, fshift(data[im], shift_x[im]))

        reduce_EMData_to_root(avg, myid, main_node)

        if myid == main_node:
            if CTF: tavg = Util.divn_filter(avg, ctf_2_sum)
            else: tavg = Util.mult_scalar(avg, 1.0 / float(nima))
        else:
            tavg = model_blank(nx, ny)

        if Fourvar:
            bcast_EMData_to_all(tavg, myid, main_node)
            vav, rvar = varf2d_MPI(myid, data, tavg, mask, "a", CTF)

        if myid == main_node:
            if Fourvar:
                tavg = fft(Util.divn_img(fft(tavg), vav))
                vav_r = Util.pack_complex_to_real(vav)
            # normalize and mask tavg in real space
            tavg = fft(tavg)
            stat = Util.infomask(tavg, mask, False)
            tavg -= stat[0]
            Util.mul_img(tavg, mask)
            tavg.write_image("tavg.hdf", Iter)
            # For testing purposes: shift tavg to some random place and see if the centering is still correct
            #tavg = rot_shift3D(tavg,sx=3,sy=-4)

        if Fourvar: del vav
        bcast_EMData_to_all(tavg, myid, main_node)
        tavg = fft(tavg)

        sx_sum = 0.0
        nxc = nx // 2

        for ifil in range(nfils):
            """
			# Calculate filament average
			avg = EMData(nx, ny, 1, False)
			filnima = 0
			for im in xrange(indcs[ifil][0], indcs[ifil][1]):
				Util.add_img(avg, data[im])
				filnima += 1
			tavg = Util.mult_scalar(avg, 1.0/float(filnima))
			"""
            # Calculate 1D ccf between each segment and filament average
            nsegms = indcs[ifil][1] - indcs[ifil][0]
            ctx = [None] * nsegms
            pcoords = [None] * nsegms
            for im in range(indcs[ifil][0], indcs[ifil][1]):
                ctx[im - indcs[ifil][0]] = Util.window(ccf(tavg, data[im]), nx,
                                                       1)
                pcoords[im - indcs[ifil][0]] = data[im].get_attr(
                    'ptcl_source_coord')
                #ctx[im-indcs[ifil][0]].write_image("ctx.hdf",im-indcs[ifil][0])
                #print "  CTX  ",myid,im,Util.infomask(ctx[im-indcs[ifil][0]], None, True)
            # search for best x-shift
            cents = nsegms // 2

            dst = sqrt(
                max((pcoords[cents][0] - pcoords[0][0])**2 +
                    (pcoords[cents][1] - pcoords[0][1])**2,
                    (pcoords[cents][0] - pcoords[-1][0])**2 +
                    (pcoords[cents][1] - pcoords[-1][1])**2))
            maxincline = atan2(ny // 2 - 2 - float(search_rng), dst)
            kang = int(dst * tan(maxincline) + 0.5)
            #print  "  settings ",nsegms,cents,dst,search_rng,maxincline,kang

            # ## C code for alignment. @ming
            results = [0.0] * 3
            results = Util.helixshiftali(ctx, pcoords, nsegms, maxincline,
                                         kang, search_rng, nxc)
            sib = int(results[0])
            bang = results[1]
            qm = results[2]
            #print qm, sib, bang

            # qm = -1.e23
            #
            # 			for six in xrange(-search_rng, search_rng+1,1):
            # 				q0 = ctx[cents].get_value_at(six+nxc)
            # 				for incline in xrange(kang+1):
            # 					qt = q0
            # 					qu = q0
            # 					if(kang>0):  tang = tan(maxincline/kang*incline)
            # 					else:        tang = 0.0
            # 					for kim in xrange(cents+1,nsegms):
            # 						dst = sqrt((pcoords[cents][0] - pcoords[kim][0])**2 + (pcoords[cents][1] - pcoords[kim][1])**2)
            # 						xl = dst*tang+six+nxc
            # 						ixl = int(xl)
            # 						dxl = xl - ixl
            # 						#print "  A  ", ifil,six,incline,kim,xl,ixl,dxl
            # 						qt += (1.0-dxl)*ctx[kim].get_value_at(ixl) + dxl*ctx[kim].get_value_at(ixl+1)
            # 						xl = -dst*tang+six+nxc
            # 						ixl = int(xl)
            # 						dxl = xl - ixl
            # 						qu += (1.0-dxl)*ctx[kim].get_value_at(ixl) + dxl*ctx[kim].get_value_at(ixl+1)
            # 					for kim in xrange(cents):
            # 						dst = sqrt((pcoords[cents][0] - pcoords[kim][0])**2 + (pcoords[cents][1] - pcoords[kim][1])**2)
            # 						xl = -dst*tang+six+nxc
            # 						ixl = int(xl)
            # 						dxl = xl - ixl
            # 						qt += (1.0-dxl)*ctx[kim].get_value_at(ixl) + dxl*ctx[kim].get_value_at(ixl+1)
            # 						xl =  dst*tang+six+nxc
            # 						ixl = int(xl)
            # 						dxl = xl - ixl
            # 						qu += (1.0-dxl)*ctx[kim].get_value_at(ixl) + dxl*ctx[kim].get_value_at(ixl+1)
            # 					if( qt > qm ):
            # 						qm = qt
            # 						sib = six
            # 						bang = tang
            # 					if( qu > qm ):
            # 						qm = qu
            # 						sib = six
            # 						bang = -tang
            #if incline == 0:  print  "incline = 0  ",six,tang,qt,qu
            #print qm,six,sib,bang
            #print " got results   ",indcs[ifil][0], indcs[ifil][1], ifil,myid,qm,sib,tang,bang,len(ctx),Util.infomask(ctx[0], None, True)
            for im in range(indcs[ifil][0], indcs[ifil][1]):
                kim = im - indcs[ifil][0]
                dst = sqrt((pcoords[cents][0] - pcoords[kim][0])**2 +
                           (pcoords[cents][1] - pcoords[kim][1])**2)
                if (kim < cents): xl = -dst * bang + sib
                else: xl = dst * bang + sib
                shift_x[im] = xl

            # Average shift
            sx_sum += shift_x[indcs[ifil][0] + cents]

        # #print myid,sx_sum,total_nfils
        sx_sum = mpi.mpi_reduce(sx_sum, 1, mpi.MPI_FLOAT, mpi.MPI_SUM,
                                main_node, mpi.MPI_COMM_WORLD)
        if myid == main_node:
            sx_sum = float(sx_sum[0]) / total_nfils
            print_msg("Average shift  %6.2f\n" % (sx_sum))
        else:
            sx_sum = 0.0
        sx_sum = 0.0
        sx_sum = bcast_number_to_all(sx_sum, source_node=main_node)
        for im in range(ldata):
            shift_x[im] -= sx_sum
            #print  "   %3d  %6.3f"%(im,shift_x[im])
        #exit()

    # combine shifts found with the original parameters
    for im in range(ldata):
        t1 = Transform()
        ##import random
        ##shix=random.randint(-10, 10)
        ##t1.set_params({"type":"2D","tx":shix})
        t1.set_params({"type": "2D", "tx": shift_x[im]})
        # combine t0 and t1
        tt = t1 * init_params[im]
        data[im].set_attr("xform.align2d", tt)
    # write out headers and STOP, under MPI writing has to be done sequentially
    mpi.mpi_barrier(mpi.MPI_COMM_WORLD)
    par_str = ["xform.align2d", "ID"]
    if myid == main_node:
        from sp_utilities import file_type
        if (file_type(stack) == "bdb"):
            from sp_utilities import recv_attr_dict_bdb
            recv_attr_dict_bdb(main_node, stack, data, par_str, 0, ldata,
                               nproc)
        else:
            from sp_utilities import recv_attr_dict
            recv_attr_dict(main_node, stack, data, par_str, 0, ldata, nproc)
    else:
        send_attr_dict(main_node, data, par_str, 0, ldata)
    if myid == main_node: print_end_msg("helical-shiftali_MPI")
Ejemplo n.º 3
0
def shiftali_MPI(stack,
                 maskfile=None,
                 maxit=100,
                 CTF=False,
                 snr=1.0,
                 Fourvar=False,
                 search_rng=-1,
                 oneDx=False,
                 search_rng_y=-1):

    number_of_proc = mpi.mpi_comm_size(mpi.MPI_COMM_WORLD)
    myid = mpi.mpi_comm_rank(mpi.MPI_COMM_WORLD)
    main_node = 0

    ftp = file_type(stack)

    if myid == main_node:
        print_begin_msg("shiftali_MPI")

    max_iter = int(maxit)

    if myid == main_node:
        if ftp == "bdb":
            from EMAN2db import db_open_dict
            dummy = db_open_dict(stack, True)
        nima = EMUtil.get_image_count(stack)
    else:
        nima = 0
    nima = bcast_number_to_all(nima, source_node=main_node)
    list_of_particles = list(range(nima))

    image_start, image_end = MPI_start_end(nima, number_of_proc, myid)
    list_of_particles = list_of_particles[image_start:image_end]

    # read nx and ctf_app (if CTF) and broadcast to all nodes
    if myid == main_node:
        ima = EMData()
        ima.read_image(stack, list_of_particles[0], True)
        nx = ima.get_xsize()
        ny = ima.get_ysize()
        if CTF: ctf_app = ima.get_attr_default('ctf_applied', 2)
        del ima
    else:
        nx = 0
        ny = 0
        if CTF: ctf_app = 0
    nx = bcast_number_to_all(nx, source_node=main_node)
    ny = bcast_number_to_all(ny, source_node=main_node)
    if CTF:
        ctf_app = bcast_number_to_all(ctf_app, source_node=main_node)
        if ctf_app > 0:
            ERROR("data cannot be ctf-applied", myid=myid)

    if maskfile == None:
        mrad = min(nx, ny)
        mask = model_circle(mrad // 2 - 2, nx, ny)
    else:
        mask = get_im(maskfile)

    if CTF:
        from sp_filter import filt_ctf
        from sp_morphology import ctf_img
        ctf_abs_sum = EMData(nx, ny, 1, False)
        ctf_2_sum = EMData(nx, ny, 1, False)
    else:
        ctf_2_sum = None

    from sp_global_def import CACHE_DISABLE
    if CACHE_DISABLE:
        data = EMData.read_images(stack, list_of_particles)
    else:
        for i in range(number_of_proc):
            if myid == i:
                data = EMData.read_images(stack, list_of_particles)
            if ftp == "bdb": mpi.mpi_barrier(mpi.MPI_COMM_WORLD)

    for im in range(len(data)):
        data[im].set_attr('ID', list_of_particles[im])
        st = Util.infomask(data[im], mask, False)
        data[im] -= st[0]
        if CTF:
            ctf_params = data[im].get_attr("ctf")
            ctfimg = ctf_img(nx, ctf_params, ny=ny)
            Util.add_img2(ctf_2_sum, ctfimg)
            Util.add_img_abs(ctf_abs_sum, ctfimg)

    if CTF:
        reduce_EMData_to_root(ctf_2_sum, myid, main_node)
        reduce_EMData_to_root(ctf_abs_sum, myid, main_node)
    else:
        ctf_2_sum = None
    if CTF:
        if myid != main_node:
            del ctf_2_sum
            del ctf_abs_sum
        else:
            temp = EMData(nx, ny, 1, False)
            for i in range(0, nx, 2):
                for j in range(ny):
                    temp.set_value_at(i, j, snr)
            Util.add_img(ctf_2_sum, temp)
            del temp

    total_iter = 0

    # apply initial xform.align2d parameters stored in header
    init_params = []
    for im in range(len(data)):
        t = data[im].get_attr('xform.align2d')
        init_params.append(t)
        p = t.get_params("2d")
        data[im] = rot_shift2D(data[im],
                               p['alpha'],
                               sx=p['tx'],
                               sy=p['ty'],
                               mirror=p['mirror'],
                               scale=p['scale'])

    # fourier transform all images, and apply ctf if CTF
    for im in range(len(data)):
        if CTF:
            ctf_params = data[im].get_attr("ctf")
            data[im] = filt_ctf(fft(data[im]), ctf_params)
        else:
            data[im] = fft(data[im])

    sx_sum = 0
    sy_sum = 0
    sx_sum_total = 0
    sy_sum_total = 0
    shift_x = [0.0] * len(data)
    shift_y = [0.0] * len(data)
    ishift_x = [0.0] * len(data)
    ishift_y = [0.0] * len(data)

    for Iter in range(max_iter):
        if myid == main_node:
            start_time = time()
            print_msg("Iteration #%4d\n" % (total_iter))
        total_iter += 1
        avg = EMData(nx, ny, 1, False)
        for im in data:
            Util.add_img(avg, im)

        reduce_EMData_to_root(avg, myid, main_node)

        if myid == main_node:
            if CTF:
                tavg = Util.divn_filter(avg, ctf_2_sum)
            else:
                tavg = Util.mult_scalar(avg, 1.0 / float(nima))
        else:
            tavg = EMData(nx, ny, 1, False)

        if Fourvar:
            bcast_EMData_to_all(tavg, myid, main_node)
            vav, rvar = varf2d_MPI(myid, data, tavg, mask, "a", CTF)

        if myid == main_node:
            if Fourvar:
                tavg = fft(Util.divn_img(fft(tavg), vav))
                vav_r = Util.pack_complex_to_real(vav)

            # normalize and mask tavg in real space
            tavg = fft(tavg)
            stat = Util.infomask(tavg, mask, False)
            tavg -= stat[0]
            Util.mul_img(tavg, mask)
            # For testing purposes: shift tavg to some random place and see if the centering is still correct
            #tavg = rot_shift3D(tavg,sx=3,sy=-4)
            tavg = fft(tavg)

        if Fourvar: del vav
        bcast_EMData_to_all(tavg, myid, main_node)

        sx_sum = 0
        sy_sum = 0
        if search_rng > 0: nwx = 2 * search_rng + 1
        else: nwx = nx

        if search_rng_y > 0: nwy = 2 * search_rng_y + 1
        else: nwy = ny

        not_zero = 0
        for im in range(len(data)):
            if oneDx:
                ctx = Util.window(ccf(data[im], tavg), nwx, 1)
                p1 = peak_search(ctx)
                p1_x = -int(p1[0][3])
                ishift_x[im] = p1_x
                sx_sum += p1_x
            else:
                p1 = peak_search(Util.window(ccf(data[im], tavg), nwx, nwy))
                p1_x = -int(p1[0][4])
                p1_y = -int(p1[0][5])
                ishift_x[im] = p1_x
                ishift_y[im] = p1_y
                sx_sum += p1_x
                sy_sum += p1_y

            if not_zero == 0:
                if (not (ishift_x[im] == 0.0)) or (not (ishift_y[im] == 0.0)):
                    not_zero = 1

        sx_sum = mpi.mpi_reduce(sx_sum, 1, mpi.MPI_INT, mpi.MPI_SUM, main_node,
                                mpi.MPI_COMM_WORLD)

        if not oneDx:
            sy_sum = mpi.mpi_reduce(sy_sum, 1, mpi.MPI_INT, mpi.MPI_SUM,
                                    main_node, mpi.MPI_COMM_WORLD)

        if myid == main_node:
            sx_sum_total = int(sx_sum[0])
            if not oneDx:
                sy_sum_total = int(sy_sum[0])
        else:
            sx_sum_total = 0
            sy_sum_total = 0

        sx_sum_total = bcast_number_to_all(sx_sum_total, source_node=main_node)

        if not oneDx:
            sy_sum_total = bcast_number_to_all(sy_sum_total,
                                               source_node=main_node)

        sx_ave = round(float(sx_sum_total) / nima)
        sy_ave = round(float(sy_sum_total) / nima)
        for im in range(len(data)):
            p1_x = ishift_x[im] - sx_ave
            p1_y = ishift_y[im] - sy_ave
            params2 = {
                "filter_type": Processor.fourier_filter_types.SHIFT,
                "x_shift": p1_x,
                "y_shift": p1_y,
                "z_shift": 0.0
            }
            data[im] = Processor.EMFourierFilter(data[im], params2)
            shift_x[im] += p1_x
            shift_y[im] += p1_y
        # stop if all shifts are zero
        not_zero = mpi.mpi_reduce(not_zero, 1, mpi.MPI_INT, mpi.MPI_SUM,
                                  main_node, mpi.MPI_COMM_WORLD)
        if myid == main_node:
            not_zero_all = int(not_zero[0])
        else:
            not_zero_all = 0
        not_zero_all = bcast_number_to_all(not_zero_all, source_node=main_node)

        if myid == main_node:
            print_msg("Time of iteration = %12.2f\n" % (time() - start_time))
            start_time = time()

        if not_zero_all == 0: break

    #for im in xrange(len(data)): data[im] = fft(data[im])  This should not be required as only header information is used
    # combine shifts found with the original parameters
    for im in range(len(data)):
        t0 = init_params[im]
        t1 = Transform()
        t1.set_params({
            "type": "2D",
            "alpha": 0,
            "scale": t0.get_scale(),
            "mirror": 0,
            "tx": shift_x[im],
            "ty": shift_y[im]
        })
        # combine t0 and t1
        tt = t1 * t0
        data[im].set_attr("xform.align2d", tt)

    # write out headers and STOP, under MPI writing has to be done sequentially
    mpi.mpi_barrier(mpi.MPI_COMM_WORLD)
    par_str = ["xform.align2d", "ID"]
    if myid == main_node:
        from sp_utilities import file_type
        if (file_type(stack) == "bdb"):
            from sp_utilities import recv_attr_dict_bdb
            recv_attr_dict_bdb(main_node, stack, data, par_str, image_start,
                               image_end, number_of_proc)
        else:
            from sp_utilities import recv_attr_dict
            recv_attr_dict(main_node, stack, data, par_str, image_start,
                           image_end, number_of_proc)

    else:
        send_attr_dict(main_node, data, par_str, image_start, image_end)
    if myid == main_node: print_end_msg("shiftali_MPI")
Ejemplo n.º 4
0
def main():
    arglist = []
    for arg in sys.argv:
        arglist.append(arg)
    progname = optparse.os.path.basename(arglist[0])
    usage = progname + """ inputvolume  locresvolume maskfile outputfile   --radius --falloff  --MPI

	    Locally filer a volume based on local resolution volume (sxlocres.py) within area outlined by the maskfile
	"""
    parser = optparse.OptionParser(usage, version=sp_global_def.SPARXVERSION)

    parser.add_option(
        "--radius",
        type="int",
        default=-1,
        help=
        "if there is no maskfile, sphere with r=radius will be used, by default the radius is nx/2-1"
    )
    parser.add_option("--falloff",
                      type="float",
                      default=0.1,
                      help="falloff of tanl filter (default 0.1)")
    parser.add_option("--MPI",
                      action="store_true",
                      default=False,
                      help="use MPI version")

    (options, args) = parser.parse_args(arglist[1:])

    if len(args) < 3 or len(args) > 4:
        sp_global_def.sxprint("See usage " + usage)
        sp_global_def.ERROR(
            "Wrong number of parameters. Please see usage information above.")
        return

    if sp_global_def.CACHE_DISABLE:
        pass  #IMPORTIMPORTIMPORT from sp_utilities import disable_bdb_cache
        sp_utilities.disable_bdb_cache()

    if options.MPI:
        number_of_proc = mpi.mpi_comm_size(mpi.MPI_COMM_WORLD)
        myid = mpi.mpi_comm_rank(mpi.MPI_COMM_WORLD)
        main_node = 0

        if (myid == main_node):
            #print sys.argv
            vi = sp_utilities.get_im(sys.argv[1])
            ui = sp_utilities.get_im(sys.argv[2])
            #print   Util.infomask(ui, None, True)
            radius = options.radius
            nx = vi.get_xsize()
            ny = vi.get_ysize()
            nz = vi.get_zsize()
            dis = [nx, ny, nz]
        else:
            falloff = 0.0
            radius = 0
            dis = [0, 0, 0]
            vi = None
            ui = None
        dis = sp_utilities.bcast_list_to_all(dis, myid, source_node=main_node)

        if (myid != main_node):
            nx = int(dis[0])
            ny = int(dis[1])
            nz = int(dis[2])
        radius = sp_utilities.bcast_number_to_all(radius, main_node)
        if len(args) == 3:
            if (radius == -1):
                radius = min(nx, ny, nz) // 2 - 1
            m = sp_utilities.model_circle(radius, nx, ny, nz)
            outvol = args[2]

        elif len(args) == 4:
            if (myid == main_node):
                m = sp_morphology.binarize(sp_utilities.get_im(args[2]), 0.5)
            else:
                m = sp_utilities.model_blank(nx, ny, nz)
            outvol = args[3]
            sp_utilities.bcast_EMData_to_all(m, myid, main_node)

        pass  #IMPORTIMPORTIMPORT from sp_filter import filterlocal
        filteredvol = sp_filter.filterlocal(ui, vi, m, options.falloff, myid,
                                            main_node, number_of_proc)

        if (myid == 0):
            filteredvol.write_image(outvol)

    else:
        vi = sp_utilities.get_im(args[0])
        ui = sp_utilities.get_im(
            args[1]
        )  # resolution volume, values are assumed to be from 0 to 0.5

        nn = vi.get_xsize()

        falloff = options.falloff

        if len(args) == 3:
            radius = options.radius
            if (radius == -1):
                radius = nn // 2 - 1
            m = sp_utilities.model_circle(radius, nn, nn, nn)
            outvol = args[2]

        elif len(args) == 4:
            m = sp_morphology.binarize(sp_utilities.get_im(args[2]), 0.5)
            outvol = args[3]

        sp_fundamentals.fftip(vi)  # this is the volume to be filtered

        #  Round all resolution numbers to two digits
        for x in range(nn):
            for y in range(nn):
                for z in range(nn):
                    ui.set_value_at_fast(x, y, z,
                                         round(ui.get_value_at(x, y, z), 2))
        st = EMAN2_cppwrap.Util.infomask(ui, m, True)

        filteredvol = sp_utilities.model_blank(nn, nn, nn)
        cutoff = max(st[2] - 0.01, 0.0)
        while (cutoff < st[3]):
            cutoff = round(cutoff + 0.01, 2)
            pt = EMAN2_cppwrap.Util.infomask(
                sp_morphology.threshold_outside(ui, cutoff - 0.00501,
                                                cutoff + 0.005), m, True)
            if (pt[0] != 0.0):
                vovo = sp_fundamentals.fft(
                    sp_filter.filt_tanl(vi, cutoff, falloff))
                for x in range(nn):
                    for y in range(nn):
                        for z in range(nn):
                            if (m.get_value_at(x, y, z) > 0.5):
                                if (round(ui.get_value_at(x, y, z),
                                          2) == cutoff):
                                    filteredvol.set_value_at_fast(
                                        x, y, z, vovo.get_value_at(x, y, z))

        sp_global_def.write_command(optparse.os.path.dirname(outvol))
        filteredvol.write_image(outvol)
Ejemplo n.º 5
0
def main():

    progname = os.path.basename(sys.argv[0])
    usage = progname + " proj_stack output_averages --MPI"
    parser = OptionParser(usage, version=SPARXVERSION)

    parser.add_option("--img_per_group",
                      type="int",
                      default=100,
                      help="number of images per group")
    parser.add_option("--radius",
                      type="int",
                      default=-1,
                      help="radius for alignment")
    parser.add_option(
        "--xr",
        type="string",
        default="2 1",
        help="range for translation search in x direction, search is +/xr")
    parser.add_option(
        "--yr",
        type="string",
        default="-1",
        help=
        "range for translation search in y direction, search is +/yr (default = same as xr)"
    )
    parser.add_option(
        "--ts",
        type="string",
        default="1 0.5",
        help=
        "step size of the translation search in both directions, search is -xr, -xr+ts, 0, xr-ts, xr, can be fractional"
    )
    parser.add_option(
        "--iter",
        type="int",
        default=30,
        help="number of iterations within alignment (default = 30)")
    parser.add_option(
        "--num_ali",
        type="int",
        default=5,
        help="number of alignments performed for stability (default = 5)")
    parser.add_option("--thld_err",
                      type="float",
                      default=1.0,
                      help="threshold of pixel error (default = 1.732)")
    parser.add_option(
        "--grouping",
        type="string",
        default="GRP",
        help=
        "do grouping of projections: PPR - per projection, GRP - different size groups, exclusive (default), GEV - grouping equal size"
    )
    parser.add_option(
        "--delta",
        type="float",
        default=-1.0,
        help="angular step for reference projections (required for GEV method)"
    )
    parser.add_option(
        "--fl",
        type="float",
        default=0.3,
        help="cut-off frequency of hyperbolic tangent low-pass Fourier filter")
    parser.add_option(
        "--aa",
        type="float",
        default=0.2,
        help="fall-off of hyperbolic tangent low-pass Fourier filter")
    parser.add_option("--CTF",
                      action="store_true",
                      default=False,
                      help="Consider CTF correction during the alignment ")
    parser.add_option("--MPI",
                      action="store_true",
                      default=False,
                      help="use MPI version")

    (options, args) = parser.parse_args()

    myid = mpi.mpi_comm_rank(MPI_COMM_WORLD)
    number_of_proc = mpi.mpi_comm_size(MPI_COMM_WORLD)
    main_node = 0

    if len(args) == 2:
        stack = args[0]
        outdir = args[1]
    else:
        sp_global_def.ERROR("Incomplete list of arguments",
                            "sxproj_stability.main",
                            1,
                            myid=myid)
        return
    if not options.MPI:
        sp_global_def.ERROR("Non-MPI not supported!",
                            "sxproj_stability.main",
                            1,
                            myid=myid)
        return

    if sp_global_def.CACHE_DISABLE:
        from sp_utilities import disable_bdb_cache
        disable_bdb_cache()
    sp_global_def.BATCH = True

    img_per_grp = options.img_per_group
    radius = options.radius
    ite = options.iter
    num_ali = options.num_ali
    thld_err = options.thld_err

    xrng = get_input_from_string(options.xr)
    if options.yr == "-1":
        yrng = xrng
    else:
        yrng = get_input_from_string(options.yr)

    step = get_input_from_string(options.ts)

    if myid == main_node:
        nima = EMUtil.get_image_count(stack)
        img = get_image(stack)
        nx = img.get_xsize()
        ny = img.get_ysize()
    else:
        nima = 0
        nx = 0
        ny = 0
    nima = bcast_number_to_all(nima)
    nx = bcast_number_to_all(nx)
    ny = bcast_number_to_all(ny)
    if radius == -1: radius = nx / 2 - 2
    mask = model_circle(radius, nx, nx)

    st = time()
    if options.grouping == "GRP":
        if myid == main_node:
            sxprint("  A  ", myid, "  ", time() - st)
            proj_attr = EMUtil.get_all_attributes(stack, "xform.projection")
            proj_params = []
            for i in range(nima):
                dp = proj_attr[i].get_params("spider")
                phi, theta, psi, s2x, s2y = dp["phi"], dp["theta"], dp[
                    "psi"], -dp["tx"], -dp["ty"]
                proj_params.append([phi, theta, psi, s2x, s2y])

            # Here is where the grouping is done, I didn't put enough annotation in the group_proj_by_phitheta,
            # So I will briefly explain it here
            # proj_list  : Returns a list of list of particle numbers, each list contains img_per_grp particle numbers
            #              except for the last one. Depending on the number of particles left, they will either form a
            #              group or append themselves to the last group
            # angle_list : Also returns a list of list, each list contains three numbers (phi, theta, delta), (phi,
            #              theta) is the projection angle of the center of the group, delta is the range of this group
            # mirror_list: Also returns a list of list, each list contains img_per_grp True or False, which indicates
            #              whether it should take mirror position.
            # In this program angle_list and mirror list are not of interest.

            proj_list_all, angle_list, mirror_list = group_proj_by_phitheta(
                proj_params, img_per_grp=img_per_grp)
            del proj_params
            sxprint("  B  number of groups  ", myid, "  ", len(proj_list_all),
                    time() - st)
        mpi_barrier(MPI_COMM_WORLD)

        # Number of groups, actually there could be one or two more groups, since the size of the remaining group varies
        # we will simply assign them to main node.
        n_grp = nima / img_per_grp - 1

        # Divide proj_list_all equally to all nodes, and becomes proj_list
        proj_list = []
        for i in range(n_grp):
            proc_to_stay = i % number_of_proc
            if proc_to_stay == main_node:
                if myid == main_node: proj_list.append(proj_list_all[i])
            elif myid == main_node:
                mpi_send(len(proj_list_all[i]), 1, MPI_INT, proc_to_stay,
                         SPARX_MPI_TAG_UNIVERSAL, MPI_COMM_WORLD)
                mpi_send(proj_list_all[i], len(proj_list_all[i]), MPI_INT,
                         proc_to_stay, SPARX_MPI_TAG_UNIVERSAL, MPI_COMM_WORLD)
            elif myid == proc_to_stay:
                img_per_grp = mpi_recv(1, MPI_INT, main_node,
                                       SPARX_MPI_TAG_UNIVERSAL, MPI_COMM_WORLD)
                img_per_grp = int(img_per_grp[0])
                temp = mpi_recv(img_per_grp, MPI_INT, main_node,
                                SPARX_MPI_TAG_UNIVERSAL, MPI_COMM_WORLD)
                proj_list.append(list(map(int, temp)))
                del temp
            mpi_barrier(MPI_COMM_WORLD)
        sxprint("  C  ", myid, "  ", time() - st)
        if myid == main_node:
            # Assign the remaining groups to main_node
            for i in range(n_grp, len(proj_list_all)):
                proj_list.append(proj_list_all[i])
            del proj_list_all, angle_list, mirror_list

    #   Compute stability per projection projection direction, equal number assigned, thus overlaps
    elif options.grouping == "GEV":

        if options.delta == -1.0:
            ERROR(
                "Angular step for reference projections is required for GEV method"
            )
            return

        from sp_utilities import even_angles, nearestk_to_refdir, getvec
        refproj = even_angles(options.delta)
        img_begin, img_end = MPI_start_end(len(refproj), number_of_proc, myid)
        # Now each processor keeps its own share of reference projections
        refprojdir = refproj[img_begin:img_end]
        del refproj

        ref_ang = [0.0] * (len(refprojdir) * 2)
        for i in range(len(refprojdir)):
            ref_ang[i * 2] = refprojdir[0][0]
            ref_ang[i * 2 + 1] = refprojdir[0][1] + i * 0.1

        sxprint("  A  ", myid, "  ", time() - st)
        proj_attr = EMUtil.get_all_attributes(stack, "xform.projection")
        #  the solution below is very slow, do not use it unless there is a problem with the i/O
        """
		for i in xrange(number_of_proc):
			if myid == i:
				proj_attr = EMUtil.get_all_attributes(stack, "xform.projection")
			mpi_barrier(MPI_COMM_WORLD)
		"""
        sxprint("  B  ", myid, "  ", time() - st)

        proj_ang = [0.0] * (nima * 2)
        for i in range(nima):
            dp = proj_attr[i].get_params("spider")
            proj_ang[i * 2] = dp["phi"]
            proj_ang[i * 2 + 1] = dp["theta"]
        sxprint("  C  ", myid, "  ", time() - st)
        asi = Util.nearestk_to_refdir(proj_ang, ref_ang, img_per_grp)
        del proj_ang, ref_ang
        proj_list = []
        for i in range(len(refprojdir)):
            proj_list.append(asi[i * img_per_grp:(i + 1) * img_per_grp])
        del asi
        sxprint("  D  ", myid, "  ", time() - st)
        #from sys import exit
        #exit()

    #   Compute stability per projection
    elif options.grouping == "PPR":
        sxprint("  A  ", myid, "  ", time() - st)
        proj_attr = EMUtil.get_all_attributes(stack, "xform.projection")
        sxprint("  B  ", myid, "  ", time() - st)
        proj_params = []
        for i in range(nima):
            dp = proj_attr[i].get_params("spider")
            phi, theta, psi, s2x, s2y = dp["phi"], dp["theta"], dp[
                "psi"], -dp["tx"], -dp["ty"]
            proj_params.append([phi, theta, psi, s2x, s2y])
        img_begin, img_end = MPI_start_end(nima, number_of_proc, myid)
        sxprint("  C  ", myid, "  ", time() - st)
        from sp_utilities import nearest_proj
        proj_list, mirror_list = nearest_proj(
            proj_params, img_per_grp,
            list(range(img_begin, img_begin + 1)))  #range(img_begin, img_end))
        refprojdir = proj_params[img_begin:img_end]
        del proj_params, mirror_list
        sxprint("  D  ", myid, "  ", time() - st)

    else:
        ERROR("Incorrect projection grouping option")
        return

    ###########################################################################################################
    # Begin stability test
    from sp_utilities import get_params_proj, read_text_file
    #if myid == 0:
    #	from utilities import read_text_file
    #	proj_list[0] = map(int, read_text_file("lggrpp0.txt"))

    from sp_utilities import model_blank
    aveList = [model_blank(nx, ny)] * len(proj_list)
    if options.grouping == "GRP":
        refprojdir = [[0.0, 0.0, -1.0]] * len(proj_list)
    for i in range(len(proj_list)):
        sxprint("  E  ", myid, "  ", time() - st)
        class_data = EMData.read_images(stack, proj_list[i])
        #print "  R  ",myid,"  ",time()-st
        if options.CTF:
            from sp_filter import filt_ctf
            for im in range(len(class_data)):  #  MEM LEAK!!
                atemp = class_data[im].copy()
                btemp = filt_ctf(atemp, atemp.get_attr("ctf"), binary=1)
                class_data[im] = btemp
                #class_data[im] = filt_ctf(class_data[im], class_data[im].get_attr("ctf"), binary=1)
        for im in class_data:
            try:
                t = im.get_attr(
                    "xform.align2d")  # if they are there, no need to set them!
            except:
                try:
                    t = im.get_attr("xform.projection")
                    d = t.get_params("spider")
                    set_params2D(im, [0.0, -d["tx"], -d["ty"], 0, 1.0])
                except:
                    set_params2D(im, [0.0, 0.0, 0.0, 0, 1.0])
        #print "  F  ",myid,"  ",time()-st
        # Here, we perform realignment num_ali times
        all_ali_params = []
        for j in range(num_ali):
            if (xrng[0] == 0.0 and yrng[0] == 0.0):
                avet = ali2d_ras(class_data,
                                 randomize=True,
                                 ir=1,
                                 ou=radius,
                                 rs=1,
                                 step=1.0,
                                 dst=90.0,
                                 maxit=ite,
                                 check_mirror=True,
                                 FH=options.fl,
                                 FF=options.aa)
            else:
                avet = within_group_refinement(class_data, mask, True, 1,
                                               radius, 1, xrng, yrng, step,
                                               90.0, ite, options.fl,
                                               options.aa)
            ali_params = []
            for im in range(len(class_data)):
                alpha, sx, sy, mirror, scale = get_params2D(class_data[im])
                ali_params.extend([alpha, sx, sy, mirror])
            all_ali_params.append(ali_params)
        #aveList[i] = avet
        #print "  G  ",myid,"  ",time()-st
        del ali_params
        # We determine the stability of this group here.
        # stable_set contains all particles deemed stable, it is a list of list
        # each list has two elements, the first is the pixel error, the second is the image number
        # stable_set is sorted based on pixel error
        #from utilities import write_text_file
        #write_text_file(all_ali_params, "all_ali_params%03d.txt"%myid)
        stable_set, mir_stab_rate, average_pix_err = multi_align_stability(
            all_ali_params, 0.0, 10000.0, thld_err, False, 2 * radius + 1)
        #print "  H  ",myid,"  ",time()-st
        if (len(stable_set) > 5):
            stable_set_id = []
            members = []
            pix_err = []
            # First put the stable members into attr 'members' and 'pix_err'
            for s in stable_set:
                # s[1] - number in this subset
                stable_set_id.append(s[1])
                # the original image number
                members.append(proj_list[i][s[1]])
                pix_err.append(s[0])
            # Then put the unstable members into attr 'members' and 'pix_err'
            from sp_fundamentals import rot_shift2D
            avet.to_zero()
            if options.grouping == "GRP":
                aphi = 0.0
                atht = 0.0
                vphi = 0.0
                vtht = 0.0
            l = -1
            for j in range(len(proj_list[i])):
                #  Here it will only work if stable_set_id is sorted in the increasing number, see how l progresses
                if j in stable_set_id:
                    l += 1
                    avet += rot_shift2D(class_data[j], stable_set[l][2][0],
                                        stable_set[l][2][1],
                                        stable_set[l][2][2],
                                        stable_set[l][2][3])
                    if options.grouping == "GRP":
                        phi, theta, psi, sxs, sy_s = get_params_proj(
                            class_data[j])
                        if (theta > 90.0):
                            phi = (phi + 540.0) % 360.0
                            theta = 180.0 - theta
                        aphi += phi
                        atht += theta
                        vphi += phi * phi
                        vtht += theta * theta
                else:
                    members.append(proj_list[i][j])
                    pix_err.append(99999.99)
            aveList[i] = avet.copy()
            if l > 1:
                l += 1
                aveList[i] /= l
                if options.grouping == "GRP":
                    aphi /= l
                    atht /= l
                    vphi = (vphi - l * aphi * aphi) / l
                    vtht = (vtht - l * atht * atht) / l
                    from math import sqrt
                    refprojdir[i] = [
                        aphi, atht,
                        (sqrt(max(vphi, 0.0)) + sqrt(max(vtht, 0.0))) / 2.0
                    ]

            # Here more information has to be stored, PARTICULARLY WHAT IS THE REFERENCE DIRECTION
            aveList[i].set_attr('members', members)
            aveList[i].set_attr('refprojdir', refprojdir[i])
            aveList[i].set_attr('pixerr', pix_err)
        else:
            sxprint(" empty group ", i, refprojdir[i])
            aveList[i].set_attr('members', [-1])
            aveList[i].set_attr('refprojdir', refprojdir[i])
            aveList[i].set_attr('pixerr', [99999.])

    del class_data

    if myid == main_node:
        km = 0
        for i in range(number_of_proc):
            if i == main_node:
                for im in range(len(aveList)):
                    aveList[im].write_image(args[1], km)
                    km += 1
            else:
                nl = mpi_recv(1, MPI_INT, i, SPARX_MPI_TAG_UNIVERSAL,
                              MPI_COMM_WORLD)
                nl = int(nl[0])
                for im in range(nl):
                    ave = recv_EMData(i, im + i + 70000)
                    nm = mpi_recv(1, MPI_INT, i, SPARX_MPI_TAG_UNIVERSAL,
                                  MPI_COMM_WORLD)
                    nm = int(nm[0])
                    members = mpi_recv(nm, MPI_INT, i, SPARX_MPI_TAG_UNIVERSAL,
                                       MPI_COMM_WORLD)
                    ave.set_attr('members', list(map(int, members)))
                    members = mpi_recv(nm, MPI_FLOAT, i,
                                       SPARX_MPI_TAG_UNIVERSAL, MPI_COMM_WORLD)
                    ave.set_attr('pixerr', list(map(float, members)))
                    members = mpi_recv(3, MPI_FLOAT, i,
                                       SPARX_MPI_TAG_UNIVERSAL, MPI_COMM_WORLD)
                    ave.set_attr('refprojdir', list(map(float, members)))
                    ave.write_image(args[1], km)
                    km += 1
    else:
        mpi_send(len(aveList), 1, MPI_INT, main_node, SPARX_MPI_TAG_UNIVERSAL,
                 MPI_COMM_WORLD)
        for im in range(len(aveList)):
            send_EMData(aveList[im], main_node, im + myid + 70000)
            members = aveList[im].get_attr('members')
            mpi_send(len(members), 1, MPI_INT, main_node,
                     SPARX_MPI_TAG_UNIVERSAL, MPI_COMM_WORLD)
            mpi_send(members, len(members), MPI_INT, main_node,
                     SPARX_MPI_TAG_UNIVERSAL, MPI_COMM_WORLD)
            members = aveList[im].get_attr('pixerr')
            mpi_send(members, len(members), MPI_FLOAT, main_node,
                     SPARX_MPI_TAG_UNIVERSAL, MPI_COMM_WORLD)
            try:
                members = aveList[im].get_attr('refprojdir')
                mpi_send(members, 3, MPI_FLOAT, main_node,
                         SPARX_MPI_TAG_UNIVERSAL, MPI_COMM_WORLD)
            except:
                mpi_send([-999.0, -999.0, -999.0], 3, MPI_FLOAT, main_node,
                         SPARX_MPI_TAG_UNIVERSAL, MPI_COMM_WORLD)

    sp_global_def.BATCH = False
    mpi_barrier(MPI_COMM_WORLD)
Ejemplo n.º 6
0
def filterlocal(ui, vi, m, falloff, myid, main_node, number_of_proc):
	from mpi 	  	  import mpi_init, mpi_comm_size, mpi_comm_rank, MPI_COMM_WORLD
	from mpi 	  	  import mpi_reduce, mpi_bcast, mpi_barrier, mpi_gatherv, mpi_send, mpi_recv
	from mpi 	  	  import MPI_SUM, MPI_FLOAT, MPI_INT
	from sp_utilities import bcast_number_to_all, bcast_list_to_all, model_blank, bcast_EMData_to_all, reduce_EMData_to_root
	from sp_morphology import threshold_outside
	from sp_filter import filt_tanl
	from sp_fundamentals import fft, fftip

	if(myid == main_node):

		nx = vi.get_xsize()
		ny = vi.get_ysize()
		nz = vi.get_zsize()
		#  Round all resolution numbers to two digits
		for x in range(nx):
			for y in range(ny):
				for z in range(nz):
					ui.set_value_at_fast( x,y,z, round(ui.get_value_at(x,y,z), 2) )
		dis = [nx,ny,nz]
	else:
		falloff = 0.0
		radius  = 0
		dis = [0,0,0]
	falloff = bcast_number_to_all(falloff, main_node)
	dis = bcast_list_to_all(dis, myid, source_node = main_node)

	if(myid != main_node):
		nx = int(dis[0])
		ny = int(dis[1])
		nz = int(dis[2])

		vi = model_blank(nx,ny,nz)
		ui = model_blank(nx,ny,nz)

	bcast_EMData_to_all(vi, myid, main_node)
	bcast_EMData_to_all(ui, myid, main_node)

	fftip(vi)  #  volume to be filtered

	st = Util.infomask(ui, m, True)


	filteredvol = model_blank(nx,ny,nz)
	cutoff = max(st[2] - 0.01,0.0)
	while(cutoff < st[3] ):
		cutoff = round(cutoff + 0.01, 2)
		#if(myid == main_node):  print  cutoff,st
		pt = Util.infomask( threshold_outside(ui, cutoff - 0.00501, cutoff + 0.005), m, True)  # Ideally, one would want to check only slices in question...
		if(pt[0] != 0.0):
			#print cutoff,pt[0]
			vovo = fft( filt_tanl(vi, cutoff, falloff) )
			for z in range(myid, nz, number_of_proc):
				for x in range(nx):
					for y in range(ny):
						if(m.get_value_at(x,y,z) > 0.5):
							if(round(ui.get_value_at(x,y,z),2) == cutoff):
								filteredvol.set_value_at_fast(x,y,z,vovo.get_value_at(x,y,z))

	mpi_barrier(MPI_COMM_WORLD)
	reduce_EMData_to_root(filteredvol, myid, main_node, MPI_COMM_WORLD)
	return filteredvol
Ejemplo n.º 7
0
def main():
	progname = os.path.basename(sys.argv[0])
	usage = progname + " stack outdir <maskfile> --ir=inner_radius --ou=outer_radius --rs=ring_step --xr=x_range --yr=y_range --ts=translation_step --dst=delta --center=center --maxit=max_iteration --CTF --snr=SNR --Fourvar=Fourier_variance --Ng=group_number --Function=user_function_name --CUDA --GPUID --MPI"
	parser = OptionParser(usage,version=SPARXVERSION)
	parser.add_option("--ir",       type="float",  default=1,             help="inner radius for rotational correlation > 0 (set to 1)")
	parser.add_option("--ou",       type="float",  default=-1,            help="outer radius for rotational correlation < nx/2-1 (set to the radius of the particle)")
	parser.add_option("--rs",       type="float",  default=1,             help="step between rings in rotational correlation > 0 (set to 1)" ) 
	parser.add_option("--xr",       type="string", default="4 2 1 1",     help="range for translation search in x direction, search is +/xr ")
	parser.add_option("--yr",       type="string", default="-1",          help="range for translation search in y direction, search is +/yr ")
	parser.add_option("--ts",       type="string", default="2 1 0.5 0.25",help="step of translation search in both directions")
	parser.add_option("--nomirror", action="store_true", default=False,   help="Disable checking mirror orientations of images (default False)")
	parser.add_option("--dst",      type="float",  default=0.0,           help="delta")
	parser.add_option("--center",   type="float",  default=-1,            help="-1.average center method; 0.not centered; 1.phase approximation; 2.cc with Gaussian function; 3.cc with donut-shaped image 4.cc with user-defined reference 5.cc with self-rotated average")
	parser.add_option("--maxit",    type="float",  default=0,             help="maximum number of iterations (0 means the maximum iterations is 10, but it will automatically stop should the criterion falls")
	parser.add_option("--CTF",      action="store_true", default=False,   help="use CTF correction during alignment")
	parser.add_option("--snr",      type="float",  default=1.0,           help="signal-to-noise ratio of the data (set to 1.0)")
	parser.add_option("--Fourvar",  action="store_true", default=False,   help="compute Fourier variance")
	#parser.add_option("--Ng",       type="int",          default=-1,      help="number of groups in the new CTF filteration")
	parser.add_option("--function", type="string",       default="ref_ali2d",  help="name of the reference preparation function (default ref_ali2d)")
	#parser.add_option("--CUDA",     action="store_true", default=False,   help="use CUDA program")
	#parser.add_option("--GPUID",    type="string",    default="",         help="ID of GPUs available")
	parser.add_option("--MPI",      action="store_true", default=False,   help="use MPI version ")
	parser.add_option("--rotational", action="store_true", default=False, help="rotational alignment with optional limited in-plane angle, the parameters are: ir, ou, rs, psi_max, mode(F or H), maxit, orient, randomize")
	parser.add_option("--psi_max",  type="float",        default=180.0,   help="psi_max")
	parser.add_option("--mode",     type="string",       default="F",     help="Full or Half rings, default F")
	parser.add_option("--randomize",action="store_true", default=False,   help="randomize initial rotations (suboption of friedel, default False)")
	parser.add_option("--orient",   action="store_true", default=False,   help="orient images such that the average is symmetric about x-axis, for layer lines (suboption of friedel, default False)")
	parser.add_option("--template", type="string",       default=None,    help="2D alignment will be initialized using the template provided (only non-MPI version, default None)")
	parser.add_option("--random_method",   type="string", default="",   help="use SHC or SCF (default standard method)")

	(options, args) = parser.parse_args()

	if len(args) < 2 or len(args) > 3:
		sxprint( "Usage: " + usage )
		sxprint( "Please run \'" + progname + " -h\' for detailed options" )
		sp_global_def.ERROR( "Invalid number of parameters used. Please see usage information above." )
		return

	elif(options.rotational):
		from sp_applications import ali2d_rotationaltop
		sp_global_def.BATCH = True
		ali2d_rotationaltop(args[1], args[0], options.randomize, options.orient, options.ir, options.ou, options.rs, options.psi_max, options.mode, options.maxit)
	else:
		if args[1] == 'None': 
			outdir = None
		else:		          
			outdir = args[1]

		if len(args) == 2: 
			mask = None
		else:              
			mask = args[2]
		
		if sp_global_def.CACHE_DISABLE:
			from sp_utilities import disable_bdb_cache
			disable_bdb_cache()
		
		sp_global_def.BATCH = True
		if  options.MPI:
			from sp_applications import ali2d_base
			from mpi import mpi_comm_size, mpi_comm_rank, MPI_COMM_WORLD

			number_of_proc = mpi_comm_size(MPI_COMM_WORLD)
			myid = mpi_comm_rank(MPI_COMM_WORLD)
			main_node = 0

			if(myid == main_node):
				import subprocess
				from sp_logger import Logger, BaseLogger_Files
				#  Create output directory
				log = Logger(BaseLogger_Files())
				log.prefix = os.path.join(outdir)
				cmd = "mkdir "+log.prefix
				outcome = subprocess.call(cmd, shell=True)
				log.prefix += "/"
			else:
				outcome = 0
				log = None
			from sp_utilities       import bcast_number_to_all
			outcome  = bcast_number_to_all(outcome, source_node = main_node)
			if(outcome == 1):
				sp_global_def.ERROR( "Output directory exists, please change the name and restart the program", myid=myid )

			dummy = ali2d_base(args[0], outdir, mask, options.ir, options.ou, options.rs, options.xr, options.yr, \
				options.ts, options.nomirror, options.dst, \
				options.center, options.maxit, options.CTF, options.snr, options.Fourvar, \
				options.function, random_method = options.random_method, log = log, \
				number_of_proc = number_of_proc, myid = myid, main_node = main_node, mpi_comm = MPI_COMM_WORLD,\
				write_headers = True)
		else:
			sxprint( " Non-MPI is no more in use, try MPI option, please." )
			"""
			from sp_applications import ali2d
			ali2d(args[0], outdir, mask, options.ir, options.ou, options.rs, options.xr, options.yr, \
				options.ts, options.nomirror, options.dst, \
				options.center, options.maxit, options.CTF, options.snr, options.Fourvar, \
				-1, options.function, False, "", options.MPI, \
				options.template, random_method = options.random_method)
	    	"""
		sp_global_def.BATCH = False
Ejemplo n.º 8
0
def main(args):

    progname = optparse.os.path.basename(sys.argv[0])
    usage = (
        progname +
        " stack  [output_directory] --ir=inner_radius --rs=ring_step --xr=x_range --yr=y_range  --ts=translational_search_step  --delta=angular_step --center=center_type --maxit1=max_iter1 --maxit2=max_iter2 --L2threshold=0.1 --ref_a=S --sym=c1"
    )
    usage += """

stack			2D images in a stack file: (default required string)
directory		output directory name: into which the results will be written (if it does not exist, it will be created, if it does exist, the results will be written possibly overwriting previous results) (default required string)
"""

    parser = optparse.OptionParser(usage, version=sp_global_def.SPARXVERSION)
    parser.add_option(
        "--radius",
        type="int",
        default=29,
        help=
        "radius of the particle: has to be less than < int(nx/2)-1 (default 29)",
    )

    parser.add_option(
        "--xr",
        type="string",
        default="0",
        help=
        "range for translation search in x direction: search is +/xr in pixels (default '0')",
    )
    parser.add_option(
        "--yr",
        type="string",
        default="0",
        help=
        "range for translation search in y direction: if omitted will be set to xr, search is +/yr in pixels (default '0')",
    )
    parser.add_option("--mask3D",
                      type="string",
                      default=None,
                      help="3D mask file: (default sphere)")
    parser.add_option(
        "--moon_elimination",
        type="string",
        default="",
        help=
        "elimination of disconnected pieces: two arguments: mass in KDa and pixel size in px/A separated by comma, no space (default none)",
    )
    parser.add_option(
        "--ir",
        type="int",
        default=1,
        help="inner radius for rotational search: > 0 (default 1)",
    )

    # 'radius' and 'ou' are the same as per Pawel's request; 'ou' is hidden from the user
    # the 'ou' variable is not changed to 'radius' in the 'sparx' program. This change is at interface level only for sxviper.
    ##### XXXXXXXXXXXXXXXXXXXXXX option does not exist in docs XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
    parser.add_option("--ou",
                      type="int",
                      default=-1,
                      help=optparse.SUPPRESS_HELP)
    parser.add_option(
        "--rs",
        type="int",
        default=1,
        help="step between rings in rotational search: >0 (default 1)",
    )
    parser.add_option(
        "--ts",
        type="string",
        default="1.0",
        help=
        "step size of the translation search in x-y directions: search is -xr, -xr+ts, 0, xr-ts, xr, can be fractional (default '1.0')",
    )
    parser.add_option(
        "--delta",
        type="string",
        default="2.0",
        help="angular step of reference projections: (default '2.0')",
    )
    parser.add_option(
        "--center",
        type="float",
        default=-1.0,
        help=
        "centering of 3D template: average shift method; 0: no centering; 1: center of gravity (default -1.0)",
    )
    parser.add_option(
        "--maxit1",
        type="int",
        default=400,
        help=
        "maximum number of iterations performed for the GA part: (default 400)",
    )
    parser.add_option(
        "--maxit2",
        type="int",
        default=50,
        help=
        "maximum number of iterations performed for the finishing up part: (default 50)",
    )
    parser.add_option(
        "--L2threshold",
        type="float",
        default=0.03,
        help=
        "stopping criterion of GA: given as a maximum relative dispersion of volumes' L2 norms: (default 0.03)",
    )
    parser.add_option(
        "--ref_a",
        type="string",
        default="S",
        help=
        "method for generating the quasi-uniformly distributed projection directions: (default S)",
    )
    parser.add_option(
        "--sym",
        type="string",
        default="c1",
        help="point-group symmetry of the structure: (default c1)",
    )

    # parser.add_option("--function", type="string", default="ref_ali3d",         help="name of the reference preparation function (ref_ali3d by default)")
    ##### XXXXXXXXXXXXXXXXXXXXXX option does not exist in docs XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
    parser.add_option("--function",
                      type="string",
                      default="ref_ali3d",
                      help=optparse.SUPPRESS_HELP)

    parser.add_option(
        "--nruns",
        type="int",
        default=6,
        help=
        "GA population: aka number of quasi-independent volumes (default 6)",
    )
    parser.add_option(
        "--doga",
        type="float",
        default=0.1,
        help=
        "do GA when fraction of orientation changes less than 1.0 degrees is at least doga: (default 0.1)",
    )
    ##### XXXXXXXXXXXXXXXXXXXXXX option does not exist in docs XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
    parser.add_option(
        "--npad",
        type="int",
        default=2,
        help="padding size for 3D reconstruction (default=2)",
    )
    parser.add_option(
        "--fl",
        type="float",
        default=0.25,
        help=
        "cut-off frequency applied to the template volume: using a hyperbolic tangent low-pass filter (default 0.25)",
    )
    parser.add_option(
        "--aa",
        type="float",
        default=0.1,
        help="fall-off of hyperbolic tangent low-pass filter: (default 0.1)",
    )
    parser.add_option(
        "--pwreference",
        type="string",
        default="",
        help="text file with a reference power spectrum: (default none)",
    )
    parser.add_option(
        "--debug",
        action="store_true",
        default=False,
        help="debug info printout: (default False)",
    )

    ##### XXXXXXXXXXXXXXXXXXXXXX option does not exist in docs XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
    parser.add_option(
        "--return_options",
        action="store_true",
        dest="return_options",
        default=False,
        help=optparse.SUPPRESS_HELP,
    )

    # parser.add_option("--an",       type="string", default= "-1",               help="NOT USED angular neighborhood for local searches (phi and theta)")
    # parser.add_option("--CTF",      action="store_true", default=False,         help="NOT USED Consider CTF correction during the alignment ")
    # parser.add_option("--snr",      type="float",  default= 1.0,                help="NOT USED Signal-to-Noise Ratio of the data (default 1.0)")
    # (options, args) = parser.parse_args(sys.argv[1:])

    required_option_list = ["radius"]
    (options, args) = parser.parse_args(args)
    # option_dict = vars(options)
    # print parser

    if options.return_options:
        return parser

    if options.moon_elimination == "":
        options.moon_elimination = []
    else:
        options.moon_elimination = list(
            map(float, options.moon_elimination.split(",")))

    # Making sure all required options appeared.
    for required_option in required_option_list:
        if not options.__dict__[required_option]:
            sp_global_def.sxprint("\n ==%s== mandatory option is missing.\n" %
                                  required_option)
            sp_global_def.sxprint("Please run '" + progname +
                                  " -h' for detailed options")
            sp_global_def.ERROR("Missing parameter. Please see above")
            return

    if len(args) < 2 or len(args) > 3:
        sp_global_def.sxprint("Usage: " + usage)
        sp_global_def.sxprint("Please run '" + progname +
                              " -h' for detailed options")
        sp_global_def.ERROR(
            "Invalid number of parameters used. Please see usage information above."
        )
        return

    log = sp_logger.Logger(sp_logger.BaseLogger_Files())

    # 'radius' and 'ou' are the same as per Pawel's request; 'ou' is hidden from the user
    # the 'ou' variable is not changed to 'radius' in the 'sparx' program. This change is at interface level only for sxviper.
    options.ou = options.radius
    runs_count = options.nruns
    mpi_rank = mpi.mpi_comm_rank(mpi.MPI_COMM_WORLD)
    mpi_size = mpi.mpi_comm_size(
        mpi.MPI_COMM_WORLD
    )  # Total number of processes, passed by --np option.

    if mpi_rank == 0:
        all_projs = EMAN2_cppwrap.EMData.read_images(args[0])
        subset = list(range(len(all_projs)))
        # if mpi_size > len(all_projs):
        # 	ERROR('Number of processes supplied by --np needs to be less than or equal to %d (total number of images) ' % len(all_projs), 'sxviper', 1)
        # 	mpi.mpi_finalize()
        # 	return
    else:
        all_projs = None
        subset = None

    outdir = args[1]
    error = 0
    if mpi_rank == 0:
        if mpi_size % options.nruns != 0:
            sp_global_def.ERROR(
                "Number of processes needs to be a multiple of total number of runs. Total runs by default are 3, you can change it by specifying --nruns option.",
                action=0,
            )
            error = 1

        if optparse.os.path.exists(outdir):
            sp_global_def.ERROR(
                "Output directory '%s' exists, please change the name and restart the program"
                % outdir,
                action=0,
            )
            error = 1
        sp_global_def.LOGFILE = optparse.os.path.join(outdir,
                                                      sp_global_def.LOGFILE)

    mpi.mpi_barrier(mpi.MPI_COMM_WORLD)
    error = sp_utilities.bcast_number_to_all(error,
                                             source_node=0,
                                             mpi_comm=mpi.MPI_COMM_WORLD)
    if error == 1:
        return

    if mpi_rank == 0:
        optparse.os.makedirs(outdir)
        sp_global_def.write_command(outdir)

    if outdir[-1] != "/":
        outdir += "/"
    log.prefix = outdir

    # if len(args) > 2:
    # 	ref_vol = get_im(args[2])
    # else:
    # ref_vol = None

    options.user_func = sp_user_functions.factory[options.function]

    options.CTF = False
    options.snr = 1.0
    options.an = -1.0
    out_params, out_vol, out_peaks = sp_multi_shc.multi_shc(
        all_projs,
        subset,
        runs_count,
        options,
        mpi_comm=mpi.MPI_COMM_WORLD,
        log=log)
Ejemplo n.º 9
0
def main():
    def params_3D_2D_NEW(phi, theta, psi, s2x, s2y, mirror):
        # the final ali2d parameters already combine shifts operation first and rotation operation second for parameters converted from 3D
        if mirror:
            m = 1
            alpha, sx, sy, scalen = compose_transform2(0, s2x, s2y, 1.0,
                                                       540.0 - psi, 0, 0, 1.0)
        else:
            m = 0
            alpha, sx, sy, scalen = compose_transform2(0, s2x, s2y, 1.0,
                                                       360.0 - psi, 0, 0, 1.0)
        return alpha, sx, sy, m

    progname = os.path.basename(sys.argv[0])
    usage = progname + " prj_stack  --ave2D= --var2D=  --ave3D= --var3D= --img_per_grp= --fl=  --aa=   --sym=symmetry --CTF"
    parser = OptionParser(usage, version=SPARXVERSION)

    parser.add_option("--output_dir",
                      type="string",
                      default="./",
                      help="Output directory")
    parser.add_option("--ave2D",
                      type="string",
                      default=False,
                      help="Write to the disk a stack of 2D averages")
    parser.add_option("--var2D",
                      type="string",
                      default=False,
                      help="Write to the disk a stack of 2D variances")
    parser.add_option("--ave3D",
                      type="string",
                      default=False,
                      help="Write to the disk reconstructed 3D average")
    parser.add_option("--var3D",
                      type="string",
                      default=False,
                      help="Compute 3D variability (time consuming!)")
    parser.add_option(
        "--img_per_grp",
        type="int",
        default=100,
        help="Number of neighbouring projections.(Default is 100)")
    parser.add_option(
        "--no_norm",
        action="store_true",
        default=False,
        help="Do not use normalization.(Default is to apply normalization)")
    #parser.add_option("--radius", 	    type="int"         ,	default=-1   ,				help="radius for 3D variability" )
    parser.add_option(
        "--npad",
        type="int",
        default=2,
        help=
        "Number of time to pad the original images.(Default is 2 times padding)"
    )
    parser.add_option("--sym",
                      type="string",
                      default="c1",
                      help="Symmetry. (Default is no symmetry)")
    parser.add_option(
        "--fl",
        type="float",
        default=0.0,
        help=
        "Low pass filter cutoff in absolute frequency (0.0 - 0.5) and is applied to decimated images. (Default - no filtration)"
    )
    parser.add_option(
        "--aa",
        type="float",
        default=0.02,
        help=
        "Fall off of the filter. Use default value if user has no clue about falloff (Default value is 0.02)"
    )
    parser.add_option("--CTF",
                      action="store_true",
                      default=False,
                      help="Use CFT correction.(Default is no CTF correction)")
    #parser.add_option("--MPI" , 		action="store_true",	default=False,				help="use MPI version")
    #parser.add_option("--radiuspca", 	type="int"         ,	default=-1   ,				help="radius for PCA" )
    #parser.add_option("--iter", 		type="int"         ,	default=40   ,				help="maximum number of iterations (stop criterion of reconstruction process)" )
    #parser.add_option("--abs", 		type="float"   ,        default=0.0  ,				help="minimum average absolute change of voxels' values (stop criterion of reconstruction process)" )
    #parser.add_option("--squ", 		type="float"   ,	    default=0.0  ,				help="minimum average squared change of voxels' values (stop criterion of reconstruction process)" )
    parser.add_option(
        "--VAR",
        action="store_true",
        default=False,
        help="Stack of input consists of 2D variances (Default False)")
    parser.add_option(
        "--decimate",
        type="float",
        default=0.25,
        help="Image decimate rate, a number less than 1. (Default is 0.25)")
    parser.add_option(
        "--window",
        type="int",
        default=0,
        help=
        "Target image size relative to original image size. (Default value is zero.)"
    )
    #parser.add_option("--SND",			action="store_true",	default=False,				help="compute squared normalized differences (Default False)")
    #parser.add_option("--nvec",			type="int"         ,	default=0    ,				help="Number of eigenvectors, (Default = 0 meaning no PCA calculated)")
    parser.add_option(
        "--symmetrize",
        action="store_true",
        default=False,
        help="Prepare input stack for handling symmetry (Default False)")
    parser.add_option("--overhead",
                      type="float",
                      default=0.5,
                      help="python overhead per CPU.")

    (options, args) = parser.parse_args()
    #####
    from mpi import mpi_comm_rank, mpi_comm_size, mpi_recv, MPI_COMM_WORLD
    from mpi import mpi_barrier, mpi_reduce, mpi_bcast, mpi_send, MPI_FLOAT, MPI_SUM, MPI_INT, MPI_MAX
    #from mpi import *
    from sp_applications import MPI_start_end
    from sp_reconstruction import recons3d_em, recons3d_em_MPI
    from sp_reconstruction import recons3d_4nn_MPI, recons3d_4nn_ctf_MPI
    from sp_utilities import print_begin_msg, print_end_msg, print_msg
    from sp_utilities import read_text_row, get_image, get_im, wrap_mpi_send, wrap_mpi_recv
    from sp_utilities import bcast_EMData_to_all, bcast_number_to_all
    from sp_utilities import get_symt

    #  This is code for handling symmetries by the above program.  To be incorporated. PAP 01/27/2015

    from EMAN2db import db_open_dict

    # Set up global variables related to bdb cache
    if sp_global_def.CACHE_DISABLE:
        from sp_utilities import disable_bdb_cache
        disable_bdb_cache()

    # Set up global variables related to ERROR function
    sp_global_def.BATCH = True

    # detect if program is running under MPI
    RUNNING_UNDER_MPI = "OMPI_COMM_WORLD_SIZE" in os.environ
    if RUNNING_UNDER_MPI: sp_global_def.MPI = True
    if options.output_dir == "./":
        current_output_dir = os.path.abspath(options.output_dir)
    else:
        current_output_dir = options.output_dir
    if options.symmetrize:

        if mpi.mpi_comm_size(MPI_COMM_WORLD) > 1:
            ERROR("Cannot use more than one CPU for symmetry preparation")

        if not os.path.exists(current_output_dir):
            os.makedirs(current_output_dir)
            sp_global_def.write_command(current_output_dir)

        from sp_logger import Logger, BaseLogger_Files
        if os.path.exists(os.path.join(current_output_dir, "log.txt")):
            os.remove(os.path.join(current_output_dir, "log.txt"))
        log_main = Logger(BaseLogger_Files())
        log_main.prefix = os.path.join(current_output_dir, "./")

        instack = args[0]
        sym = options.sym.lower()
        if (sym == "c1"):
            ERROR("There is no need to symmetrize stack for C1 symmetry")

        line = ""
        for a in sys.argv:
            line += " " + a
        log_main.add(line)

        if (instack[:4] != "bdb:"):
            #if output_dir =="./": stack = "bdb:data"
            stack = "bdb:" + current_output_dir + "/data"
            delete_bdb(stack)
            junk = cmdexecute("sp_cpy.py  " + instack + "  " + stack)
        else:
            stack = instack

        qt = EMUtil.get_all_attributes(stack, 'xform.projection')

        na = len(qt)
        ts = get_symt(sym)
        ks = len(ts)
        angsa = [None] * na

        for k in range(ks):
            #Qfile = "Q%1d"%k
            #if options.output_dir!="./": Qfile = os.path.join(options.output_dir,"Q%1d"%k)
            Qfile = os.path.join(current_output_dir, "Q%1d" % k)
            #delete_bdb("bdb:Q%1d"%k)
            delete_bdb("bdb:" + Qfile)
            #junk = cmdexecute("e2bdb.py  "+stack+"  --makevstack=bdb:Q%1d"%k)
            junk = cmdexecute("e2bdb.py  " + stack + "  --makevstack=bdb:" +
                              Qfile)
            #DB = db_open_dict("bdb:Q%1d"%k)
            DB = db_open_dict("bdb:" + Qfile)
            for i in range(na):
                ut = qt[i] * ts[k]
                DB.set_attr(i, "xform.projection", ut)
                #bt = ut.get_params("spider")
                #angsa[i] = [round(bt["phi"],3)%360.0, round(bt["theta"],3)%360.0, bt["psi"], -bt["tx"], -bt["ty"]]
            #write_text_row(angsa, 'ptsma%1d.txt'%k)
            #junk = cmdexecute("e2bdb.py  "+stack+"  --makevstack=bdb:Q%1d"%k)
            #junk = cmdexecute("sxheader.py  bdb:Q%1d  --params=xform.projection  --import=ptsma%1d.txt"%(k,k))
            DB.close()
        #if options.output_dir =="./": delete_bdb("bdb:sdata")
        delete_bdb("bdb:" + current_output_dir + "/" + "sdata")
        #junk = cmdexecute("e2bdb.py . --makevstack=bdb:sdata --filt=Q")
        sdata = "bdb:" + current_output_dir + "/" + "sdata"
        sxprint(sdata)
        junk = cmdexecute("e2bdb.py   " + current_output_dir +
                          "  --makevstack=" + sdata + " --filt=Q")
        #junk = cmdexecute("ls  EMAN2DB/sdata*")
        #a = get_im("bdb:sdata")
        a = get_im(sdata)
        a.set_attr("variabilitysymmetry", sym)
        #a.write_image("bdb:sdata")
        a.write_image(sdata)

    else:

        from sp_fundamentals import window2d
        myid = mpi_comm_rank(MPI_COMM_WORLD)
        number_of_proc = mpi_comm_size(MPI_COMM_WORLD)
        main_node = 0
        shared_comm = mpi_comm_split_type(MPI_COMM_WORLD, MPI_COMM_TYPE_SHARED,
                                          0, MPI_INFO_NULL)
        myid_on_node = mpi_comm_rank(shared_comm)
        no_of_processes_per_group = mpi_comm_size(shared_comm)
        masters_from_groups_vs_everything_else_comm = mpi_comm_split(
            MPI_COMM_WORLD, main_node == myid_on_node, myid_on_node)
        color, no_of_groups, balanced_processor_load_on_nodes = get_colors_and_subsets(main_node, MPI_COMM_WORLD, myid, \
            shared_comm, myid_on_node, masters_from_groups_vs_everything_else_comm)
        overhead_loading = options.overhead * number_of_proc
        #memory_per_node  = options.memory_per_node
        #if memory_per_node == -1.: memory_per_node = 2.*no_of_processes_per_group
        keepgoing = 1

        current_window = options.window
        current_decimate = options.decimate

        if len(args) == 1: stack = args[0]
        else:
            sxprint("Usage: " + usage)
            sxprint("Please run \'" + progname + " -h\' for detailed options")
            ERROR(
                "Invalid number of parameters used. Please see usage information above."
            )
            return

        t0 = time()
        # obsolete flags
        options.MPI = True
        #options.nvec = 0
        options.radiuspca = -1
        options.iter = 40
        options.abs = 0.0
        options.squ = 0.0

        if options.fl > 0.0 and options.aa == 0.0:
            ERROR("Fall off has to be given for the low-pass filter",
                  myid=myid)

        #if options.VAR and options.SND:
        #	ERROR( "Only one of var and SND can be set!",myid=myid )

        if options.VAR and (options.ave2D or options.ave3D or options.var2D):
            ERROR(
                "When VAR is set, the program cannot output ave2D, ave3D or var2D",
                myid=myid)

        #if options.SND and (options.ave2D or options.ave3D):
        #	ERROR( "When SND is set, the program cannot output ave2D or ave3D", myid=myid )

        #if options.nvec > 0 :
        #	ERROR( "PCA option not implemented", myid=myid )

        #if options.nvec > 0 and options.ave3D == None:
        #	ERROR( "When doing PCA analysis, one must set ave3D", myid=myid )

        if current_decimate > 1.0 or current_decimate < 0.0:
            ERROR("Decimate rate should be a value between 0.0 and 1.0",
                  myid=myid)

        if current_window < 0.0:
            ERROR("Target window size should be always larger than zero",
                  myid=myid)

        if myid == main_node:
            img = get_image(stack, 0)
            nx = img.get_xsize()
            ny = img.get_ysize()
            if (min(nx, ny) < current_window): keepgoing = 0
        keepgoing = bcast_number_to_all(keepgoing, main_node, MPI_COMM_WORLD)
        if keepgoing == 0:
            ERROR(
                "The target window size cannot be larger than the size of decimated image",
                myid=myid)

        import string
        options.sym = options.sym.lower()
        # if global_def.CACHE_DISABLE:
        # 	from utilities import disable_bdb_cache
        # 	disable_bdb_cache()
        # global_def.BATCH = True

        if myid == main_node:
            if not os.path.exists(current_output_dir):
                os.makedirs(current_output_dir
                            )  # Never delete output_dir in the program!

        img_per_grp = options.img_per_grp
        #nvec        = options.nvec
        radiuspca = options.radiuspca
        from sp_logger import Logger, BaseLogger_Files
        #if os.path.exists(os.path.join(options.output_dir, "log.txt")): os.remove(os.path.join(options.output_dir, "log.txt"))
        log_main = Logger(BaseLogger_Files())
        log_main.prefix = os.path.join(current_output_dir, "./")

        if myid == main_node:
            line = ""
            for a in sys.argv:
                line += " " + a
            log_main.add(line)
            log_main.add("-------->>>Settings given by all options<<<-------")
            log_main.add("Symmetry             : %s" % options.sym)
            log_main.add("Input stack          : %s" % stack)
            log_main.add("Output_dir           : %s" % current_output_dir)

            if options.ave3D:
                log_main.add("Ave3d                : %s" % options.ave3D)
            if options.var3D:
                log_main.add("Var3d                : %s" % options.var3D)
            if options.ave2D:
                log_main.add("Ave2D                : %s" % options.ave2D)
            if options.var2D:
                log_main.add("Var2D                : %s" % options.var2D)
            if options.VAR: log_main.add("VAR                  : True")
            else: log_main.add("VAR                  : False")
            if options.CTF: log_main.add("CTF correction       : True  ")
            else: log_main.add("CTF correction       : False ")

            log_main.add("Image per group      : %5d" % options.img_per_grp)
            log_main.add("Image decimate rate  : %4.3f" % current_decimate)
            log_main.add("Low pass filter      : %4.3f" % options.fl)
            current_fl = options.fl
            if current_fl == 0.0: current_fl = 0.5
            log_main.add(
                "Current low pass filter is equivalent to cutoff frequency %4.3f for original image size"
                % round((current_fl * current_decimate), 3))
            log_main.add("Window size          : %5d " % current_window)
            log_main.add("sx3dvariability begins")

        symbaselen = 0
        if myid == main_node:
            nima = EMUtil.get_image_count(stack)
            img = get_image(stack)
            nx = img.get_xsize()
            ny = img.get_ysize()
            nnxo = nx
            nnyo = ny
            if options.sym != "c1":
                imgdata = get_im(stack)
                try:
                    i = imgdata.get_attr("variabilitysymmetry").lower()
                    if (i != options.sym):
                        ERROR(
                            "The symmetry provided does not agree with the symmetry of the input stack",
                            myid=myid)
                except:
                    ERROR(
                        "Input stack is not prepared for symmetry, please follow instructions",
                        myid=myid)
                from sp_utilities import get_symt
                i = len(get_symt(options.sym))
                if ((nima / i) * i != nima):
                    ERROR(
                        "The length of the input stack is incorrect for symmetry processing",
                        myid=myid)
                symbaselen = nima / i
            else:
                symbaselen = nima
        else:
            nima = 0
            nx = 0
            ny = 0
            nnxo = 0
            nnyo = 0
        nima = bcast_number_to_all(nima)
        nx = bcast_number_to_all(nx)
        ny = bcast_number_to_all(ny)
        nnxo = bcast_number_to_all(nnxo)
        nnyo = bcast_number_to_all(nnyo)
        if current_window > max(nx, ny):
            ERROR("Window size is larger than the original image size")

        if current_decimate == 1.:
            if current_window != 0:
                nx = current_window
                ny = current_window
        else:
            if current_window == 0:
                nx = int(nx * current_decimate + 0.5)
                ny = int(ny * current_decimate + 0.5)
            else:
                nx = int(current_window * current_decimate + 0.5)
                ny = nx
        symbaselen = bcast_number_to_all(symbaselen)

        # check FFT prime number
        from sp_fundamentals import smallprime
        is_fft_friendly = (nx == smallprime(nx))

        if not is_fft_friendly:
            if myid == main_node:
                log_main.add(
                    "The target image size is not a product of small prime numbers"
                )
                log_main.add("Program adjusts the input settings!")
            ### two cases
            if current_decimate == 1.:
                nx = smallprime(nx)
                ny = nx
                current_window = nx  # update
                if myid == main_node:
                    log_main.add("The window size is updated to %d." %
                                 current_window)
            else:
                if current_window == 0:
                    nx = smallprime(int(nx * current_decimate + 0.5))
                    current_decimate = float(nx) / nnxo
                    ny = nx
                    if (myid == main_node):
                        log_main.add("The decimate rate is updated to %f." %
                                     current_decimate)
                else:
                    nx = smallprime(
                        int(current_window * current_decimate + 0.5))
                    ny = nx
                    current_window = int(nx / current_decimate + 0.5)
                    if (myid == main_node):
                        log_main.add("The window size is updated to %d." %
                                     current_window)

        if myid == main_node:
            log_main.add("The target image size is %d" % nx)

        if radiuspca == -1: radiuspca = nx / 2 - 2
        if myid == main_node:
            log_main.add("%-70s:  %d\n" % ("Number of projection", nima))
        img_begin, img_end = MPI_start_end(nima, number_of_proc, myid)
        """
		if options.SND:
			from sp_projection		import prep_vol, prgs
			from sp_statistics		import im_diff
			from sp_utilities		import get_im, model_circle, get_params_proj, set_params_proj
			from sp_utilities		import get_ctf, generate_ctf
			from sp_filter			import filt_ctf
		
			imgdata = EMData.read_images(stack, range(img_begin, img_end))

			if options.CTF:
				vol = recons3d_4nn_ctf_MPI(myid, imgdata, 1.0, symmetry=options.sym, npad=options.npad, xysize=-1, zsize=-1)
			else:
				vol = recons3d_4nn_MPI(myid, imgdata, symmetry=options.sym, npad=options.npad, xysize=-1, zsize=-1)

			bcast_EMData_to_all(vol, myid)
			volft, kb = prep_vol(vol)

			mask = model_circle(nx/2-2, nx, ny)
			varList = []
			for i in xrange(img_begin, img_end):
				phi, theta, psi, s2x, s2y = get_params_proj(imgdata[i-img_begin])
				ref_prj = prgs(volft, kb, [phi, theta, psi, -s2x, -s2y])
				if options.CTF:
					ctf_params = get_ctf(imgdata[i-img_begin])
					ref_prj = filt_ctf(ref_prj, generate_ctf(ctf_params))
				diff, A, B = im_diff(ref_prj, imgdata[i-img_begin], mask)
				diff2 = diff*diff
				set_params_proj(diff2, [phi, theta, psi, s2x, s2y])
				varList.append(diff2)
			mpi_barrier(MPI_COMM_WORLD)
		"""

        if options.VAR:  # 2D variance images have no shifts
            #varList   = EMData.read_images(stack, range(img_begin, img_end))
            from EMAN2 import Region
            for index_of_particle in range(img_begin, img_end):
                image = get_im(stack, index_of_proj)
                if current_window > 0:
                    varList.append(
                        fdecimate(
                            window2d(image, current_window, current_window),
                            nx, ny))
                else:
                    varList.append(fdecimate(image, nx, ny))

        else:
            from sp_utilities import bcast_number_to_all, bcast_list_to_all, send_EMData, recv_EMData
            from sp_utilities import set_params_proj, get_params_proj, params_3D_2D, get_params2D, set_params2D, compose_transform2
            from sp_utilities import model_blank, nearest_proj, model_circle, write_text_row, wrap_mpi_gatherv
            from sp_applications import pca
            from sp_statistics import avgvar, avgvar_ctf, ccc
            from sp_filter import filt_tanl
            from sp_morphology import threshold, square_root
            from sp_projection import project, prep_vol, prgs
            from sets import Set
            from sp_utilities import wrap_mpi_recv, wrap_mpi_bcast, wrap_mpi_send
            import numpy as np
            if myid == main_node:
                t1 = time()
                proj_angles = []
                aveList = []
                tab = EMUtil.get_all_attributes(stack, 'xform.projection')
                for i in range(nima):
                    t = tab[i].get_params('spider')
                    phi = t['phi']
                    theta = t['theta']
                    psi = t['psi']
                    x = theta
                    if x > 90.0: x = 180.0 - x
                    x = x * 10000 + psi
                    proj_angles.append([x, t['phi'], t['theta'], t['psi'], i])
                t2 = time()
                log_main.add(
                    "%-70s:  %d\n" %
                    ("Number of neighboring projections", img_per_grp))
                log_main.add("...... Finding neighboring projections\n")
                log_main.add("Number of images per group: %d" % img_per_grp)
                log_main.add("Now grouping projections")
                proj_angles.sort()
                proj_angles_list = np.full((nima, 4), 0.0, dtype=np.float32)
                for i in range(nima):
                    proj_angles_list[i][0] = proj_angles[i][1]
                    proj_angles_list[i][1] = proj_angles[i][2]
                    proj_angles_list[i][2] = proj_angles[i][3]
                    proj_angles_list[i][3] = proj_angles[i][4]
            else:
                proj_angles_list = 0
            proj_angles_list = wrap_mpi_bcast(proj_angles_list, main_node,
                                              MPI_COMM_WORLD)
            proj_angles = []
            for i in range(nima):
                proj_angles.append([
                    proj_angles_list[i][0], proj_angles_list[i][1],
                    proj_angles_list[i][2],
                    int(proj_angles_list[i][3])
                ])
            del proj_angles_list
            proj_list, mirror_list = nearest_proj(proj_angles, img_per_grp,
                                                  range(img_begin, img_end))
            all_proj = Set()
            for im in proj_list:
                for jm in im:
                    all_proj.add(proj_angles[jm][3])
            all_proj = list(all_proj)
            index = {}
            for i in range(len(all_proj)):
                index[all_proj[i]] = i
            mpi_barrier(MPI_COMM_WORLD)
            if myid == main_node:
                log_main.add("%-70s:  %.2f\n" %
                             ("Finding neighboring projections lasted [s]",
                              time() - t2))
                log_main.add("%-70s:  %d\n" %
                             ("Number of groups processed on the main node",
                              len(proj_list)))
                log_main.add("Grouping projections took:  %12.1f [m]" %
                             ((time() - t2) / 60.))
                log_main.add("Number of groups on main node: ", len(proj_list))
            mpi_barrier(MPI_COMM_WORLD)

            if myid == main_node:
                log_main.add("...... Calculating the stack of 2D variances \n")
            # Memory estimation. There are two memory consumption peaks
            # peak 1. Compute ave, var;
            # peak 2. Var volume reconstruction;
            # proj_params = [0.0]*(nima*5)
            aveList = []
            varList = []
            #if nvec > 0: eigList = [[] for i in range(nvec)]
            dnumber = len(
                all_proj)  # all neighborhood set for assigned to myid
            pnumber = len(proj_list) * 2. + img_per_grp  # aveList and varList
            tnumber = dnumber + pnumber
            vol_size2 = nx**3 * 4. * 8 / 1.e9
            vol_size1 = 2. * nnxo**3 * 4. * 8 / 1.e9
            proj_size = nnxo * nnyo * len(
                proj_list) * 4. * 2. / 1.e9  # both aveList and varList
            orig_data_size = nnxo * nnyo * 4. * tnumber / 1.e9
            reduced_data_size = nx * nx * 4. * tnumber / 1.e9
            full_data = np.full((number_of_proc, 2), -1., dtype=np.float16)
            full_data[myid] = orig_data_size, reduced_data_size
            if myid != main_node:
                wrap_mpi_send(full_data, main_node, MPI_COMM_WORLD)
            if myid == main_node:
                for iproc in range(number_of_proc):
                    if iproc != main_node:
                        dummy = wrap_mpi_recv(iproc, MPI_COMM_WORLD)
                        full_data[np.where(dummy > -1)] = dummy[np.where(
                            dummy > -1)]
                del dummy
            mpi_barrier(MPI_COMM_WORLD)
            full_data = wrap_mpi_bcast(full_data, main_node, MPI_COMM_WORLD)
            # find the CPU with heaviest load
            minindx = np.argsort(full_data, 0)
            heavy_load_myid = minindx[-1][1]
            total_mem = sum(full_data)
            if myid == main_node:
                if current_window == 0:
                    log_main.add(
                        "Nx:   current image size = %d. Decimated by %f from %d"
                        % (nx, current_decimate, nnxo))
                else:
                    log_main.add(
                        "Nx:   current image size = %d. Windowed to %d, and decimated by %f from %d"
                        % (nx, current_window, current_decimate, nnxo))
                log_main.add("Nproj:       number of particle images.")
                log_main.add("Navg:        number of 2D average images.")
                log_main.add("Nvar:        number of 2D variance images.")
                log_main.add(
                    "Img_per_grp: user defined image per group for averaging = %d"
                    % img_per_grp)
                log_main.add(
                    "Overhead:    total python overhead memory consumption   = %f"
                    % overhead_loading)
                log_main.add("Total memory) = 4.0*nx^2*(nproj + navg +nvar+ img_per_grp)/1.0e9 + overhead: %12.3f [GB]"%\
                   (total_mem[1] + overhead_loading))
            del full_data
            mpi_barrier(MPI_COMM_WORLD)
            if myid == heavy_load_myid:
                log_main.add(
                    "Begin reading and preprocessing images on processor. Wait... "
                )
                ttt = time()
            #imgdata = EMData.read_images(stack, all_proj)
            imgdata = [None for im in range(len(all_proj))]
            for index_of_proj in range(len(all_proj)):
                #image = get_im(stack, all_proj[index_of_proj])
                if (current_window > 0):
                    imgdata[index_of_proj] = fdecimate(
                        window2d(get_im(stack, all_proj[index_of_proj]),
                                 current_window, current_window), nx, ny)
                else:
                    imgdata[index_of_proj] = fdecimate(
                        get_im(stack, all_proj[index_of_proj]), nx, ny)

                if (current_decimate > 0.0 and options.CTF):
                    ctf = imgdata[index_of_proj].get_attr("ctf")
                    ctf.apix = ctf.apix / current_decimate
                    imgdata[index_of_proj].set_attr("ctf", ctf)

                if myid == heavy_load_myid and index_of_proj % 100 == 0:
                    log_main.add(" ...... %6.2f%% " %
                                 (index_of_proj / float(len(all_proj)) * 100.))
            mpi_barrier(MPI_COMM_WORLD)
            if myid == heavy_load_myid:
                log_main.add("All_proj preprocessing cost %7.2f m" %
                             ((time() - ttt) / 60.))
                log_main.add("Wait untill reading on all CPUs done...")
            '''	
			imgdata2 = EMData.read_images(stack, range(img_begin, img_end))
			if options.fl > 0.0:
				for k in xrange(len(imgdata2)):
					imgdata2[k] = filt_tanl(imgdata2[k], options.fl, options.aa)
			if options.CTF:
				vol = recons3d_4nn_ctf_MPI(myid, imgdata2, 1.0, symmetry=options.sym, npad=options.npad, xysize=-1, zsize=-1)
			else:
				vol = recons3d_4nn_MPI(myid, imgdata2, symmetry=options.sym, npad=options.npad, xysize=-1, zsize=-1)
			if myid == main_node:
				vol.write_image("vol_ctf.hdf")
				print_msg("Writing to the disk volume reconstructed from averages as		:  %s\n"%("vol_ctf.hdf"))
			del vol, imgdata2
			mpi_barrier(MPI_COMM_WORLD)
			'''
            from sp_applications import prepare_2d_forPCA
            from sp_utilities import model_blank
            from EMAN2 import Transform
            if not options.no_norm:
                mask = model_circle(nx / 2 - 2, nx, nx)
            if options.CTF:
                from sp_utilities import pad
                from sp_filter import filt_ctf
            from sp_filter import filt_tanl
            if myid == heavy_load_myid:
                log_main.add("Start computing 2D aveList and varList. Wait...")
                ttt = time()
            inner = nx // 2 - 4
            outer = inner + 2
            xform_proj_for_2D = [None for i in range(len(proj_list))]
            for i in range(len(proj_list)):
                ki = proj_angles[proj_list[i][0]][3]
                if ki >= symbaselen: continue
                mi = index[ki]
                dpar = Util.get_transform_params(imgdata[mi],
                                                 "xform.projection", "spider")
                phiM, thetaM, psiM, s2xM, s2yM = dpar["phi"], dpar[
                    "theta"], dpar[
                        "psi"], -dpar["tx"] * current_decimate, -dpar[
                            "ty"] * current_decimate
                grp_imgdata = []
                for j in range(img_per_grp):
                    mj = index[proj_angles[proj_list[i][j]][3]]
                    cpar = Util.get_transform_params(imgdata[mj],
                                                     "xform.projection",
                                                     "spider")
                    alpha, sx, sy, mirror = params_3D_2D_NEW(
                        cpar["phi"], cpar["theta"], cpar["psi"],
                        -cpar["tx"] * current_decimate,
                        -cpar["ty"] * current_decimate, mirror_list[i][j])
                    if thetaM <= 90:
                        if mirror == 0:
                            alpha, sx, sy, scale = compose_transform2(
                                alpha, sx, sy, 1.0, phiM - cpar["phi"], 0.0,
                                0.0, 1.0)
                        else:
                            alpha, sx, sy, scale = compose_transform2(
                                alpha, sx, sy, 1.0, 180 - (phiM - cpar["phi"]),
                                0.0, 0.0, 1.0)
                    else:
                        if mirror == 0:
                            alpha, sx, sy, scale = compose_transform2(
                                alpha, sx, sy, 1.0, -(phiM - cpar["phi"]), 0.0,
                                0.0, 1.0)
                        else:
                            alpha, sx, sy, scale = compose_transform2(
                                alpha, sx, sy, 1.0,
                                -(180 - (phiM - cpar["phi"])), 0.0, 0.0, 1.0)
                    imgdata[mj].set_attr(
                        "xform.align2d",
                        Transform({
                            "type": "2D",
                            "alpha": alpha,
                            "tx": sx,
                            "ty": sy,
                            "mirror": mirror,
                            "scale": 1.0
                        }))
                    grp_imgdata.append(imgdata[mj])
                if not options.no_norm:
                    for k in range(img_per_grp):
                        ave, std, minn, maxx = Util.infomask(
                            grp_imgdata[k], mask, False)
                        grp_imgdata[k] -= ave
                        grp_imgdata[k] /= std
                if options.fl > 0.0:
                    for k in range(img_per_grp):
                        grp_imgdata[k] = filt_tanl(grp_imgdata[k], options.fl,
                                                   options.aa)

                #  Because of background issues, only linear option works.
                if options.CTF:
                    ave, var = aves_wiener(grp_imgdata,
                                           SNR=1.0e5,
                                           interpolation_method="linear")
                else:
                    ave, var = ave_var(grp_imgdata)
                # Switch to std dev
                # threshold is not really needed,it is just in case due to numerical accuracy something turns out negative.
                var = square_root(threshold(var))

                set_params_proj(ave, [phiM, thetaM, 0.0, 0.0, 0.0])
                set_params_proj(var, [phiM, thetaM, 0.0, 0.0, 0.0])

                aveList.append(ave)
                varList.append(var)
                xform_proj_for_2D[i] = [phiM, thetaM, 0.0, 0.0, 0.0]
                '''
				if nvec > 0:
					eig = pca(input_stacks=grp_imgdata, subavg="", mask_radius=radiuspca, nvec=nvec, incore=True, shuffle=False, genbuf=True)
					for k in range(nvec):
						set_params_proj(eig[k], [phiM, thetaM, 0.0, 0.0, 0.0])
						eigList[k].append(eig[k])
					"""
					if myid == 0 and i == 0:
						for k in xrange(nvec):
							eig[k].write_image("eig.hdf", k)
					"""
				'''
                if (myid == heavy_load_myid) and (i % 100 == 0):
                    log_main.add(" ......%6.2f%%  " %
                                 (i / float(len(proj_list)) * 100.))
            del imgdata, grp_imgdata, cpar, dpar, all_proj, proj_angles, index
            if not options.no_norm: del mask
            if myid == main_node: del tab
            #  At this point, all averages and variances are computed
            mpi_barrier(MPI_COMM_WORLD)

            if (myid == heavy_load_myid):
                log_main.add("Computing aveList and varList took %12.1f [m]" %
                             ((time() - ttt) / 60.))

            xform_proj_for_2D = wrap_mpi_gatherv(xform_proj_for_2D, main_node,
                                                 MPI_COMM_WORLD)
            if (myid == main_node):
                write_text_row([str(entry) for entry in xform_proj_for_2D],
                               os.path.join(current_output_dir, "params.txt"))
            del xform_proj_for_2D
            mpi_barrier(MPI_COMM_WORLD)
            if options.ave2D:
                from sp_fundamentals import fpol
                from sp_applications import header
                if myid == main_node:
                    log_main.add("Compute ave2D ... ")
                    km = 0
                    for i in range(number_of_proc):
                        if i == main_node:
                            for im in range(len(aveList)):
                                aveList[im].write_image(
                                    os.path.join(current_output_dir,
                                                 options.ave2D), km)
                                km += 1
                        else:
                            nl = mpi_recv(1, MPI_INT, i,
                                          SPARX_MPI_TAG_UNIVERSAL,
                                          MPI_COMM_WORLD)
                            nl = int(nl[0])
                            for im in range(nl):
                                ave = recv_EMData(i, im + i + 70000)
                                """
								nm = mpi_recv(1, MPI_INT, i, SPARX_MPI_TAG_UNIVERSAL, MPI_COMM_WORLD)
								nm = int(nm[0])
								members = mpi_recv(nm, MPI_INT, i, SPARX_MPI_TAG_UNIVERSAL, MPI_COMM_WORLD)
								ave.set_attr('members', map(int, members))
								members = mpi_recv(nm, MPI_FLOAT, i, SPARX_MPI_TAG_UNIVERSAL, MPI_COMM_WORLD)
								ave.set_attr('pix_err', map(float, members))
								members = mpi_recv(3, MPI_FLOAT, i, SPARX_MPI_TAG_UNIVERSAL, MPI_COMM_WORLD)
								ave.set_attr('refprojdir', map(float, members))
								"""
                                tmpvol = fpol(ave, nx, nx, 1)
                                tmpvol.write_image(
                                    os.path.join(current_output_dir,
                                                 options.ave2D), km)
                                km += 1
                else:
                    mpi_send(len(aveList), 1, MPI_INT, main_node,
                             SPARX_MPI_TAG_UNIVERSAL, MPI_COMM_WORLD)
                    for im in range(len(aveList)):
                        send_EMData(aveList[im], main_node, im + myid + 70000)
                        """
						members = aveList[im].get_attr('members')
						mpi_send(len(members), 1, MPI_INT, main_node, SPARX_MPI_TAG_UNIVERSAL, MPI_COMM_WORLD)
						mpi_send(members, len(members), MPI_INT, main_node, SPARX_MPI_TAG_UNIVERSAL, MPI_COMM_WORLD)
						members = aveList[im].get_attr('pix_err')
						mpi_send(members, len(members), MPI_FLOAT, main_node, SPARX_MPI_TAG_UNIVERSAL, MPI_COMM_WORLD)
						try:
							members = aveList[im].get_attr('refprojdir')
							mpi_send(members, 3, MPI_FLOAT, main_node, SPARX_MPI_TAG_UNIVERSAL, MPI_COMM_WORLD)
						except:
							mpi_send([-999.0,-999.0,-999.0], 3, MPI_FLOAT, main_node, SPARX_MPI_TAG_UNIVERSAL, MPI_COMM_WORLD)
						"""
                if myid == main_node:
                    header(os.path.join(current_output_dir, options.ave2D),
                           params='xform.projection',
                           fimport=os.path.join(current_output_dir,
                                                "params.txt"))
                mpi_barrier(MPI_COMM_WORLD)
            if options.ave3D:
                from sp_fundamentals import fpol
                t5 = time()
                if myid == main_node: log_main.add("Reconstruct ave3D ... ")
                ave3D = recons3d_4nn_MPI(myid,
                                         aveList,
                                         symmetry=options.sym,
                                         npad=options.npad)
                bcast_EMData_to_all(ave3D, myid)
                if myid == main_node:
                    if current_decimate != 1.0:
                        ave3D = resample(ave3D, 1. / current_decimate)
                    ave3D = fpol(ave3D, nnxo, nnxo,
                                 nnxo)  # always to the orignal image size
                    set_pixel_size(ave3D, 1.0)
                    ave3D.write_image(
                        os.path.join(current_output_dir, options.ave3D))
                    log_main.add("Ave3D reconstruction took %12.1f [m]" %
                                 ((time() - t5) / 60.0))
                    log_main.add("%-70s:  %s\n" %
                                 ("The reconstructed ave3D is saved as ",
                                  options.ave3D))

            mpi_barrier(MPI_COMM_WORLD)
            del ave, var, proj_list, stack, alpha, sx, sy, mirror, aveList
            '''
			if nvec > 0:
				for k in range(nvec):
					if myid == main_node:log_main.add("Reconstruction eigenvolumes", k)
					cont = True
					ITER = 0
					mask2d = model_circle(radiuspca, nx, nx)
					while cont:
						#print "On node %d, iteration %d"%(myid, ITER)
						eig3D = recons3d_4nn_MPI(myid, eigList[k], symmetry=options.sym, npad=options.npad)
						bcast_EMData_to_all(eig3D, myid, main_node)
						if options.fl > 0.0:
							eig3D = filt_tanl(eig3D, options.fl, options.aa)
						if myid == main_node:
							eig3D.write_image(os.path.join(options.outpout_dir, "eig3d_%03d.hdf"%(k, ITER)))
						Util.mul_img( eig3D, model_circle(radiuspca, nx, nx, nx) )
						eig3Df, kb = prep_vol(eig3D)
						del eig3D
						cont = False
						icont = 0
						for l in range(len(eigList[k])):
							phi, theta, psi, s2x, s2y = get_params_proj(eigList[k][l])
							proj = prgs(eig3Df, kb, [phi, theta, psi, s2x, s2y])
							cl = ccc(proj, eigList[k][l], mask2d)
							if cl < 0.0:
								icont += 1
								cont = True
								eigList[k][l] *= -1.0
						u = int(cont)
						u = mpi_reduce([u], 1, MPI_INT, MPI_MAX, main_node, MPI_COMM_WORLD)
						icont = mpi_reduce([icont], 1, MPI_INT, MPI_SUM, main_node, MPI_COMM_WORLD)

						if myid == main_node:
							u = int(u[0])
							log_main.add(" Eigenvector: ",k," number changed ",int(icont[0]))
						else: u = 0
						u = bcast_number_to_all(u, main_node)
						cont = bool(u)
						ITER += 1

					del eig3Df, kb
					mpi_barrier(MPI_COMM_WORLD)
				del eigList, mask2d
			'''
            if options.ave3D: del ave3D
            if options.var2D:
                from sp_fundamentals import fpol
                from sp_applications import header
                if myid == main_node:
                    log_main.add("Compute var2D...")
                    km = 0
                    for i in range(number_of_proc):
                        if i == main_node:
                            for im in range(len(varList)):
                                tmpvol = fpol(varList[im], nx, nx, 1)
                                tmpvol.write_image(
                                    os.path.join(current_output_dir,
                                                 options.var2D), km)
                                km += 1
                        else:
                            nl = mpi_recv(1, MPI_INT, i,
                                          SPARX_MPI_TAG_UNIVERSAL,
                                          MPI_COMM_WORLD)
                            nl = int(nl[0])
                            for im in range(nl):
                                ave = recv_EMData(i, im + i + 70000)
                                tmpvol = fpol(ave, nx, nx, 1)
                                tmpvol.write_image(
                                    os.path.join(current_output_dir,
                                                 options.var2D), km)
                                km += 1
                else:
                    mpi_send(len(varList), 1, MPI_INT, main_node,
                             SPARX_MPI_TAG_UNIVERSAL, MPI_COMM_WORLD)
                    for im in range(len(varList)):
                        send_EMData(varList[im], main_node, im + myid +
                                    70000)  #  What with the attributes??
                mpi_barrier(MPI_COMM_WORLD)
                if myid == main_node:
                    from sp_applications import header
                    header(os.path.join(current_output_dir, options.var2D),
                           params='xform.projection',
                           fimport=os.path.join(current_output_dir,
                                                "params.txt"))
                mpi_barrier(MPI_COMM_WORLD)
        if options.var3D:
            if myid == main_node: log_main.add("Reconstruct var3D ...")
            t6 = time()
            # radiusvar = options.radius
            # if( radiusvar < 0 ):  radiusvar = nx//2 -3
            res = recons3d_4nn_MPI(myid,
                                   varList,
                                   symmetry=options.sym,
                                   npad=options.npad)
            #res = recons3d_em_MPI(varList, vol_stack, options.iter, radiusvar, options.abs, True, options.sym, options.squ)
            if myid == main_node:
                from sp_fundamentals import fpol
                if current_decimate != 1.0:
                    res = resample(res, 1. / current_decimate)
                res = fpol(res, nnxo, nnxo, nnxo)
                set_pixel_size(res, 1.0)
                res.write_image(os.path.join(current_output_dir,
                                             options.var3D))
                log_main.add(
                    "%-70s:  %s\n" %
                    ("The reconstructed var3D is saved as ", options.var3D))
                log_main.add("Var3D reconstruction took %f12.1 [m]" %
                             ((time() - t6) / 60.0))
                log_main.add("Total computation time %f12.1 [m]" %
                             ((time() - t0) / 60.0))
                log_main.add("sx3dvariability finishes")

        if RUNNING_UNDER_MPI:
            sp_global_def.MPI = False

        sp_global_def.BATCH = False
Ejemplo n.º 10
0
def main():
    from sp_logger import Logger, BaseLogger_Files
    arglist = []
    i = 0
    while (i < len(sys.argv)):
        if sys.argv[i] == '-p4pg':
            i = i + 2
        elif sys.argv[i] == '-p4wd':
            i = i + 2
        else:
            arglist.append(sys.argv[i])
            i = i + 1
    progname = os.path.basename(arglist[0])
    usage = progname + " stack  outdir  <mask> --focus=3Dmask --radius=outer_radius --delta=angular_step" +\
    "--an=angular_neighborhood --maxit=max_iter  --CTF --sym=c1 --function=user_function --independent=indenpendent_runs  --number_of_images_per_group=number_of_images_per_group  --low_pass_filter=.25  --seed=random_seed"
    parser = OptionParser(usage, version=SPARXVERSION)
    parser.add_option("--focus",
                      type="string",
                      default='',
                      help="bineary 3D mask for focused clustering ")
    parser.add_option(
        "--ir",
        type="int",
        default=1,
        help="inner radius for rotational correlation > 0 (set to 1)")
    parser.add_option(
        "--radius",
        type="int",
        default=-1,
        help=
        "particle radius in pixel for rotational correlation <nx-1 (set to the radius of the particle)"
    )
    parser.add_option("--maxit",
                      type="int",
                      default=25,
                      help="maximum number of iteration")
    parser.add_option(
        "--rs",
        type="int",
        default=1,
        help="step between rings in rotational correlation >0 (set to 1)")
    parser.add_option(
        "--xr",
        type="string",
        default='1',
        help="range for translation search in x direction, search is +/-xr ")
    parser.add_option(
        "--yr",
        type="string",
        default='-1',
        help=
        "range for translation search in y direction, search is +/-yr (default = same as xr)"
    )
    parser.add_option(
        "--ts",
        type="string",
        default='0.25',
        help=
        "step size of the translation search in both directions direction, search is -xr, -xr+ts, 0, xr-ts, xr "
    )
    parser.add_option("--delta",
                      type="string",
                      default='2',
                      help="angular step of reference projections")
    parser.add_option("--an",
                      type="string",
                      default='-1',
                      help="angular neighborhood for local searches")
    parser.add_option(
        "--center",
        type="int",
        default=0,
        help=
        "0 - if you do not want the volume to be centered, 1 - center the volume using cog (default=0)"
    )
    parser.add_option(
        "--nassign",
        type="int",
        default=1,
        help=
        "number of reassignment iterations performed for each angular step (set to 3) "
    )
    parser.add_option(
        "--nrefine",
        type="int",
        default=0,
        help=
        "number of alignment iterations performed for each angular step (set to 0)"
    )
    parser.add_option("--CTF",
                      action="store_true",
                      default=False,
                      help="do CTF correction during clustring")
    parser.add_option(
        "--stoprnct",
        type="float",
        default=3.0,
        help="Minimum percentage of assignment change to stop the program")
    parser.add_option("--sym",
                      type="string",
                      default='c1',
                      help="symmetry of the structure ")
    parser.add_option("--function",
                      type="string",
                      default='do_volume_mrk05',
                      help="name of the reference preparation function")
    parser.add_option("--independent",
                      type="int",
                      default=3,
                      help="number of independent run")
    parser.add_option("--number_of_images_per_group",
                      type="int",
                      default=1000,
                      help="number of groups")
    parser.add_option(
        "--low_pass_filter",
        type="float",
        default=-1.0,
        help=
        "absolute frequency of low-pass filter for 3d sorting on the original image size"
    )
    parser.add_option("--nxinit",
                      type="int",
                      default=64,
                      help="initial image size for sorting")
    parser.add_option("--unaccounted",
                      action="store_true",
                      default=False,
                      help="reconstruct the unaccounted images")
    parser.add_option(
        "--seed",
        type="int",
        default=-1,
        help="random seed for create initial random assignment for EQ Kmeans")
    parser.add_option("--smallest_group",
                      type="int",
                      default=500,
                      help="minimum members for identified group")
    parser.add_option("--sausage",
                      action="store_true",
                      default=False,
                      help="way of filter volume")
    parser.add_option("--chunk0",
                      type="string",
                      default='',
                      help="chunk0 for computing margin of error")
    parser.add_option("--chunk1",
                      type="string",
                      default='',
                      help="chunk1 for computing margin of error")
    parser.add_option(
        "--PWadjustment",
        type="string",
        default='',
        help=
        "1-D power spectrum of PDB file used for EM volume power spectrum correction"
    )
    parser.add_option(
        "--protein_shape",
        type="string",
        default='g',
        help=
        "protein shape. It defines protein preferred orientation angles. Currently it has g and f two types "
    )
    parser.add_option(
        "--upscale",
        type="float",
        default=0.5,
        help=" scaling parameter to adjust the power spectrum of EM volumes")
    parser.add_option("--wn",
                      type="int",
                      default=0,
                      help="optimal window size for data processing")
    parser.add_option(
        "--interpolation",
        type="string",
        default="4nn",
        help="3-d reconstruction interpolation method, two options trl and 4nn"
    )

    (options, args) = parser.parse_args(arglist[1:])

    if len(args) < 1 or len(args) > 4:
        sxprint("Usage: " + usage)
        sxprint("Please run \'" + progname + " -h\' for detailed options")
        ERROR(
            "Invalid number of parameters used. Please see usage information above."
        )
        return

    else:

        if len(args) > 2:
            mask_file = args[2]
        else:
            mask_file = None

        orgstack = args[0]
        masterdir = args[1]
        sp_global_def.BATCH = True
        #---initialize MPI related variables
        nproc = mpi.mpi_comm_size(mpi.MPI_COMM_WORLD)
        myid = mpi.mpi_comm_rank(mpi.MPI_COMM_WORLD)
        mpi_comm = mpi.MPI_COMM_WORLD
        main_node = 0
        # import some utilities
        from sp_utilities import get_im, bcast_number_to_all, cmdexecute, write_text_file, read_text_file, wrap_mpi_bcast, get_params_proj, write_text_row
        from sp_applications import recons3d_n_MPI, mref_ali3d_MPI, Kmref_ali3d_MPI
        from sp_statistics import k_means_match_clusters_asg_new, k_means_stab_bbenum
        from sp_applications import mref_ali3d_EQ_Kmeans, ali3d_mref_Kmeans_MPI
        # Create the main log file
        from sp_logger import Logger, BaseLogger_Files
        if myid == main_node:
            log_main = Logger(BaseLogger_Files())
            log_main.prefix = masterdir + "/"
        else:
            log_main = None
        #--- fill input parameters into dictionary named after Constants
        Constants = {}
        Constants["stack"] = args[0]
        Constants["masterdir"] = masterdir
        Constants["mask3D"] = mask_file
        Constants["focus3Dmask"] = options.focus
        Constants["indep_runs"] = options.independent
        Constants["stoprnct"] = options.stoprnct
        Constants[
            "number_of_images_per_group"] = options.number_of_images_per_group
        Constants["CTF"] = options.CTF
        Constants["maxit"] = options.maxit
        Constants["ir"] = options.ir
        Constants["radius"] = options.radius
        Constants["nassign"] = options.nassign
        Constants["rs"] = options.rs
        Constants["xr"] = options.xr
        Constants["yr"] = options.yr
        Constants["ts"] = options.ts
        Constants["delta"] = options.delta
        Constants["an"] = options.an
        Constants["sym"] = options.sym
        Constants["center"] = options.center
        Constants["nrefine"] = options.nrefine
        #Constants["fourvar"]            		 = options.fourvar
        Constants["user_func"] = options.function
        Constants[
            "low_pass_filter"] = options.low_pass_filter  # enforced low_pass_filter
        #Constants["debug"]              		 = options.debug
        Constants["main_log_prefix"] = args[1]
        #Constants["importali3d"]        		 = options.importali3d
        Constants["myid"] = myid
        Constants["main_node"] = main_node
        Constants["nproc"] = nproc
        Constants["log_main"] = log_main
        Constants["nxinit"] = options.nxinit
        Constants["unaccounted"] = options.unaccounted
        Constants["seed"] = options.seed
        Constants["smallest_group"] = options.smallest_group
        Constants["sausage"] = options.sausage
        Constants["chunk0"] = options.chunk0
        Constants["chunk1"] = options.chunk1
        Constants["PWadjustment"] = options.PWadjustment
        Constants["upscale"] = options.upscale
        Constants["wn"] = options.wn
        Constants["3d-interpolation"] = options.interpolation
        Constants["protein_shape"] = options.protein_shape
        # -----------------------------------------------------
        #
        # Create and initialize Tracker dictionary with input options
        Tracker = {}
        Tracker["constants"] = Constants
        Tracker["maxit"] = Tracker["constants"]["maxit"]
        Tracker["radius"] = Tracker["constants"]["radius"]
        #Tracker["xr"]             = ""
        #Tracker["yr"]             = "-1"  # Do not change!
        #Tracker["ts"]             = 1
        #Tracker["an"]             = "-1"
        #Tracker["delta"]          = "2.0"
        #Tracker["zoom"]           = True
        #Tracker["nsoft"]          = 0
        #Tracker["local"]          = False
        #Tracker["PWadjustment"]   = Tracker["constants"]["PWadjustment"]
        Tracker["upscale"] = Tracker["constants"]["upscale"]
        #Tracker["upscale"]        = 0.5
        Tracker[
            "applyctf"] = False  #  Should the data be premultiplied by the CTF.  Set to False for local continuous.
        #Tracker["refvol"]         = None
        Tracker["nxinit"] = Tracker["constants"]["nxinit"]
        #Tracker["nxstep"]         = 32
        Tracker["icurrentres"] = -1
        #Tracker["ireachedres"]    = -1
        #Tracker["lowpass"]        = 0.4
        #Tracker["falloff"]        = 0.2
        #Tracker["inires"]         = options.inires  # Now in A, convert to absolute before using
        Tracker["fuse_freq"] = 50  # Now in A, convert to absolute before using
        #Tracker["delpreviousmax"] = False
        #Tracker["anger"]          = -1.0
        #Tracker["shifter"]        = -1.0
        #Tracker["saturatecrit"]   = 0.95
        #Tracker["pixercutoff"]    = 2.0
        #Tracker["directory"]      = ""
        #Tracker["previousoutputdir"] = ""
        #Tracker["eliminated-outliers"] = False
        #Tracker["mainiteration"]  = 0
        #Tracker["movedback"]      = False
        #Tracker["state"]          = Tracker["constants"]["states"][0]
        #Tracker["global_resolution"] =0.0
        Tracker["orgstack"] = orgstack
        #--------------------------------------------------------------------
        # import from utilities
        from sp_utilities import sample_down_1D_curve, get_initial_ID, remove_small_groups, print_upper_triangular_matrix, print_a_line_with_timestamp
        from sp_utilities import print_dict, get_resolution_mrk01, partition_to_groups, partition_independent_runs, get_outliers
        from sp_utilities import merge_groups, save_alist, margin_of_error, get_margin_of_error, do_two_way_comparison, select_two_runs, get_ali3d_params
        from sp_utilities import counting_projections, unload_dict, load_dict, get_stat_proj, create_random_list, get_number_of_groups, recons_mref
        from sp_utilities import apply_low_pass_filter, get_groups_from_partition, get_number_of_groups, get_complementary_elements_total, update_full_dict
        from sp_utilities import count_chunk_members, set_filter_parameters_from_adjusted_fsc, get_two_chunks_from_stack
        ####------------------------------------------------------------------
        #
        # Get the pixel size; if none, set to 1.0, and the original image size
        from sp_utilities import get_shrink_data_huang
        if (myid == main_node):
            line = strftime("%Y-%m-%d_%H:%M:%S", localtime()) + " =>"
            sxprint((line + "Initialization of 3-D sorting"))
            a = get_im(orgstack)
            nnxo = a.get_xsize()
            if (Tracker["nxinit"] > nnxo):
                sp_global_def.ERROR(
                    "Image size less than minimum permitted $d" %
                    Tracker["nxinit"])
                nnxo = -1
            else:
                if Tracker["constants"]["CTF"]:
                    i = a.get_attr('ctf')
                    pixel_size = i.apix
                    fq = pixel_size / Tracker["fuse_freq"]
                else:
                    pixel_size = 1.0
                    #  No pixel size, fusing computed as 5 Fourier pixels
                    fq = 5.0 / nnxo
                    del a
        else:
            nnxo = 0
            fq = 0.0
            pixel_size = 1.0
        nnxo = bcast_number_to_all(nnxo, source_node=main_node)
        if (nnxo < 0):
            return
        pixel_size = bcast_number_to_all(pixel_size, source_node=main_node)
        fq = bcast_number_to_all(fq, source_node=main_node)
        if Tracker["constants"]["wn"] == 0:
            Tracker["constants"]["nnxo"] = nnxo
        else:
            Tracker["constants"]["nnxo"] = Tracker["constants"]["wn"]
            nnxo = Tracker["constants"]["nnxo"]
        Tracker["constants"]["pixel_size"] = pixel_size
        Tracker["fuse_freq"] = fq
        del fq, nnxo, pixel_size
        if (Tracker["constants"]["radius"] < 1):
            Tracker["constants"][
                "radius"] = Tracker["constants"]["nnxo"] // 2 - 2
        elif ((2 * Tracker["constants"]["radius"] + 2) >
              Tracker["constants"]["nnxo"]):
            sp_global_def.ERROR("Particle radius set too large!", myid=myid)


####-----------------------------------------------------------------------------------------
# Master directory
        if myid == main_node:
            if masterdir == "":
                timestring = strftime("_%d_%b_%Y_%H_%M_%S", localtime())
                masterdir = "master_sort3d" + timestring
            li = len(masterdir)
            cmd = "{} {}".format("mkdir -p", masterdir)
            os.system(cmd)
        else:
            li = 0
        li = mpi.mpi_bcast(li, 1, mpi.MPI_INT, main_node,
                           mpi.MPI_COMM_WORLD)[0]
        if li > 0:
            masterdir = mpi.mpi_bcast(masterdir, li, mpi.MPI_CHAR, main_node,
                                      mpi.MPI_COMM_WORLD)
            import string
            masterdir = string.join(masterdir, "")
        if myid == main_node:
            print_dict(Tracker["constants"],
                       "Permanent settings of 3-D sorting program")
        ######### create a vstack from input stack to the local stack in masterdir
        # stack name set to default
        Tracker["constants"]["stack"] = "bdb:" + masterdir + "/rdata"
        Tracker["constants"]["ali3d"] = os.path.join(masterdir,
                                                     "ali3d_init.txt")
        Tracker["constants"]["ctf_params"] = os.path.join(
            masterdir, "ctf_params.txt")
        Tracker["constants"]["partstack"] = Tracker["constants"][
            "ali3d"]  # also serves for refinement
        if myid == main_node:
            total_stack = EMUtil.get_image_count(Tracker["orgstack"])
        else:
            total_stack = 0
        total_stack = bcast_number_to_all(total_stack, source_node=main_node)
        mpi.mpi_barrier(mpi.MPI_COMM_WORLD)
        from time import sleep
        while not os.path.exists(masterdir):
            sxprint("Node ", myid, "  waiting...")
            sleep(5)
        mpi.mpi_barrier(mpi.MPI_COMM_WORLD)
        if myid == main_node:
            log_main.add("Sphire sort3d ")
            log_main.add("the sort3d master directory is " + masterdir)
        #####
        ###----------------------------------------------------------------------------------
        # Initial data analysis and handle two chunk files
        from random import shuffle
        # Compute the resolution
        #### make chunkdir dictionary for computing margin of error
        import sp_user_functions
        user_func = sp_user_functions.factory[Tracker["constants"]
                                              ["user_func"]]
        chunk_dict = {}
        chunk_list = []
        if myid == main_node:
            chunk_one = read_text_file(Tracker["constants"]["chunk0"])
            chunk_two = read_text_file(Tracker["constants"]["chunk1"])
        else:
            chunk_one = 0
            chunk_two = 0
        chunk_one = wrap_mpi_bcast(chunk_one, main_node)
        chunk_two = wrap_mpi_bcast(chunk_two, main_node)
        mpi.mpi_barrier(mpi.MPI_COMM_WORLD)
        ######################## Read/write bdb: data on main node ############################
        if myid == main_node:
            if (orgstack[:4] == "bdb:"):
                cmd = "{} {} {}".format(
                    "e2bdb.py", orgstack,
                    "--makevstack=" + Tracker["constants"]["stack"])
            else:
                cmd = "{} {} {}".format("sp_cpy.py", orgstack,
                                        Tracker["constants"]["stack"])
            junk = cmdexecute(cmd)
            cmd = "{} {} {}".format(
                "sp_header.py  --params=xform.projection",
                "--export=" + Tracker["constants"]["ali3d"], orgstack)
            junk = cmdexecute(cmd)
            cmd = "{} {} {}".format(
                "sp_header.py  --params=ctf",
                "--export=" + Tracker["constants"]["ctf_params"], orgstack)
            junk = cmdexecute(cmd)
        mpi.mpi_barrier(mpi.MPI_COMM_WORLD)
        ########-----------------------------------------------------------------------------
        Tracker["total_stack"] = total_stack
        Tracker["constants"]["total_stack"] = total_stack
        Tracker["shrinkage"] = float(
            Tracker["nxinit"]) / Tracker["constants"]["nnxo"]
        Tracker[
            "radius"] = Tracker["constants"]["radius"] * Tracker["shrinkage"]
        if Tracker["constants"]["mask3D"]:
            Tracker["mask3D"] = os.path.join(masterdir, "smask.hdf")
        else:
            Tracker["mask3D"] = None
        if Tracker["constants"]["focus3Dmask"]:
            Tracker["focus3D"] = os.path.join(masterdir, "sfocus.hdf")
        else:
            Tracker["focus3D"] = None
        if myid == main_node:
            if Tracker["constants"]["mask3D"]:
                mask_3D = get_shrink_3dmask(Tracker["nxinit"],
                                            Tracker["constants"]["mask3D"])
                mask_3D.write_image(Tracker["mask3D"])
            if Tracker["constants"]["focus3Dmask"]:
                mask_3D = get_shrink_3dmask(
                    Tracker["nxinit"], Tracker["constants"]["focus3Dmask"])
                st = Util.infomask(mask_3D, None, True)
                if (st[0] == 0.0):
                    ERROR(
                        "Incorrect focused mask, after binarize all values zero"
                    )
                mask_3D.write_image(Tracker["focus3D"])
                del mask_3D
        if Tracker["constants"]["PWadjustment"] != '':
            PW_dict = {}
            nxinit_pwsp = sample_down_1D_curve(
                Tracker["constants"]["nxinit"], Tracker["constants"]["nnxo"],
                Tracker["constants"]["PWadjustment"])
            Tracker["nxinit_PW"] = os.path.join(masterdir, "spwp.txt")
            if myid == main_node:
                write_text_file(nxinit_pwsp, Tracker["nxinit_PW"])
            PW_dict[Tracker["constants"]
                    ["nnxo"]] = Tracker["constants"]["PWadjustment"]
            PW_dict[Tracker["constants"]["nxinit"]] = Tracker["nxinit_PW"]
            Tracker["PW_dict"] = PW_dict
        mpi.mpi_barrier(mpi.MPI_COMM_WORLD)
        #-----------------------From two chunks to FSC, and low pass filter-----------------------------------------###
        for element in chunk_one:
            chunk_dict[element] = 0
        for element in chunk_two:
            chunk_dict[element] = 1
        chunk_list = [chunk_one, chunk_two]
        Tracker["chunk_dict"] = chunk_dict
        Tracker["P_chunk0"] = len(chunk_one) / float(total_stack)
        Tracker["P_chunk1"] = len(chunk_two) / float(total_stack)
        ### create two volumes to estimate resolution
        if myid == main_node:
            for index in range(2):
                write_text_file(
                    chunk_list[index],
                    os.path.join(masterdir, "chunk%01d.txt" % index))
        mpi.mpi_barrier(mpi.MPI_COMM_WORLD)
        vols = []
        for index in range(2):
            data, old_shifts = get_shrink_data_huang(
                Tracker,
                Tracker["constants"]["nxinit"],
                os.path.join(masterdir, "chunk%01d.txt" % index),
                Tracker["constants"]["partstack"],
                myid,
                main_node,
                nproc,
                preshift=True)
            vol = recons3d_4nn_ctf_MPI(myid=myid,
                                       prjlist=data,
                                       symmetry=Tracker["constants"]["sym"],
                                       finfo=None)
            if myid == main_node:
                vol.write_image(os.path.join(masterdir, "vol%d.hdf" % index))
            vols.append(vol)
            mpi.mpi_barrier(mpi.MPI_COMM_WORLD)
        if myid == main_node:
            low_pass, falloff, currentres = get_resolution_mrk01(
                vols, Tracker["constants"]["radius"],
                Tracker["constants"]["nxinit"], masterdir, Tracker["mask3D"])
            if low_pass > Tracker["constants"]["low_pass_filter"]:
                low_pass = Tracker["constants"]["low_pass_filter"]
        else:
            low_pass = 0.0
            falloff = 0.0
            currentres = 0.0
        bcast_number_to_all(currentres, source_node=main_node)
        bcast_number_to_all(low_pass, source_node=main_node)
        bcast_number_to_all(falloff, source_node=main_node)
        Tracker["currentres"] = currentres
        Tracker["falloff"] = falloff
        if Tracker["constants"]["low_pass_filter"] == -1.0:
            Tracker["low_pass_filter"] = min(
                .45, low_pass / Tracker["shrinkage"])  # no better than .45
        else:
            Tracker["low_pass_filter"] = min(
                .45,
                Tracker["constants"]["low_pass_filter"] / Tracker["shrinkage"])
        Tracker["lowpass"] = Tracker["low_pass_filter"]
        Tracker["falloff"] = .1
        Tracker["global_fsc"] = os.path.join(masterdir, "fsc.txt")
        ############################################################################################
        if myid == main_node:
            log_main.add("The command-line inputs are as following:")
            log_main.add(
                "**********************************************************")
        for a in sys.argv:
            if myid == main_node: log_main.add(a)
        if myid == main_node:
            log_main.add("number of cpus used in this run is %d" %
                         Tracker["constants"]["nproc"])
            log_main.add(
                "**********************************************************")
        from sp_filter import filt_tanl
        ### START 3-D sorting
        if myid == main_node:
            log_main.add("----------3-D sorting  program------- ")
            log_main.add(
                "current resolution %6.3f for images of original size in terms of absolute frequency"
                % Tracker["currentres"])
            log_main.add("equivalent to %f Angstrom resolution" %
                         (Tracker["constants"]["pixel_size"] /
                          Tracker["currentres"] / Tracker["shrinkage"]))
            log_main.add("the user provided enforced low_pass_filter is %f" %
                         Tracker["constants"]["low_pass_filter"])
            #log_main.add("equivalent to %f Angstrom resolution"%(Tracker["constants"]["pixel_size"]/Tracker["constants"]["low_pass_filter"]))
            for index in range(2):
                filt_tanl(
                    get_im(os.path.join(masterdir, "vol%01d.hdf" % index)),
                    Tracker["low_pass_filter"],
                    Tracker["falloff"]).write_image(
                        os.path.join(masterdir, "volf%01d.hdf" % index))
        mpi.mpi_barrier(mpi.MPI_COMM_WORLD)
        from sp_utilities import get_input_from_string
        delta = get_input_from_string(Tracker["constants"]["delta"])
        delta = delta[0]
        from sp_utilities import even_angles
        n_angles = even_angles(delta, 0, 180)
        this_ali3d = Tracker["constants"]["ali3d"]
        sampled = get_stat_proj(Tracker, delta, this_ali3d)
        if myid == main_node:
            nc = 0
            for a in sampled:
                if len(sampled[a]) > 0:
                    nc += 1
            log_main.add("total sampled direction %10d  at angle step %6.3f" %
                         (len(n_angles), delta))
            log_main.add(
                "captured sampled directions %10d percentage covered by data  %6.3f"
                % (nc, float(nc) / len(n_angles) * 100))
        number_of_images_per_group = Tracker["constants"][
            "number_of_images_per_group"]
        if myid == main_node:
            log_main.add("user provided number_of_images_per_group %d" %
                         number_of_images_per_group)
        Tracker["number_of_images_per_group"] = number_of_images_per_group
        number_of_groups = get_number_of_groups(total_stack,
                                                number_of_images_per_group)
        Tracker["number_of_groups"] = number_of_groups
        generation = 0
        partition_dict = {}
        full_dict = {}
        workdir = os.path.join(masterdir, "generation%03d" % generation)
        Tracker["this_dir"] = workdir
        if myid == main_node:
            log_main.add("---- generation         %5d" % generation)
            log_main.add("number of images per group is set as %d" %
                         number_of_images_per_group)
            log_main.add("the initial number of groups is  %10d " %
                         number_of_groups)
            cmd = "{} {}".format("mkdir", workdir)
            os.system(cmd)
        mpi.mpi_barrier(mpi.MPI_COMM_WORLD)
        list_to_be_processed = list(range(Tracker["constants"]["total_stack"]))
        Tracker["this_data_list"] = list_to_be_processed
        create_random_list(Tracker)
        #################################
        full_dict = {}
        for iptl in range(Tracker["constants"]["total_stack"]):
            full_dict[iptl] = iptl
        Tracker["full_ID_dict"] = full_dict
        #################################
        for indep_run in range(Tracker["constants"]["indep_runs"]):
            Tracker["this_particle_list"] = Tracker["this_indep_list"][
                indep_run]
            ref_vol = recons_mref(Tracker)
            if myid == main_node:
                log_main.add("independent run  %10d" % indep_run)
            mpi.mpi_barrier(mpi.MPI_COMM_WORLD)
            Tracker["this_data_list"] = list_to_be_processed
            Tracker["total_stack"] = len(Tracker["this_data_list"])
            Tracker["this_particle_text_file"] = os.path.join(
                workdir,
                "independent_list_%03d.txt" % indep_run)  # for get_shrink_data
            if myid == main_node:
                write_text_file(Tracker["this_data_list"],
                                Tracker["this_particle_text_file"])
            mpi.mpi_barrier(mpi.MPI_COMM_WORLD)
            outdir = os.path.join(workdir, "EQ_Kmeans%03d" % indep_run)
            ref_vol = apply_low_pass_filter(ref_vol, Tracker)
            mref_ali3d_EQ_Kmeans(ref_vol, outdir,
                                 Tracker["this_particle_text_file"], Tracker)
            partition_dict[indep_run] = Tracker["this_partition"]
        Tracker["partition_dict"] = partition_dict
        Tracker["total_stack"] = len(Tracker["this_data_list"])
        Tracker["this_total_stack"] = Tracker["total_stack"]
        ###############################
        do_two_way_comparison(Tracker)
        ###############################
        ref_vol_list = []
        from time import sleep
        number_of_ref_class = []
        for igrp in range(len(Tracker["two_way_stable_member"])):
            Tracker["this_data_list"] = Tracker["two_way_stable_member"][igrp]
            Tracker["this_data_list_file"] = os.path.join(
                workdir, "stable_class%d.txt" % igrp)
            if myid == main_node:
                write_text_file(Tracker["this_data_list"],
                                Tracker["this_data_list_file"])
            data, old_shifts = get_shrink_data_huang(
                Tracker,
                Tracker["nxinit"],
                Tracker["this_data_list_file"],
                Tracker["constants"]["partstack"],
                myid,
                main_node,
                nproc,
                preshift=True)
            volref = recons3d_4nn_ctf_MPI(myid=myid,
                                          prjlist=data,
                                          symmetry=Tracker["constants"]["sym"],
                                          finfo=None)
            ref_vol_list.append(volref)
            number_of_ref_class.append(len(Tracker["this_data_list"]))
            if myid == main_node:
                log_main.add("group  %d  members %d " %
                             (igrp, len(Tracker["this_data_list"])))
        Tracker["number_of_ref_class"] = number_of_ref_class
        nx_of_image = ref_vol_list[0].get_xsize()
        if Tracker["constants"]["PWadjustment"]:
            Tracker["PWadjustment"] = Tracker["PW_dict"][nx_of_image]
        else:
            Tracker["PWadjustment"] = Tracker["constants"][
                "PWadjustment"]  # no PW adjustment
        if myid == main_node:
            for iref in range(len(ref_vol_list)):
                refdata = [None] * 4
                refdata[0] = ref_vol_list[iref]
                refdata[1] = Tracker
                refdata[2] = Tracker["constants"]["myid"]
                refdata[3] = Tracker["constants"]["nproc"]
                volref = user_func(refdata)
                volref.write_image(os.path.join(workdir, "volf_stable.hdf"),
                                   iref)
        mpi.mpi_barrier(mpi.MPI_COMM_WORLD)
        Tracker["this_data_list"] = Tracker["this_accounted_list"]
        outdir = os.path.join(workdir, "Kmref")
        empty_group, res_groups, final_list = ali3d_mref_Kmeans_MPI(
            ref_vol_list, outdir, Tracker["this_accounted_text"], Tracker)
        Tracker["this_unaccounted_list"] = get_complementary_elements(
            list_to_be_processed, final_list)
        if myid == main_node:
            log_main.add("the number of particles not processed is %d" %
                         len(Tracker["this_unaccounted_list"]))
            write_text_file(Tracker["this_unaccounted_list"],
                            Tracker["this_unaccounted_text"])
        update_full_dict(Tracker["this_unaccounted_list"], Tracker)
        #######################################
        number_of_groups = len(res_groups)
        vol_list = []
        number_of_ref_class = []
        for igrp in range(number_of_groups):
            data, old_shifts = get_shrink_data_huang(
                Tracker,
                Tracker["constants"]["nnxo"],
                os.path.join(outdir, "Class%d.txt" % igrp),
                Tracker["constants"]["partstack"],
                myid,
                main_node,
                nproc,
                preshift=True)
            volref = recons3d_4nn_ctf_MPI(myid=myid,
                                          prjlist=data,
                                          symmetry=Tracker["constants"]["sym"],
                                          finfo=None)
            vol_list.append(volref)

            if (myid == main_node):
                npergroup = len(
                    read_text_file(os.path.join(outdir, "Class%d.txt" % igrp)))
            else:
                npergroup = 0
            npergroup = bcast_number_to_all(npergroup, main_node)
            number_of_ref_class.append(npergroup)

        Tracker["number_of_ref_class"] = number_of_ref_class

        mpi.mpi_barrier(mpi.MPI_COMM_WORLD)
        nx_of_image = vol_list[0].get_xsize()
        if Tracker["constants"]["PWadjustment"]:
            Tracker["PWadjustment"] = Tracker["PW_dict"][nx_of_image]
        else:
            Tracker["PWadjustment"] = Tracker["constants"]["PWadjustment"]

        if myid == main_node:
            for ivol in range(len(vol_list)):
                refdata = [None] * 4
                refdata[0] = vol_list[ivol]
                refdata[1] = Tracker
                refdata[2] = Tracker["constants"]["myid"]
                refdata[3] = Tracker["constants"]["nproc"]
                volref = user_func(refdata)
                volref.write_image(
                    os.path.join(workdir, "volf_of_Classes.hdf"), ivol)
                log_main.add("number of unaccounted particles  %10d" %
                             len(Tracker["this_unaccounted_list"]))
                log_main.add("number of accounted particles  %10d" %
                             len(Tracker["this_accounted_list"]))

        Tracker["this_data_list"] = Tracker[
            "this_unaccounted_list"]  # reset parameters for the next round calculation
        Tracker["total_stack"] = len(Tracker["this_unaccounted_list"])
        Tracker["this_total_stack"] = Tracker["total_stack"]
        number_of_groups = get_number_of_groups(
            len(Tracker["this_unaccounted_list"]), number_of_images_per_group)
        Tracker["number_of_groups"] = number_of_groups
        while number_of_groups >= 2:
            generation += 1
            partition_dict = {}
            workdir = os.path.join(masterdir, "generation%03d" % generation)
            Tracker["this_dir"] = workdir
            if myid == main_node:
                log_main.add("*********************************************")
                log_main.add("-----    generation             %5d    " %
                             generation)
                log_main.add("number of images per group is set as %10d " %
                             number_of_images_per_group)
                log_main.add("the number of groups is  %10d " %
                             number_of_groups)
                log_main.add(" number of particles for clustering is %10d" %
                             Tracker["total_stack"])
                cmd = "{} {}".format("mkdir", workdir)
                os.system(cmd)
            mpi.mpi_barrier(mpi.MPI_COMM_WORLD)
            create_random_list(Tracker)
            for indep_run in range(Tracker["constants"]["indep_runs"]):
                Tracker["this_particle_list"] = Tracker["this_indep_list"][
                    indep_run]
                ref_vol = recons_mref(Tracker)
                if myid == main_node:
                    log_main.add("independent run  %10d" % indep_run)
                    outdir = os.path.join(workdir, "EQ_Kmeans%03d" % indep_run)
                Tracker["this_data_list"] = Tracker["this_unaccounted_list"]
                #ref_vol=apply_low_pass_filter(ref_vol,Tracker)
                mref_ali3d_EQ_Kmeans(ref_vol, outdir,
                                     Tracker["this_unaccounted_text"], Tracker)
                partition_dict[indep_run] = Tracker["this_partition"]
                Tracker["this_data_list"] = Tracker["this_unaccounted_list"]
                Tracker["total_stack"] = len(Tracker["this_unaccounted_list"])
                Tracker["partition_dict"] = partition_dict
                Tracker["this_total_stack"] = Tracker["total_stack"]
            total_list_of_this_run = Tracker["this_unaccounted_list"]
            ###############################
            do_two_way_comparison(Tracker)
            ###############################
            ref_vol_list = []
            number_of_ref_class = []
            for igrp in range(len(Tracker["two_way_stable_member"])):
                Tracker["this_data_list"] = Tracker["two_way_stable_member"][
                    igrp]
                Tracker["this_data_list_file"] = os.path.join(
                    workdir, "stable_class%d.txt" % igrp)
                if myid == main_node:
                    write_text_file(Tracker["this_data_list"],
                                    Tracker["this_data_list_file"])
                mpi.mpi_barrier(mpi.MPI_COMM_WORLD)
                data, old_shifts = get_shrink_data_huang(
                    Tracker,
                    Tracker["constants"]["nxinit"],
                    Tracker["this_data_list_file"],
                    Tracker["constants"]["partstack"],
                    myid,
                    main_node,
                    nproc,
                    preshift=True)
                volref = recons3d_4nn_ctf_MPI(
                    myid=myid,
                    prjlist=data,
                    symmetry=Tracker["constants"]["sym"],
                    finfo=None)
                #volref = filt_tanl(volref, Tracker["constants"]["low_pass_filter"],.1)
                if myid == main_node:
                    volref.write_image(os.path.join(workdir, "vol_stable.hdf"),
                                       iref)
                #volref = resample(volref,Tracker["shrinkage"])
                ref_vol_list.append(volref)
                number_of_ref_class.append(len(Tracker["this_data_list"]))
                mpi.mpi_barrier(mpi.MPI_COMM_WORLD)
            Tracker["number_of_ref_class"] = number_of_ref_class
            Tracker["this_data_list"] = Tracker["this_accounted_list"]
            outdir = os.path.join(workdir, "Kmref")
            empty_group, res_groups, final_list = ali3d_mref_Kmeans_MPI(
                ref_vol_list, outdir, Tracker["this_accounted_text"], Tracker)
            # calculate the 3-D structure of original image size for each group
            number_of_groups = len(res_groups)
            Tracker["this_unaccounted_list"] = get_complementary_elements(
                total_list_of_this_run, final_list)
            if myid == main_node:
                log_main.add("the number of particles not processed is %d" %
                             len(Tracker["this_unaccounted_list"]))
                write_text_file(Tracker["this_unaccounted_list"],
                                Tracker["this_unaccounted_text"])
            mpi.mpi_barrier(mpi.MPI_COMM_WORLD)
            update_full_dict(Tracker["this_unaccounted_list"], Tracker)
            vol_list = []
            for igrp in range(number_of_groups):
                data, old_shifts = get_shrink_data_huang(
                    Tracker,
                    Tracker["constants"]["nnxo"],
                    os.path.join(outdir, "Class%d.txt" % igrp),
                    Tracker["constants"]["partstack"],
                    myid,
                    main_node,
                    nproc,
                    preshift=True)
                volref = recons3d_4nn_ctf_MPI(
                    myid=myid,
                    prjlist=data,
                    symmetry=Tracker["constants"]["sym"],
                    finfo=None)
                vol_list.append(volref)

            mpi.mpi_barrier(mpi.MPI_COMM_WORLD)
            nx_of_image = ref_vol_list[0].get_xsize()
            if Tracker["constants"]["PWadjustment"]:
                Tracker["PWadjustment"] = Tracker["PW_dict"][nx_of_image]
            else:
                Tracker["PWadjustment"] = Tracker["constants"]["PWadjustment"]

            if myid == main_node:
                for ivol in range(len(vol_list)):
                    refdata = [None] * 4
                    refdata[0] = vol_list[ivol]
                    refdata[1] = Tracker
                    refdata[2] = Tracker["constants"]["myid"]
                    refdata[3] = Tracker["constants"]["nproc"]
                    volref = user_func(refdata)
                    volref.write_image(
                        os.path.join(workdir, "volf_of_Classes.hdf"), ivol)
                log_main.add("number of unaccounted particles  %10d" %
                             len(Tracker["this_unaccounted_list"]))
                log_main.add("number of accounted particles  %10d" %
                             len(Tracker["this_accounted_list"]))
            del vol_list
            mpi.mpi_barrier(mpi.MPI_COMM_WORLD)
            number_of_groups = get_number_of_groups(
                len(Tracker["this_unaccounted_list"]),
                number_of_images_per_group)
            Tracker["number_of_groups"] = number_of_groups
            Tracker["this_data_list"] = Tracker["this_unaccounted_list"]
            Tracker["total_stack"] = len(Tracker["this_unaccounted_list"])
        if Tracker["constants"]["unaccounted"]:
            data, old_shifts = get_shrink_data_huang(
                Tracker,
                Tracker["constants"]["nnxo"],
                Tracker["this_unaccounted_text"],
                Tracker["constants"]["partstack"],
                myid,
                main_node,
                nproc,
                preshift=True)
            volref = recons3d_4nn_ctf_MPI(myid=myid,
                                          prjlist=data,
                                          symmetry=Tracker["constants"]["sym"],
                                          finfo=None)
            nx_of_image = volref.get_xsize()
            if Tracker["constants"]["PWadjustment"]:
                Tracker["PWadjustment"] = Tracker["PW_dict"][nx_of_image]
            else:
                Tracker["PWadjustment"] = Tracker["constants"]["PWadjustment"]
            if (myid == main_node):
                refdata = [None] * 4
                refdata[0] = volref
                refdata[1] = Tracker
                refdata[2] = Tracker["constants"]["myid"]
                refdata[3] = Tracker["constants"]["nproc"]
                volref = user_func(refdata)
                #volref    = filt_tanl(volref, Tracker["constants"]["low_pass_filter"],.1)
                volref.write_image(
                    os.path.join(workdir, "volf_unaccounted.hdf"))
        # Finish program
        if myid == main_node: log_main.add("sxsort3d finishes")
        mpi.mpi_barrier(mpi.MPI_COMM_WORLD)
        return
Ejemplo n.º 11
0
def main():
    def params_3D_2D_NEW(phi, theta, psi, s2x, s2y, mirror):
        # the final ali2d parameters already combine shifts operation first and rotation operation second for parameters converted from 3D
        if mirror:
            m = 1
            alpha, sx, sy, scalen = sp_utilities.compose_transform2(
                0, s2x, s2y, 1.0, 540.0 - psi, 0, 0, 1.0)
        else:
            m = 0
            alpha, sx, sy, scalen = sp_utilities.compose_transform2(
                0, s2x, s2y, 1.0, 360.0 - psi, 0, 0, 1.0)
        return alpha, sx, sy, m

    progname = optparse.os.path.basename(sys.argv[0])
    usage = (
        progname +
        " prj_stack  --ave2D= --var2D=  --ave3D= --var3D= --img_per_grp= --fl=  --aa=   --sym=symmetry --CTF"
    )
    parser = optparse.OptionParser(usage, version=sp_global_def.SPARXVERSION)

    parser.add_option("--output_dir",
                      type="string",
                      default="./",
                      help="Output directory")
    parser.add_option(
        "--ave2D",
        type="string",
        default=False,
        help="Write to the disk a stack of 2D averages",
    )
    parser.add_option(
        "--var2D",
        type="string",
        default=False,
        help="Write to the disk a stack of 2D variances",
    )
    parser.add_option(
        "--ave3D",
        type="string",
        default=False,
        help="Write to the disk reconstructed 3D average",
    )
    parser.add_option(
        "--var3D",
        type="string",
        default=False,
        help="Compute 3D variability (time consuming!)",
    )
    parser.add_option(
        "--img_per_grp",
        type="int",
        default=100,
        help="Number of neighbouring projections.(Default is 100)",
    )
    parser.add_option(
        "--no_norm",
        action="store_true",
        default=False,
        help="Do not use normalization.(Default is to apply normalization)",
    )
    # parser.add_option("--radius", 	    type="int"         ,	default=-1   ,				help="radius for 3D variability" )
    parser.add_option(
        "--npad",
        type="int",
        default=2,
        help=
        "Number of time to pad the original images.(Default is 2 times padding)",
    )
    parser.add_option("--sym",
                      type="string",
                      default="c1",
                      help="Symmetry. (Default is no symmetry)")
    parser.add_option(
        "--fl",
        type="float",
        default=0.0,
        help=
        "Low pass filter cutoff in absolute frequency (0.0 - 0.5) and is applied to decimated images. (Default - no filtration)",
    )
    parser.add_option(
        "--aa",
        type="float",
        default=0.02,
        help=
        "Fall off of the filter. Use default value if user has no clue about falloff (Default value is 0.02)",
    )
    parser.add_option(
        "--CTF",
        action="store_true",
        default=False,
        help="Use CFT correction.(Default is no CTF correction)",
    )
    # parser.add_option("--MPI" , 		action="store_true",	default=False,				help="use MPI version")
    # parser.add_option("--radiuspca", 	type="int"         ,	default=-1   ,				help="radius for PCA" )
    # parser.add_option("--iter", 		type="int"         ,	default=40   ,				help="maximum number of iterations (stop criterion of reconstruction process)" )
    # parser.add_option("--abs", 		type="float"   ,        default=0.0  ,				help="minimum average absolute change of voxels' values (stop criterion of reconstruction process)" )
    # parser.add_option("--squ", 		type="float"   ,	    default=0.0  ,				help="minimum average squared change of voxels' values (stop criterion of reconstruction process)" )
    parser.add_option(
        "--VAR",
        action="store_true",
        default=False,
        help="Stack of input consists of 2D variances (Default False)",
    )
    parser.add_option(
        "--decimate",
        type="float",
        default=0.25,
        help="Image decimate rate, a number less than 1. (Default is 0.25)",
    )
    parser.add_option(
        "--window",
        type="int",
        default=0,
        help=
        "Target image size relative to original image size. (Default value is zero.)",
    )
    # parser.add_option("--SND",			action="store_true",	default=False,				help="compute squared normalized differences (Default False)")
    # parser.add_option("--nvec",			type="int"         ,	default=0    ,				help="Number of eigenvectors, (Default = 0 meaning no PCA calculated)")
    parser.add_option(
        "--symmetrize",
        action="store_true",
        default=False,
        help="Prepare input stack for handling symmetry (Default False)",
    )
    parser.add_option("--overhead",
                      type="float",
                      default=0.5,
                      help="python overhead per CPU.")

    (options, args) = parser.parse_args()
    #####
    # from mpi import *

    #  This is code for handling symmetries by the above program.  To be incorporated. PAP 01/27/2015

    # Set up global variables related to bdb cache
    if sp_global_def.CACHE_DISABLE:
        sp_utilities.disable_bdb_cache()

    # Set up global variables related to ERROR function
    sp_global_def.BATCH = True

    # detect if program is running under MPI
    RUNNING_UNDER_MPI = "OMPI_COMM_WORLD_SIZE" in optparse.os.environ
    if RUNNING_UNDER_MPI:
        sp_global_def.MPI = True
    if options.output_dir == "./":
        current_output_dir = optparse.os.path.abspath(options.output_dir)
    else:
        current_output_dir = options.output_dir
    if options.symmetrize:

        if mpi.mpi_comm_size(mpi.MPI_COMM_WORLD) > 1:
            sp_global_def.ERROR(
                "Cannot use more than one CPU for symmetry preparation")

        if not optparse.os.path.exists(current_output_dir):
            optparse.os.makedirs(current_output_dir)
            sp_global_def.write_command(current_output_dir)

        if optparse.os.path.exists(
                optparse.os.path.join(current_output_dir, "log.txt")):
            optparse.os.remove(
                optparse.os.path.join(current_output_dir, "log.txt"))
        log_main = sp_logger.Logger(sp_logger.BaseLogger_Files())
        log_main.prefix = optparse.os.path.join(current_output_dir, "./")

        instack = args[0]
        sym = options.sym.lower()
        if sym == "c1":
            sp_global_def.ERROR(
                "There is no need to symmetrize stack for C1 symmetry")

        line = ""
        for a in sys.argv:
            line += " " + a
        log_main.add(line)

        if instack[:4] != "bdb:":
            # if output_dir =="./": stack = "bdb:data"
            stack = "bdb:" + current_output_dir + "/data"
            sp_utilities.delete_bdb(stack)
            junk = sp_utilities.cmdexecute("sp_cpy.py  " + instack + "  " +
                                           stack)
        else:
            stack = instack

        qt = EMAN2_cppwrap.EMUtil.get_all_attributes(stack, "xform.projection")

        na = len(qt)
        ts = sp_utilities.get_symt(sym)
        ks = len(ts)
        angsa = [None] * na

        for k in range(ks):
            # Qfile = "Q%1d"%k
            # if options.output_dir!="./": Qfile = os.path.join(options.output_dir,"Q%1d"%k)
            Qfile = optparse.os.path.join(current_output_dir, "Q%1d" % k)
            # delete_bdb("bdb:Q%1d"%k)
            sp_utilities.delete_bdb("bdb:" + Qfile)
            # junk = cmdexecute("e2bdb.py  "+stack+"  --makevstack=bdb:Q%1d"%k)
            junk = sp_utilities.cmdexecute("e2bdb.py  " + stack +
                                           "  --makevstack=bdb:" + Qfile)
            # DB = db_open_dict("bdb:Q%1d"%k)
            DB = EMAN2db.db_open_dict("bdb:" + Qfile)
            for i in range(na):
                ut = qt[i] * ts[k]
                DB.set_attr(i, "xform.projection", ut)
                # bt = ut.get_params("spider")
                # angsa[i] = [round(bt["phi"],3)%360.0, round(bt["theta"],3)%360.0, bt["psi"], -bt["tx"], -bt["ty"]]
            # write_text_row(angsa, 'ptsma%1d.txt'%k)
            # junk = cmdexecute("e2bdb.py  "+stack+"  --makevstack=bdb:Q%1d"%k)
            # junk = cmdexecute("sxheader.py  bdb:Q%1d  --params=xform.projection  --import=ptsma%1d.txt"%(k,k))
            DB.close()
        # if options.output_dir =="./": delete_bdb("bdb:sdata")
        sp_utilities.delete_bdb("bdb:" + current_output_dir + "/" + "sdata")
        # junk = cmdexecute("e2bdb.py . --makevstack=bdb:sdata --filt=Q")
        sdata = "bdb:" + current_output_dir + "/" + "sdata"
        sp_global_def.sxprint(sdata)
        junk = sp_utilities.cmdexecute("e2bdb.py   " + current_output_dir +
                                       "  --makevstack=" + sdata + " --filt=Q")
        # junk = cmdexecute("ls  EMAN2DB/sdata*")
        # a = get_im("bdb:sdata")
        a = sp_utilities.get_im(sdata)
        a.set_attr("variabilitysymmetry", sym)
        # a.write_image("bdb:sdata")
        a.write_image(sdata)

    else:

        myid = mpi.mpi_comm_rank(mpi.MPI_COMM_WORLD)
        number_of_proc = mpi.mpi_comm_size(mpi.MPI_COMM_WORLD)
        main_node = 0
        shared_comm = mpi.mpi_comm_split_type(mpi.MPI_COMM_WORLD,
                                              mpi.MPI_COMM_TYPE_SHARED, 0,
                                              mpi.MPI_INFO_NULL)
        myid_on_node = mpi.mpi_comm_rank(shared_comm)
        no_of_processes_per_group = mpi.mpi_comm_size(shared_comm)
        masters_from_groups_vs_everything_else_comm = mpi.mpi_comm_split(
            mpi.MPI_COMM_WORLD, main_node == myid_on_node, myid_on_node)
        color, no_of_groups, balanced_processor_load_on_nodes = sp_utilities.get_colors_and_subsets(
            main_node,
            mpi.MPI_COMM_WORLD,
            myid,
            shared_comm,
            myid_on_node,
            masters_from_groups_vs_everything_else_comm,
        )
        overhead_loading = options.overhead * number_of_proc
        # memory_per_node  = options.memory_per_node
        # if memory_per_node == -1.: memory_per_node = 2.*no_of_processes_per_group
        keepgoing = 1

        current_window = options.window
        current_decimate = options.decimate

        if len(args) == 1:
            stack = args[0]
        else:
            sp_global_def.sxprint("Usage: " + usage)
            sp_global_def.sxprint("Please run '" + progname +
                                  " -h' for detailed options")
            sp_global_def.ERROR(
                "Invalid number of parameters used. Please see usage information above."
            )
            return

        t0 = time.time()
        # obsolete flags
        options.MPI = True
        # options.nvec = 0
        options.radiuspca = -1
        options.iter = 40
        options.abs = 0.0
        options.squ = 0.0

        if options.fl > 0.0 and options.aa == 0.0:
            sp_global_def.ERROR(
                "Fall off has to be given for the low-pass filter", myid=myid)

        # if options.VAR and options.SND:
        # 	ERROR( "Only one of var and SND can be set!",myid=myid )

        if options.VAR and (options.ave2D or options.ave3D or options.var2D):
            sp_global_def.ERROR(
                "When VAR is set, the program cannot output ave2D, ave3D or var2D",
                myid=myid,
            )

        # if options.SND and (options.ave2D or options.ave3D):
        # 	ERROR( "When SND is set, the program cannot output ave2D or ave3D", myid=myid )

        # if options.nvec > 0 :
        # 	ERROR( "PCA option not implemented", myid=myid )

        # if options.nvec > 0 and options.ave3D == None:
        # 	ERROR( "When doing PCA analysis, one must set ave3D", myid=myid )

        if current_decimate > 1.0 or current_decimate < 0.0:
            sp_global_def.ERROR(
                "Decimate rate should be a value between 0.0 and 1.0",
                myid=myid)

        if current_window < 0.0:
            sp_global_def.ERROR(
                "Target window size should be always larger than zero",
                myid=myid)

        if myid == main_node:
            img = sp_utilities.get_image(stack, 0)
            nx = img.get_xsize()
            ny = img.get_ysize()
            if min(nx, ny) < current_window:
                keepgoing = 0
        keepgoing = sp_utilities.bcast_number_to_all(keepgoing, main_node,
                                                     mpi.MPI_COMM_WORLD)
        if keepgoing == 0:
            sp_global_def.ERROR(
                "The target window size cannot be larger than the size of decimated image",
                myid=myid,
            )

        options.sym = options.sym.lower()
        # if global_def.CACHE_DISABLE:
        # 	from utilities import disable_bdb_cache
        # 	disable_bdb_cache()
        # global_def.BATCH = True

        if myid == main_node:
            if not optparse.os.path.exists(current_output_dir):
                optparse.os.makedirs(
                    current_output_dir
                )  # Never delete output_dir in the program!

        img_per_grp = options.img_per_grp
        # nvec        = options.nvec
        radiuspca = options.radiuspca
        # if os.path.exists(os.path.join(options.output_dir, "log.txt")): os.remove(os.path.join(options.output_dir, "log.txt"))
        log_main = sp_logger.Logger(sp_logger.BaseLogger_Files())
        log_main.prefix = optparse.os.path.join(current_output_dir, "./")

        if myid == main_node:
            line = ""
            for a in sys.argv:
                line += " " + a
            log_main.add(line)
            log_main.add("-------->>>Settings given by all options<<<-------")
            log_main.add("Symmetry             : %s" % options.sym)
            log_main.add("Input stack          : %s" % stack)
            log_main.add("Output_dir           : %s" % current_output_dir)

            if options.ave3D:
                log_main.add("Ave3d                : %s" % options.ave3D)
            if options.var3D:
                log_main.add("Var3d                : %s" % options.var3D)
            if options.ave2D:
                log_main.add("Ave2D                : %s" % options.ave2D)
            if options.var2D:
                log_main.add("Var2D                : %s" % options.var2D)
            if options.VAR:
                log_main.add("VAR                  : True")
            else:
                log_main.add("VAR                  : False")
            if options.CTF:
                log_main.add("CTF correction       : True  ")
            else:
                log_main.add("CTF correction       : False ")

            log_main.add("Image per group      : %5d" % options.img_per_grp)
            log_main.add("Image decimate rate  : %4.3f" % current_decimate)
            log_main.add("Low pass filter      : %4.3f" % options.fl)
            current_fl = options.fl
            if current_fl == 0.0:
                current_fl = 0.5
            log_main.add(
                "Current low pass filter is equivalent to cutoff frequency %4.3f for original image size"
                % round((current_fl * current_decimate), 3))
            log_main.add("Window size          : %5d " % current_window)
            log_main.add("sx3dvariability begins")

        symbaselen = 0
        if myid == main_node:
            nima = EMAN2_cppwrap.EMUtil.get_image_count(stack)
            img = sp_utilities.get_image(stack)
            nx = img.get_xsize()
            ny = img.get_ysize()
            nnxo = nx
            nnyo = ny
            if options.sym != "c1":
                imgdata = sp_utilities.get_im(stack)
                try:
                    i = imgdata.get_attr("variabilitysymmetry").lower()
                    if i != options.sym:
                        sp_global_def.ERROR(
                            "The symmetry provided does not agree with the symmetry of the input stack",
                            myid=myid,
                        )
                except:
                    sp_global_def.ERROR(
                        "Input stack is not prepared for symmetry, please follow instructions",
                        myid=myid,
                    )
                i = len(sp_utilities.get_symt(options.sym))
                if (old_div(nima, i)) * i != nima:
                    sp_global_def.ERROR(
                        "The length of the input stack is incorrect for symmetry processing",
                        myid=myid,
                    )
                symbaselen = old_div(nima, i)
            else:
                symbaselen = nima
        else:
            nima = 0
            nx = 0
            ny = 0
            nnxo = 0
            nnyo = 0
        nima = sp_utilities.bcast_number_to_all(nima)
        nx = sp_utilities.bcast_number_to_all(nx)
        ny = sp_utilities.bcast_number_to_all(ny)
        nnxo = sp_utilities.bcast_number_to_all(nnxo)
        nnyo = sp_utilities.bcast_number_to_all(nnyo)
        if current_window > max(nx, ny):
            sp_global_def.ERROR(
                "Window size is larger than the original image size")

        if current_decimate == 1.0:
            if current_window != 0:
                nx = current_window
                ny = current_window
        else:
            if current_window == 0:
                nx = int(nx * current_decimate + 0.5)
                ny = int(ny * current_decimate + 0.5)
            else:
                nx = int(current_window * current_decimate + 0.5)
                ny = nx
        symbaselen = sp_utilities.bcast_number_to_all(symbaselen)

        # check FFT prime number
        is_fft_friendly = nx == sp_fundamentals.smallprime(nx)

        if not is_fft_friendly:
            if myid == main_node:
                log_main.add(
                    "The target image size is not a product of small prime numbers"
                )
                log_main.add("Program adjusts the input settings!")
            ### two cases
            if current_decimate == 1.0:
                nx = sp_fundamentals.smallprime(nx)
                ny = nx
                current_window = nx  # update
                if myid == main_node:
                    log_main.add("The window size is updated to %d." %
                                 current_window)
            else:
                if current_window == 0:
                    nx = sp_fundamentals.smallprime(
                        int(nx * current_decimate + 0.5))
                    current_decimate = old_div(float(nx), nnxo)
                    ny = nx
                    if myid == main_node:
                        log_main.add("The decimate rate is updated to %f." %
                                     current_decimate)
                else:
                    nx = sp_fundamentals.smallprime(
                        int(current_window * current_decimate + 0.5))
                    ny = nx
                    current_window = int(old_div(nx, current_decimate) + 0.5)
                    if myid == main_node:
                        log_main.add("The window size is updated to %d." %
                                     current_window)

        if myid == main_node:
            log_main.add("The target image size is %d" % nx)

        if radiuspca == -1:
            radiuspca = old_div(nx, 2) - 2
        if myid == main_node:
            log_main.add("%-70s:  %d\n" % ("Number of projection", nima))
        img_begin, img_end = sp_applications.MPI_start_end(
            nima, number_of_proc, myid)
        """Multiline Comment0"""
        """
        Comments from adnan, replace index_of_proj to index_of_particle, index_of_proj was not defined
        also varList is not defined not made an empty list there
        """

        if options.VAR:  # 2D variance images have no shifts
            varList = []
            # varList   = EMData.read_images(stack, range(img_begin, img_end))
            for index_of_particle in range(img_begin, img_end):
                image = sp_utilities.get_im(stack, index_of_particle)
                if current_window > 0:
                    varList.append(
                        sp_fundamentals.fdecimate(
                            sp_fundamentals.window2d(image, current_window,
                                                     current_window),
                            nx,
                            ny,
                        ))
                else:
                    varList.append(sp_fundamentals.fdecimate(image, nx, ny))

        else:
            if myid == main_node:
                t1 = time.time()
                proj_angles = []
                aveList = []
                tab = EMAN2_cppwrap.EMUtil.get_all_attributes(
                    stack, "xform.projection")
                for i in range(nima):
                    t = tab[i].get_params("spider")
                    phi = t["phi"]
                    theta = t["theta"]
                    psi = t["psi"]
                    x = theta
                    if x > 90.0:
                        x = 180.0 - x
                    x = x * 10000 + psi
                    proj_angles.append([x, t["phi"], t["theta"], t["psi"], i])
                t2 = time.time()
                log_main.add(
                    "%-70s:  %d\n" %
                    ("Number of neighboring projections", img_per_grp))
                log_main.add("...... Finding neighboring projections\n")
                log_main.add("Number of images per group: %d" % img_per_grp)
                log_main.add("Now grouping projections")
                proj_angles.sort()
                proj_angles_list = numpy.full((nima, 4),
                                              0.0,
                                              dtype=numpy.float32)
                for i in range(nima):
                    proj_angles_list[i][0] = proj_angles[i][1]
                    proj_angles_list[i][1] = proj_angles[i][2]
                    proj_angles_list[i][2] = proj_angles[i][3]
                    proj_angles_list[i][3] = proj_angles[i][4]
            else:
                proj_angles_list = 0
            proj_angles_list = sp_utilities.wrap_mpi_bcast(
                proj_angles_list, main_node, mpi.MPI_COMM_WORLD)
            proj_angles = []
            for i in range(nima):
                proj_angles.append([
                    proj_angles_list[i][0],
                    proj_angles_list[i][1],
                    proj_angles_list[i][2],
                    int(proj_angles_list[i][3]),
                ])
            del proj_angles_list
            proj_list, mirror_list = sp_utilities.nearest_proj(
                proj_angles, img_per_grp, range(img_begin, img_end))
            all_proj = []
            for im in proj_list:
                for jm in im:
                    all_proj.append(proj_angles[jm][3])
            all_proj = list(set(all_proj))
            index = {}
            for i in range(len(all_proj)):
                index[all_proj[i]] = i
            mpi.mpi_barrier(mpi.MPI_COMM_WORLD)
            if myid == main_node:
                log_main.add("%-70s:  %.2f\n" %
                             ("Finding neighboring projections lasted [s]",
                              time.time() - t2))
                log_main.add("%-70s:  %d\n" %
                             ("Number of groups processed on the main node",
                              len(proj_list)))
                log_main.add("Grouping projections took:  %12.1f [m]" %
                             (old_div((time.time() - t2), 60.0)))
                log_main.add("Number of groups on main node: ", len(proj_list))
            mpi.mpi_barrier(mpi.MPI_COMM_WORLD)

            if myid == main_node:
                log_main.add("...... Calculating the stack of 2D variances \n")
            # Memory estimation. There are two memory consumption peaks
            # peak 1. Compute ave, var;
            # peak 2. Var volume reconstruction;
            # proj_params = [0.0]*(nima*5)
            aveList = []
            varList = []
            # if nvec > 0: eigList = [[] for i in range(nvec)]
            dnumber = len(
                all_proj)  # all neighborhood set for assigned to myid
            pnumber = len(proj_list) * 2.0 + img_per_grp  # aveList and varList
            tnumber = dnumber + pnumber
            vol_size2 = old_div(nx**3 * 4.0 * 8, 1.0e9)
            vol_size1 = old_div(2.0 * nnxo**3 * 4.0 * 8, 1.0e9)
            proj_size = old_div(nnxo * nnyo * len(proj_list) * 4.0 * 2.0,
                                1.0e9)  # both aveList and varList
            orig_data_size = old_div(nnxo * nnyo * 4.0 * tnumber, 1.0e9)
            reduced_data_size = old_div(nx * nx * 4.0 * tnumber, 1.0e9)
            full_data = numpy.full((number_of_proc, 2),
                                   -1.0,
                                   dtype=numpy.float16)
            full_data[myid] = orig_data_size, reduced_data_size
            if myid != main_node:
                sp_utilities.wrap_mpi_send(full_data, main_node,
                                           mpi.MPI_COMM_WORLD)
            if myid == main_node:
                for iproc in range(number_of_proc):
                    if iproc != main_node:
                        dummy = sp_utilities.wrap_mpi_recv(
                            iproc, mpi.MPI_COMM_WORLD)
                        full_data[numpy.where(dummy > -1)] = dummy[numpy.where(
                            dummy > -1)]
                del dummy
            mpi.mpi_barrier(mpi.MPI_COMM_WORLD)
            full_data = sp_utilities.wrap_mpi_bcast(full_data, main_node,
                                                    mpi.MPI_COMM_WORLD)
            # find the CPU with heaviest load
            minindx = numpy.argsort(full_data, 0)
            heavy_load_myid = minindx[-1][1]
            total_mem = sum(full_data)
            if myid == main_node:
                if current_window == 0:
                    log_main.add(
                        "Nx:   current image size = %d. Decimated by %f from %d"
                        % (nx, current_decimate, nnxo))
                else:
                    log_main.add(
                        "Nx:   current image size = %d. Windowed to %d, and decimated by %f from %d"
                        % (nx, current_window, current_decimate, nnxo))
                log_main.add("Nproj:       number of particle images.")
                log_main.add("Navg:        number of 2D average images.")
                log_main.add("Nvar:        number of 2D variance images.")
                log_main.add(
                    "Img_per_grp: user defined image per group for averaging = %d"
                    % img_per_grp)
                log_main.add(
                    "Overhead:    total python overhead memory consumption   = %f"
                    % overhead_loading)
                log_main.add(
                    "Total memory) = 4.0*nx^2*(nproj + navg +nvar+ img_per_grp)/1.0e9 + overhead: %12.3f [GB]"
                    % (total_mem[1] + overhead_loading))
            del full_data
            mpi.mpi_barrier(mpi.MPI_COMM_WORLD)
            if myid == heavy_load_myid:
                log_main.add(
                    "Begin reading and preprocessing images on processor. Wait... "
                )
                ttt = time.time()
            # imgdata = EMData.read_images(stack, all_proj)
            imgdata = [None for im in range(len(all_proj))]
            for index_of_proj in range(len(all_proj)):
                # image = get_im(stack, all_proj[index_of_proj])
                if current_window > 0:
                    imgdata[index_of_proj] = sp_fundamentals.fdecimate(
                        sp_fundamentals.window2d(
                            sp_utilities.get_im(stack,
                                                all_proj[index_of_proj]),
                            current_window,
                            current_window,
                        ),
                        nx,
                        ny,
                    )
                else:
                    imgdata[index_of_proj] = sp_fundamentals.fdecimate(
                        sp_utilities.get_im(stack, all_proj[index_of_proj]),
                        nx, ny)

                if current_decimate > 0.0 and options.CTF:
                    ctf = imgdata[index_of_proj].get_attr("ctf")
                    ctf.apix = old_div(ctf.apix, current_decimate)
                    imgdata[index_of_proj].set_attr("ctf", ctf)

                if myid == heavy_load_myid and index_of_proj % 100 == 0:
                    log_main.add(
                        " ...... %6.2f%% " %
                        (old_div(index_of_proj, float(len(all_proj))) * 100.0))
            mpi.mpi_barrier(mpi.MPI_COMM_WORLD)
            if myid == heavy_load_myid:
                log_main.add("All_proj preprocessing cost %7.2f m" % (old_div(
                    (time.time() - ttt), 60.0)))
                log_main.add("Wait untill reading on all CPUs done...")
            """Multiline Comment1"""
            if not options.no_norm:
                mask = sp_utilities.model_circle(old_div(nx, 2) - 2, nx, nx)
            if myid == heavy_load_myid:
                log_main.add("Start computing 2D aveList and varList. Wait...")
                ttt = time.time()
            inner = old_div(nx, 2) - 4
            outer = inner + 2
            xform_proj_for_2D = [None for i in range(len(proj_list))]
            for i in range(len(proj_list)):
                ki = proj_angles[proj_list[i][0]][3]
                if ki >= symbaselen:
                    continue
                mi = index[ki]
                dpar = EMAN2_cppwrap.Util.get_transform_params(
                    imgdata[mi], "xform.projection", "spider")
                phiM, thetaM, psiM, s2xM, s2yM = (
                    dpar["phi"],
                    dpar["theta"],
                    dpar["psi"],
                    -dpar["tx"] * current_decimate,
                    -dpar["ty"] * current_decimate,
                )
                grp_imgdata = []
                for j in range(img_per_grp):
                    mj = index[proj_angles[proj_list[i][j]][3]]
                    cpar = EMAN2_cppwrap.Util.get_transform_params(
                        imgdata[mj], "xform.projection", "spider")
                    alpha, sx, sy, mirror = params_3D_2D_NEW(
                        cpar["phi"],
                        cpar["theta"],
                        cpar["psi"],
                        -cpar["tx"] * current_decimate,
                        -cpar["ty"] * current_decimate,
                        mirror_list[i][j],
                    )
                    if thetaM <= 90:
                        if mirror == 0:
                            alpha, sx, sy, scale = sp_utilities.compose_transform2(
                                alpha, sx, sy, 1.0, phiM - cpar["phi"], 0.0,
                                0.0, 1.0)
                        else:
                            alpha, sx, sy, scale = sp_utilities.compose_transform2(
                                alpha,
                                sx,
                                sy,
                                1.0,
                                180 - (phiM - cpar["phi"]),
                                0.0,
                                0.0,
                                1.0,
                            )
                    else:
                        if mirror == 0:
                            alpha, sx, sy, scale = sp_utilities.compose_transform2(
                                alpha, sx, sy, 1.0, -(phiM - cpar["phi"]), 0.0,
                                0.0, 1.0)
                        else:
                            alpha, sx, sy, scale = sp_utilities.compose_transform2(
                                alpha,
                                sx,
                                sy,
                                1.0,
                                -(180 - (phiM - cpar["phi"])),
                                0.0,
                                0.0,
                                1.0,
                            )
                    imgdata[mj].set_attr(
                        "xform.align2d",
                        EMAN2_cppwrap.Transform({
                            "type": "2D",
                            "alpha": alpha,
                            "tx": sx,
                            "ty": sy,
                            "mirror": mirror,
                            "scale": 1.0,
                        }),
                    )
                    grp_imgdata.append(imgdata[mj])
                if not options.no_norm:
                    for k in range(img_per_grp):
                        ave, std, minn, maxx = EMAN2_cppwrap.Util.infomask(
                            grp_imgdata[k], mask, False)
                        grp_imgdata[k] -= ave
                        grp_imgdata[k] = old_div(grp_imgdata[k], std)
                if options.fl > 0.0:
                    for k in range(img_per_grp):
                        grp_imgdata[k] = sp_filter.filt_tanl(
                            grp_imgdata[k], options.fl, options.aa)

                #  Because of background issues, only linear option works.
                if options.CTF:
                    ave, var = sp_statistics.aves_wiener(
                        grp_imgdata, SNR=1.0e5, interpolation_method="linear")
                else:
                    ave, var = sp_statistics.ave_var(grp_imgdata)
                # Switch to std dev
                # threshold is not really needed,it is just in case due to numerical accuracy something turns out negative.
                var = sp_morphology.square_root(sp_morphology.threshold(var))

                sp_utilities.set_params_proj(ave,
                                             [phiM, thetaM, 0.0, 0.0, 0.0])
                sp_utilities.set_params_proj(var,
                                             [phiM, thetaM, 0.0, 0.0, 0.0])

                aveList.append(ave)
                varList.append(var)
                xform_proj_for_2D[i] = [phiM, thetaM, 0.0, 0.0, 0.0]
                """Multiline Comment2"""
                if (myid == heavy_load_myid) and (i % 100 == 0):
                    log_main.add(" ......%6.2f%%  " %
                                 (old_div(i, float(len(proj_list))) * 100.0))
            del imgdata, grp_imgdata, cpar, dpar, all_proj, proj_angles, index
            if not options.no_norm:
                del mask
            if myid == main_node:
                del tab
            #  At this point, all averages and variances are computed
            mpi.mpi_barrier(mpi.MPI_COMM_WORLD)

            if myid == heavy_load_myid:
                log_main.add("Computing aveList and varList took %12.1f [m]" %
                             (old_div((time.time() - ttt), 60.0)))

            xform_proj_for_2D = sp_utilities.wrap_mpi_gatherv(
                xform_proj_for_2D, main_node, mpi.MPI_COMM_WORLD)
            if myid == main_node:
                sp_utilities.write_text_row(
                    [str(entry) for entry in xform_proj_for_2D],
                    optparse.os.path.join(current_output_dir, "params.txt"),
                )
            del xform_proj_for_2D
            mpi.mpi_barrier(mpi.MPI_COMM_WORLD)
            if options.ave2D:
                if myid == main_node:
                    log_main.add("Compute ave2D ... ")
                    km = 0
                    for i in range(number_of_proc):
                        if i == main_node:
                            for im in range(len(aveList)):
                                aveList[im].write_image(
                                    optparse.os.path.join(
                                        current_output_dir, options.ave2D),
                                    km,
                                )
                                km += 1
                        else:
                            nl = mpi.mpi_recv(
                                1,
                                mpi.MPI_INT,
                                i,
                                sp_global_def.SPARX_MPI_TAG_UNIVERSAL,
                                mpi.MPI_COMM_WORLD,
                            )
                            nl = int(nl[0])
                            for im in range(nl):
                                ave = sp_utilities.recv_EMData(
                                    i, im + i + 70000)
                                """Multiline Comment3"""
                                tmpvol = sp_fundamentals.fpol(ave, nx, nx, 1)
                                tmpvol.write_image(
                                    optparse.os.path.join(
                                        current_output_dir, options.ave2D),
                                    km,
                                )
                                km += 1
                else:
                    mpi.mpi_send(
                        len(aveList),
                        1,
                        mpi.MPI_INT,
                        main_node,
                        sp_global_def.SPARX_MPI_TAG_UNIVERSAL,
                        mpi.MPI_COMM_WORLD,
                    )
                    for im in range(len(aveList)):
                        sp_utilities.send_EMData(aveList[im], main_node,
                                                 im + myid + 70000)
                        """Multiline Comment4"""
                if myid == main_node:
                    sp_applications.header(
                        optparse.os.path.join(current_output_dir,
                                              options.ave2D),
                        params="xform.projection",
                        fimport=optparse.os.path.join(current_output_dir,
                                                      "params.txt"),
                    )
                mpi.mpi_barrier(mpi.MPI_COMM_WORLD)
            if options.ave3D:
                t5 = time.time()
                if myid == main_node:
                    log_main.add("Reconstruct ave3D ... ")
                ave3D = sp_reconstruction.recons3d_4nn_MPI(
                    myid, aveList, symmetry=options.sym, npad=options.npad)
                sp_utilities.bcast_EMData_to_all(ave3D, myid)
                if myid == main_node:
                    if current_decimate != 1.0:
                        ave3D = sp_fundamentals.resample(
                            ave3D, old_div(1.0, current_decimate))
                    ave3D = sp_fundamentals.fpol(
                        ave3D, nnxo, nnxo,
                        nnxo)  # always to the orignal image size
                    sp_utilities.set_pixel_size(ave3D, 1.0)
                    ave3D.write_image(
                        optparse.os.path.join(current_output_dir,
                                              options.ave3D))
                    log_main.add("Ave3D reconstruction took %12.1f [m]" %
                                 (old_div((time.time() - t5), 60.0)))
                    log_main.add("%-70s:  %s\n" %
                                 ("The reconstructed ave3D is saved as ",
                                  options.ave3D))

            mpi.mpi_barrier(mpi.MPI_COMM_WORLD)
            del ave, var, proj_list, stack, alpha, sx, sy, mirror, aveList
            """Multiline Comment5"""

            if options.ave3D:
                del ave3D
            if options.var2D:
                if myid == main_node:
                    log_main.add("Compute var2D...")
                    km = 0
                    for i in range(number_of_proc):
                        if i == main_node:
                            for im in range(len(varList)):
                                tmpvol = sp_fundamentals.fpol(
                                    varList[im], nx, nx, 1)
                                tmpvol.write_image(
                                    optparse.os.path.join(
                                        current_output_dir, options.var2D),
                                    km,
                                )
                                km += 1
                        else:
                            nl = mpi.mpi_recv(
                                1,
                                mpi.MPI_INT,
                                i,
                                sp_global_def.SPARX_MPI_TAG_UNIVERSAL,
                                mpi.MPI_COMM_WORLD,
                            )
                            nl = int(nl[0])
                            for im in range(nl):
                                ave = sp_utilities.recv_EMData(
                                    i, im + i + 70000)
                                tmpvol = sp_fundamentals.fpol(ave, nx, nx, 1)
                                tmpvol.write_image(
                                    optparse.os.path.join(
                                        current_output_dir, options.var2D),
                                    km,
                                )
                                km += 1
                else:
                    mpi.mpi_send(
                        len(varList),
                        1,
                        mpi.MPI_INT,
                        main_node,
                        sp_global_def.SPARX_MPI_TAG_UNIVERSAL,
                        mpi.MPI_COMM_WORLD,
                    )
                    for im in range(len(varList)):
                        sp_utilities.send_EMData(
                            varList[im], main_node,
                            im + myid + 70000)  # What with the attributes??
                mpi.mpi_barrier(mpi.MPI_COMM_WORLD)
                if myid == main_node:
                    sp_applications.header(
                        optparse.os.path.join(current_output_dir,
                                              options.var2D),
                        params="xform.projection",
                        fimport=optparse.os.path.join(current_output_dir,
                                                      "params.txt"),
                    )
                mpi.mpi_barrier(mpi.MPI_COMM_WORLD)
        if options.var3D:
            if myid == main_node:
                log_main.add("Reconstruct var3D ...")
            t6 = time.time()
            # radiusvar = options.radius
            # if( radiusvar < 0 ):  radiusvar = nx//2 -3
            res = sp_reconstruction.recons3d_4nn_MPI(myid,
                                                     varList,
                                                     symmetry=options.sym,
                                                     npad=options.npad)
            # res = recons3d_em_MPI(varList, vol_stack, options.iter, radiusvar, options.abs, True, options.sym, options.squ)
            if myid == main_node:
                if current_decimate != 1.0:
                    res = sp_fundamentals.resample(
                        res, old_div(1.0, current_decimate))
                res = sp_fundamentals.fpol(res, nnxo, nnxo, nnxo)
                sp_utilities.set_pixel_size(res, 1.0)
                res.write_image(os.path.join(current_output_dir,
                                             options.var3D))
                log_main.add(
                    "%-70s:  %s\n" %
                    ("The reconstructed var3D is saved as ", options.var3D))
                log_main.add("Var3D reconstruction took %f12.1 [m]" % (old_div(
                    (time.time() - t6), 60.0)))
                log_main.add("Total computation time %f12.1 [m]" % (old_div(
                    (time.time() - t0), 60.0)))
                log_main.add("sx3dvariability finishes")

        if RUNNING_UNDER_MPI:
            sp_global_def.MPI = False

        sp_global_def.BATCH = False
Ejemplo n.º 12
0
def main():
    global Tracker, Blockdata
    progname = os.path.basename(sys.argv[0])
    usage = progname + " --output_dir=output_dir  --isac_dir=output_dir_of_isac "
    parser = optparse.OptionParser(usage, version=sp_global_def.SPARXVERSION)
    parser.add_option(
        "--pw_adjustment",
        type="string",
        default="analytical_model",
        help=
        "adjust power spectrum of 2-D averages to an analytic model. Other opions: no_adjustment; bfactor; a text file of 1D rotationally averaged PW",
    )
    #### Four options for --pw_adjustment:
    # 1> analytical_model(default);
    # 2> no_adjustment;
    # 3> bfactor;
    # 4> adjust_to_given_pw2(user has to provide a text file that contains 1D rotationally averaged PW)

    # options in common
    parser.add_option(
        "--isac_dir",
        type="string",
        default="",
        help="ISAC run output directory, input directory for this command",
    )
    parser.add_option(
        "--output_dir",
        type="string",
        default="",
        help="output directory where computed averages are saved",
    )
    parser.add_option(
        "--pixel_size",
        type="float",
        default=-1.0,
        help=
        "pixel_size of raw images. one can put 1.0 in case of negative stain data",
    )
    parser.add_option(
        "--fl",
        type="float",
        default=-1.0,
        help=
        "low pass filter, = -1.0, not applied; =0.0, using FH1 (initial resolution), = 1.0 using FH2 (resolution after local alignment), or user provided value in absolute freqency [0.0:0.5]",
    )
    parser.add_option("--stack",
                      type="string",
                      default="",
                      help="data stack used in ISAC")
    parser.add_option("--radius", type="int", default=-1, help="radius")
    parser.add_option("--xr",
                      type="float",
                      default=-1.0,
                      help="local alignment search range")
    # parser.add_option("--ts",                    type   ="float",          default =1.0,    help= "local alignment search step")
    parser.add_option(
        "--fh",
        type="float",
        default=-1.0,
        help="local alignment high frequencies limit",
    )
    # parser.add_option("--maxit",                 type   ="int",            default =5,      help= "local alignment iterations")
    parser.add_option("--navg",
                      type="int",
                      default=1000000,
                      help="number of aveages")
    parser.add_option(
        "--local_alignment",
        action="store_true",
        default=False,
        help="do local alignment",
    )
    parser.add_option(
        "--noctf",
        action="store_true",
        default=False,
        help=
        "no ctf correction, useful for negative stained data. always ctf for cryo data",
    )
    parser.add_option(
        "--B_start",
        type="float",
        default=45.0,
        help=
        "start frequency (Angstrom) of power spectrum for B_factor estimation",
    )
    parser.add_option(
        "--Bfactor",
        type="float",
        default=-1.0,
        help=
        "User defined bactors (e.g. 25.0[A^2]). By default, the program automatically estimates B-factor. ",
    )

    (options, args) = parser.parse_args(sys.argv[1:])

    adjust_to_analytic_model = (True if options.pw_adjustment
                                == "analytical_model" else False)
    no_adjustment = True if options.pw_adjustment == "no_adjustment" else False
    B_enhance = True if options.pw_adjustment == "bfactor" else False
    adjust_to_given_pw2 = (
        True if not (adjust_to_analytic_model or no_adjustment or B_enhance)
        else False)

    # mpi
    nproc = mpi.mpi_comm_size(mpi.MPI_COMM_WORLD)
    myid = mpi.mpi_comm_rank(mpi.MPI_COMM_WORLD)

    Blockdata = {}
    Blockdata["nproc"] = nproc
    Blockdata["myid"] = myid
    Blockdata["main_node"] = 0
    Blockdata["shared_comm"] = mpi.mpi_comm_split_type(
        mpi.MPI_COMM_WORLD, mpi.MPI_COMM_TYPE_SHARED, 0, mpi.MPI_INFO_NULL)
    Blockdata["myid_on_node"] = mpi.mpi_comm_rank(Blockdata["shared_comm"])
    Blockdata["no_of_processes_per_group"] = mpi.mpi_comm_size(
        Blockdata["shared_comm"])
    masters_from_groups_vs_everything_else_comm = mpi.mpi_comm_split(
        mpi.MPI_COMM_WORLD,
        Blockdata["main_node"] == Blockdata["myid_on_node"],
        Blockdata["myid_on_node"],
    )
    Blockdata["color"], Blockdata[
        "no_of_groups"], balanced_processor_load_on_nodes = sp_utilities.get_colors_and_subsets(
            Blockdata["main_node"],
            mpi.MPI_COMM_WORLD,
            Blockdata["myid"],
            Blockdata["shared_comm"],
            Blockdata["myid_on_node"],
            masters_from_groups_vs_everything_else_comm,
        )
    #  We need two nodes for processing of volumes
    Blockdata["node_volume"] = [
        Blockdata["no_of_groups"] - 3,
        Blockdata["no_of_groups"] - 2,
        Blockdata["no_of_groups"] - 1,
    ]  # For 3D stuff take three last nodes
    #  We need two CPUs for processing of volumes, they are taken to be main CPUs on each volume
    #  We have to send the two myids to all nodes so we can identify main nodes on two selected groups.
    Blockdata["nodes"] = [
        Blockdata["node_volume"][0] * Blockdata["no_of_processes_per_group"],
        Blockdata["node_volume"][1] * Blockdata["no_of_processes_per_group"],
        Blockdata["node_volume"][2] * Blockdata["no_of_processes_per_group"],
    ]
    # End of Blockdata: sorting requires at least three nodes, and the used number of nodes be integer times of three
    sp_global_def.BATCH = True
    sp_global_def.MPI = True

    if adjust_to_given_pw2:
        checking_flag = 0
        if Blockdata["myid"] == Blockdata["main_node"]:
            if not os.path.exists(options.pw_adjustment):
                checking_flag = 1
        checking_flag = sp_utilities.bcast_number_to_all(
            checking_flag, Blockdata["main_node"], mpi.MPI_COMM_WORLD)

        if checking_flag == 1:
            sp_global_def.ERROR("User provided power spectrum does not exist",
                                myid=Blockdata["myid"])

    Tracker = {}
    Constants = {}
    Constants["isac_dir"] = options.isac_dir
    Constants["masterdir"] = options.output_dir
    Constants["pixel_size"] = options.pixel_size
    Constants["orgstack"] = options.stack
    Constants["radius"] = options.radius
    Constants["xrange"] = options.xr
    Constants["FH"] = options.fh
    Constants["low_pass_filter"] = options.fl
    # Constants["maxit"]                        = options.maxit
    Constants["navg"] = options.navg
    Constants["B_start"] = options.B_start
    Constants["Bfactor"] = options.Bfactor

    if adjust_to_given_pw2:
        Constants["modelpw"] = options.pw_adjustment
    Tracker["constants"] = Constants
    # -------------------------------------------------------------
    #
    # Create and initialize Tracker dictionary with input options  # State Variables

    # <<<---------------------->>>imported functions<<<---------------------------------------------

    # x_range = max(Tracker["constants"]["xrange"], int(1./Tracker["ini_shrink"])+1)
    # y_range =  x_range

    ####-----------------------------------------------------------
    # Create Master directory and associated subdirectories
    line = time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()) + " =>"
    if Tracker["constants"]["masterdir"] == Tracker["constants"]["isac_dir"]:
        masterdir = os.path.join(Tracker["constants"]["isac_dir"], "sharpen")
    else:
        masterdir = Tracker["constants"]["masterdir"]

    if Blockdata["myid"] == Blockdata["main_node"]:
        msg = "Postprocessing ISAC 2D averages starts"
        sp_global_def.sxprint(line, "Postprocessing ISAC 2D averages starts")
        if not masterdir:
            timestring = time.strftime("_%d_%b_%Y_%H_%M_%S", time.localtime())
            masterdir = "sharpen_" + Tracker["constants"]["isac_dir"]
            os.makedirs(masterdir)
        else:
            if os.path.exists(masterdir):
                sp_global_def.sxprint("%s already exists" % masterdir)
            else:
                os.makedirs(masterdir)
        sp_global_def.write_command(masterdir)
        subdir_path = os.path.join(masterdir, "ali2d_local_params_avg")
        if not os.path.exists(subdir_path):
            os.mkdir(subdir_path)
        subdir_path = os.path.join(masterdir, "params_avg")
        if not os.path.exists(subdir_path):
            os.mkdir(subdir_path)
        li = len(masterdir)
    else:
        li = 0
    li = mpi.mpi_bcast(li, 1, mpi.MPI_INT, Blockdata["main_node"],
                       mpi.MPI_COMM_WORLD)[0]
    masterdir = mpi.mpi_bcast(masterdir, li, mpi.MPI_CHAR,
                              Blockdata["main_node"], mpi.MPI_COMM_WORLD)
    masterdir = b"".join(masterdir).decode('latin1')
    Tracker["constants"]["masterdir"] = masterdir
    log_main = sp_logger.Logger(sp_logger.BaseLogger_Files())
    log_main.prefix = Tracker["constants"]["masterdir"] + "/"

    while not os.path.exists(Tracker["constants"]["masterdir"]):
        sp_global_def.sxprint(
            "Node ",
            Blockdata["myid"],
            "  waiting...",
            Tracker["constants"]["masterdir"],
        )
        time.sleep(1)
    mpi.mpi_barrier(mpi.MPI_COMM_WORLD)

    if Blockdata["myid"] == Blockdata["main_node"]:
        init_dict = {}
        sp_global_def.sxprint(Tracker["constants"]["isac_dir"])
        Tracker["directory"] = os.path.join(Tracker["constants"]["isac_dir"],
                                            "2dalignment")
        core = sp_utilities.read_text_row(
            os.path.join(Tracker["directory"], "initial2Dparams.txt"))
        for im in range(len(core)):
            init_dict[im] = core[im]
        del core
    else:
        init_dict = 0
    init_dict = sp_utilities.wrap_mpi_bcast(init_dict,
                                            Blockdata["main_node"],
                                            communicator=mpi.MPI_COMM_WORLD)
    ###
    do_ctf = True
    if options.noctf:
        do_ctf = False
    if Blockdata["myid"] == Blockdata["main_node"]:
        if do_ctf:
            sp_global_def.sxprint("CTF correction is on")
        else:
            sp_global_def.sxprint("CTF correction is off")
        if options.local_alignment:
            sp_global_def.sxprint("local refinement is on")
        else:
            sp_global_def.sxprint("local refinement is off")
        if B_enhance:
            sp_global_def.sxprint("Bfactor is to be applied on averages")
        elif adjust_to_given_pw2:
            sp_global_def.sxprint(
                "PW of averages is adjusted to a given 1D PW curve")
        elif adjust_to_analytic_model:
            sp_global_def.sxprint(
                "PW of averages is adjusted to analytical model")
        else:
            sp_global_def.sxprint("PW of averages is not adjusted")
        # Tracker["constants"]["orgstack"] = "bdb:"+ os.path.join(Tracker["constants"]["isac_dir"],"../","sparx_stack")
        image = sp_utilities.get_im(Tracker["constants"]["orgstack"], 0)
        Tracker["constants"]["nnxo"] = image.get_xsize()
        if Tracker["constants"]["pixel_size"] == -1.0:
            sp_global_def.sxprint(
                "Pixel size value is not provided by user. extracting it from ctf header entry of the original stack."
            )
            try:
                ctf_params = image.get_attr("ctf")
                Tracker["constants"]["pixel_size"] = ctf_params.apix
            except:
                sp_global_def.ERROR(
                    "Pixel size could not be extracted from the original stack.",
                    myid=Blockdata["myid"],
                )
        ## Now fill in low-pass filter

        isac_shrink_path = os.path.join(Tracker["constants"]["isac_dir"],
                                        "README_shrink_ratio.txt")

        if not os.path.exists(isac_shrink_path):
            sp_global_def.ERROR(
                "%s does not exist in the specified ISAC run output directory"
                % (isac_shrink_path),
                myid=Blockdata["myid"],
            )

        isac_shrink_file = open(isac_shrink_path, "r")
        isac_shrink_lines = isac_shrink_file.readlines()
        isac_shrink_ratio = float(
            isac_shrink_lines[5]
        )  # 6th line: shrink ratio (= [target particle radius]/[particle radius]) used in the ISAC run
        isac_radius = float(
            isac_shrink_lines[6]
        )  # 7th line: particle radius at original pixel size used in the ISAC run
        isac_shrink_file.close()
        print("Extracted parameter values")
        print("ISAC shrink ratio    : {0}".format(isac_shrink_ratio))
        print("ISAC particle radius : {0}".format(isac_radius))
        Tracker["ini_shrink"] = isac_shrink_ratio
    else:
        Tracker["ini_shrink"] = 0.0
    Tracker = sp_utilities.wrap_mpi_bcast(Tracker,
                                          Blockdata["main_node"],
                                          communicator=mpi.MPI_COMM_WORLD)

    # print(Tracker["constants"]["pixel_size"], "pixel_size")
    x_range = max(
        Tracker["constants"]["xrange"],
        int(old_div(1.0, Tracker["ini_shrink"]) + 0.99999),
    )
    a_range = y_range = x_range

    if Blockdata["myid"] == Blockdata["main_node"]:
        parameters = sp_utilities.read_text_row(
            os.path.join(Tracker["constants"]["isac_dir"],
                         "all_parameters.txt"))
    else:
        parameters = 0
    parameters = sp_utilities.wrap_mpi_bcast(parameters,
                                             Blockdata["main_node"],
                                             communicator=mpi.MPI_COMM_WORLD)
    params_dict = {}
    list_dict = {}
    # parepare params_dict

    # navg = min(Tracker["constants"]["navg"]*Blockdata["nproc"], EMUtil.get_image_count(os.path.join(Tracker["constants"]["isac_dir"], "class_averages.hdf")))
    navg = min(
        Tracker["constants"]["navg"],
        EMAN2_cppwrap.EMUtil.get_image_count(
            os.path.join(Tracker["constants"]["isac_dir"],
                         "class_averages.hdf")),
    )
    global_dict = {}
    ptl_list = []
    memlist = []
    if Blockdata["myid"] == Blockdata["main_node"]:
        sp_global_def.sxprint("Number of averages computed in this run is %d" %
                              navg)
        for iavg in range(navg):
            params_of_this_average = []
            image = sp_utilities.get_im(
                os.path.join(Tracker["constants"]["isac_dir"],
                             "class_averages.hdf"),
                iavg,
            )
            members = sorted(image.get_attr("members"))
            memlist.append(members)
            for im in range(len(members)):
                abs_id = members[im]
                global_dict[abs_id] = [iavg, im]
                P = sp_utilities.combine_params2(
                    init_dict[abs_id][0],
                    init_dict[abs_id][1],
                    init_dict[abs_id][2],
                    init_dict[abs_id][3],
                    parameters[abs_id][0],
                    old_div(parameters[abs_id][1], Tracker["ini_shrink"]),
                    old_div(parameters[abs_id][2], Tracker["ini_shrink"]),
                    parameters[abs_id][3],
                )
                if parameters[abs_id][3] == -1:
                    sp_global_def.sxprint(
                        "WARNING: Image #{0} is an unaccounted particle with invalid 2D alignment parameters and should not be the member of any classes. Please check the consitency of input dataset."
                        .format(abs_id)
                    )  # How to check what is wrong about mirror = -1 (Toshio 2018/01/11)
                params_of_this_average.append([P[0], P[1], P[2], P[3], 1.0])
                ptl_list.append(abs_id)
            params_dict[iavg] = params_of_this_average
            list_dict[iavg] = members
            sp_utilities.write_text_row(
                params_of_this_average,
                os.path.join(
                    Tracker["constants"]["masterdir"],
                    "params_avg",
                    "params_avg_%03d.txt" % iavg,
                ),
            )
        ptl_list.sort()
        init_params = [None for im in range(len(ptl_list))]
        for im in range(len(ptl_list)):
            init_params[im] = [ptl_list[im]] + params_dict[global_dict[
                ptl_list[im]][0]][global_dict[ptl_list[im]][1]]
        sp_utilities.write_text_row(
            init_params,
            os.path.join(Tracker["constants"]["masterdir"],
                         "init_isac_params.txt"),
        )
    else:
        params_dict = 0
        list_dict = 0
        memlist = 0
    params_dict = sp_utilities.wrap_mpi_bcast(params_dict,
                                              Blockdata["main_node"],
                                              communicator=mpi.MPI_COMM_WORLD)
    list_dict = sp_utilities.wrap_mpi_bcast(list_dict,
                                            Blockdata["main_node"],
                                            communicator=mpi.MPI_COMM_WORLD)
    memlist = sp_utilities.wrap_mpi_bcast(memlist,
                                          Blockdata["main_node"],
                                          communicator=mpi.MPI_COMM_WORLD)
    # Now computing!
    del init_dict
    tag_sharpen_avg = 1000
    ## always apply low pass filter to B_enhanced images to suppress noise in high frequencies
    enforced_to_H1 = False
    if B_enhance:
        if Tracker["constants"]["low_pass_filter"] == -1.0:
            enforced_to_H1 = True

    # distribute workload among mpi processes
    image_start, image_end = sp_applications.MPI_start_end(
        navg, Blockdata["nproc"], Blockdata["myid"])

    if Blockdata["myid"] == Blockdata["main_node"]:
        cpu_dict = {}
        for iproc in range(Blockdata["nproc"]):
            local_image_start, local_image_end = sp_applications.MPI_start_end(
                navg, Blockdata["nproc"], iproc)
            for im in range(local_image_start, local_image_end):
                cpu_dict[im] = iproc
    else:
        cpu_dict = 0

    cpu_dict = sp_utilities.wrap_mpi_bcast(cpu_dict,
                                           Blockdata["main_node"],
                                           communicator=mpi.MPI_COMM_WORLD)

    slist = [None for im in range(navg)]
    ini_list = [None for im in range(navg)]
    avg1_list = [None for im in range(navg)]
    avg2_list = [None for im in range(navg)]
    data_list = [None for im in range(navg)]
    plist_dict = {}

    if Blockdata["myid"] == Blockdata["main_node"]:
        if B_enhance:
            sp_global_def.sxprint(
                "Avg ID   B-factor  FH1(Res before ali) FH2(Res after ali)")
        else:
            sp_global_def.sxprint(
                "Avg ID   FH1(Res before ali)  FH2(Res after ali)")

    FH_list = [[0, 0.0, 0.0] for im in range(navg)]
    for iavg in range(image_start, image_end):

        mlist = EMAN2_cppwrap.EMData.read_images(
            Tracker["constants"]["orgstack"], list_dict[iavg])

        for im in range(len(mlist)):
            sp_utilities.set_params2D(mlist[im],
                                      params_dict[iavg][im],
                                      xform="xform.align2d")

        if options.local_alignment:
            new_avg, plist, FH2 = sp_applications.refinement_2d_local(
                mlist,
                Tracker["constants"]["radius"],
                a_range,
                x_range,
                y_range,
                CTF=do_ctf,
                SNR=1.0e10,
            )
            plist_dict[iavg] = plist
            FH1 = -1.0

        else:
            new_avg, frc, plist = compute_average(
                mlist, Tracker["constants"]["radius"], do_ctf)
            FH1 = get_optimistic_res(frc)
            FH2 = -1.0

        FH_list[iavg] = [iavg, FH1, FH2]

        if B_enhance:
            new_avg, gb = apply_enhancement(
                new_avg,
                Tracker["constants"]["B_start"],
                Tracker["constants"]["pixel_size"],
                Tracker["constants"]["Bfactor"],
            )
            sp_global_def.sxprint("  %6d      %6.3f  %4.3f  %4.3f" %
                                  (iavg, gb, FH1, FH2))

        elif adjust_to_given_pw2:
            roo = sp_utilities.read_text_file(Tracker["constants"]["modelpw"],
                                              -1)
            roo = roo[0]  # always on the first column
            new_avg = adjust_pw_to_model(new_avg,
                                         Tracker["constants"]["pixel_size"],
                                         roo)
            sp_global_def.sxprint("  %6d      %4.3f  %4.3f  " %
                                  (iavg, FH1, FH2))

        elif adjust_to_analytic_model:
            new_avg = adjust_pw_to_model(new_avg,
                                         Tracker["constants"]["pixel_size"],
                                         None)
            sp_global_def.sxprint("  %6d      %4.3f  %4.3f   " %
                                  (iavg, FH1, FH2))

        elif no_adjustment:
            pass

        if Tracker["constants"]["low_pass_filter"] != -1.0:
            if Tracker["constants"]["low_pass_filter"] == 0.0:
                low_pass_filter = FH1
            elif Tracker["constants"]["low_pass_filter"] == 1.0:
                low_pass_filter = FH2
                if not options.local_alignment:
                    low_pass_filter = FH1
            else:
                low_pass_filter = Tracker["constants"]["low_pass_filter"]
                if low_pass_filter >= 0.45:
                    low_pass_filter = 0.45
            new_avg = sp_filter.filt_tanl(new_avg, low_pass_filter, 0.02)
        else:  # No low pass filter but if enforced
            if enforced_to_H1:
                new_avg = sp_filter.filt_tanl(new_avg, FH1, 0.02)
        if B_enhance:
            new_avg = sp_fundamentals.fft(new_avg)

        new_avg.set_attr("members", list_dict[iavg])
        new_avg.set_attr("n_objects", len(list_dict[iavg]))
        slist[iavg] = new_avg
        sp_global_def.sxprint(
            time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()) + " =>",
            "Refined average %7d" % iavg,
        )

    ## send to main node to write
    mpi.mpi_barrier(mpi.MPI_COMM_WORLD)

    for im in range(navg):
        # avg
        if (cpu_dict[im] == Blockdata["myid"]
                and Blockdata["myid"] != Blockdata["main_node"]):
            sp_utilities.send_EMData(slist[im], Blockdata["main_node"],
                                     tag_sharpen_avg)

        elif (cpu_dict[im] == Blockdata["myid"]
              and Blockdata["myid"] == Blockdata["main_node"]):
            slist[im].set_attr("members", memlist[im])
            slist[im].set_attr("n_objects", len(memlist[im]))
            slist[im].write_image(
                os.path.join(Tracker["constants"]["masterdir"],
                             "class_averages.hdf"),
                im,
            )

        elif (cpu_dict[im] != Blockdata["myid"]
              and Blockdata["myid"] == Blockdata["main_node"]):
            new_avg_other_cpu = sp_utilities.recv_EMData(
                cpu_dict[im], tag_sharpen_avg)
            new_avg_other_cpu.set_attr("members", memlist[im])
            new_avg_other_cpu.set_attr("n_objects", len(memlist[im]))
            new_avg_other_cpu.write_image(
                os.path.join(Tracker["constants"]["masterdir"],
                             "class_averages.hdf"),
                im,
            )

        if options.local_alignment:
            if cpu_dict[im] == Blockdata["myid"]:
                sp_utilities.write_text_row(
                    plist_dict[im],
                    os.path.join(
                        Tracker["constants"]["masterdir"],
                        "ali2d_local_params_avg",
                        "ali2d_local_params_avg_%03d.txt" % im,
                    ),
                )

            if (cpu_dict[im] == Blockdata["myid"]
                    and cpu_dict[im] != Blockdata["main_node"]):
                sp_utilities.wrap_mpi_send(plist_dict[im],
                                           Blockdata["main_node"],
                                           mpi.MPI_COMM_WORLD)
                sp_utilities.wrap_mpi_send(FH_list, Blockdata["main_node"],
                                           mpi.MPI_COMM_WORLD)

            elif (cpu_dict[im] != Blockdata["main_node"]
                  and Blockdata["myid"] == Blockdata["main_node"]):
                dummy = sp_utilities.wrap_mpi_recv(cpu_dict[im],
                                                   mpi.MPI_COMM_WORLD)
                plist_dict[im] = dummy
                dummy = sp_utilities.wrap_mpi_recv(cpu_dict[im],
                                                   mpi.MPI_COMM_WORLD)
                FH_list[im] = dummy[im]
        else:
            if (cpu_dict[im] == Blockdata["myid"]
                    and cpu_dict[im] != Blockdata["main_node"]):
                sp_utilities.wrap_mpi_send(FH_list, Blockdata["main_node"],
                                           mpi.MPI_COMM_WORLD)

            elif (cpu_dict[im] != Blockdata["main_node"]
                  and Blockdata["myid"] == Blockdata["main_node"]):
                dummy = sp_utilities.wrap_mpi_recv(cpu_dict[im],
                                                   mpi.MPI_COMM_WORLD)
                FH_list[im] = dummy[im]

        mpi.mpi_barrier(mpi.MPI_COMM_WORLD)
    mpi.mpi_barrier(mpi.MPI_COMM_WORLD)

    if options.local_alignment:
        if Blockdata["myid"] == Blockdata["main_node"]:
            ali3d_local_params = [None for im in range(len(ptl_list))]
            for im in range(len(ptl_list)):
                ali3d_local_params[im] = [ptl_list[im]] + plist_dict[
                    global_dict[ptl_list[im]][0]][global_dict[ptl_list[im]][1]]
            sp_utilities.write_text_row(
                ali3d_local_params,
                os.path.join(Tracker["constants"]["masterdir"],
                             "ali2d_local_params.txt"),
            )
            sp_utilities.write_text_row(
                FH_list,
                os.path.join(Tracker["constants"]["masterdir"], "FH_list.txt"))
    else:
        if Blockdata["myid"] == Blockdata["main_node"]:
            sp_utilities.write_text_row(
                FH_list,
                os.path.join(Tracker["constants"]["masterdir"], "FH_list.txt"))

    mpi.mpi_barrier(mpi.MPI_COMM_WORLD)
    target_xr = 3
    target_yr = 3
    if Blockdata["myid"] == 0:
        cmd = "{} {} {} {} {} {} {} {} {} {}".format(
            "sp_chains.py",
            os.path.join(Tracker["constants"]["masterdir"],
                         "class_averages.hdf"),
            os.path.join(Tracker["constants"]["masterdir"], "junk.hdf"),
            os.path.join(Tracker["constants"]["masterdir"],
                         "ordered_class_averages.hdf"),
            "--circular",
            "--radius=%d" % Tracker["constants"]["radius"],
            "--xr=%d" % (target_xr + 1),
            "--yr=%d" % (target_yr + 1),
            "--align",
            ">/dev/null",
        )
        junk = sp_utilities.cmdexecute(cmd)
        cmd = "{} {}".format(
            "rm -rf",
            os.path.join(Tracker["constants"]["masterdir"], "junk.hdf"))
        junk = sp_utilities.cmdexecute(cmd)

    return
Ejemplo n.º 13
0
def mref_ali2d_MPI(stack,
                   refim,
                   outdir,
                   maskfile=None,
                   ir=1,
                   ou=-1,
                   rs=1,
                   xrng=0,
                   yrng=0,
                   step=1,
                   center=1,
                   maxit=10,
                   CTF=False,
                   snr=1.0,
                   user_func_name="ref_ali2d",
                   rand_seed=1000):
    # 2D multi-reference alignment using rotational ccf in polar coordinates and quadratic interpolation

    from sp_utilities import model_circle, combine_params2, inverse_transform2, drop_image, get_image, get_im
    from sp_utilities import reduce_EMData_to_root, bcast_EMData_to_all, bcast_number_to_all
    from sp_utilities import send_attr_dict
    from sp_utilities import center_2D
    from sp_statistics import fsc_mask
    from sp_alignment import Numrinit, ringwe, search_range
    from sp_fundamentals import rot_shift2D, fshift
    from sp_utilities import get_params2D, set_params2D
    from random import seed, randint
    from sp_morphology import ctf_2
    from sp_filter import filt_btwl, filt_params
    from numpy import reshape, shape
    from sp_utilities import print_msg, print_begin_msg, print_end_msg
    import os
    import sys
    import shutil
    from sp_applications import MPI_start_end
    from mpi import mpi_comm_size, mpi_comm_rank, MPI_COMM_WORLD
    from mpi import mpi_reduce, mpi_bcast, mpi_barrier, mpi_recv, mpi_send
    from mpi import MPI_SUM, MPI_FLOAT, MPI_INT

    number_of_proc = mpi_comm_size(MPI_COMM_WORLD)
    myid = mpi_comm_rank(MPI_COMM_WORLD)
    main_node = 0

    # create the output directory, if it does not exist

    if os.path.exists(outdir):
        ERROR(
            'Output directory exists, please change the name and restart the program',
            "mref_ali2d_MPI ", 1, myid)
    mpi_barrier(MPI_COMM_WORLD)

    import sp_global_def
    if myid == main_node:
        os.mkdir(outdir)
        sp_global_def.LOGFILE = os.path.join(outdir, sp_global_def.LOGFILE)
        print_begin_msg("mref_ali2d_MPI")

    nima = EMUtil.get_image_count(stack)

    image_start, image_end = MPI_start_end(nima, number_of_proc, myid)

    nima = EMUtil.get_image_count(stack)
    ima = EMData()
    ima.read_image(stack, image_start)

    first_ring = int(ir)
    last_ring = int(ou)
    rstep = int(rs)
    max_iter = int(maxit)

    if max_iter == 0:
        max_iter = 10
        auto_stop = True
    else:
        auto_stop = False

    if myid == main_node:
        print_msg("Input stack                 : %s\n" % (stack))
        print_msg("Reference stack             : %s\n" % (refim))
        print_msg("Output directory            : %s\n" % (outdir))
        print_msg("Maskfile                    : %s\n" % (maskfile))
        print_msg("Inner radius                : %i\n" % (first_ring))

    nx = ima.get_xsize()
    # default value for the last ring
    if last_ring == -1: last_ring = nx / 2 - 2

    if myid == main_node:
        print_msg("Outer radius                : %i\n" % (last_ring))
        print_msg("Ring step                   : %i\n" % (rstep))
        print_msg("X search range              : %f\n" % (xrng))
        print_msg("Y search range              : %f\n" % (yrng))
        print_msg("Translational step          : %f\n" % (step))
        print_msg("Center type                 : %i\n" % (center))
        print_msg("Maximum iteration           : %i\n" % (max_iter))
        print_msg("CTF correction              : %s\n" % (CTF))
        print_msg("Signal-to-Noise Ratio       : %f\n" % (snr))
        print_msg("Random seed                 : %i\n\n" % (rand_seed))
        print_msg("User function               : %s\n" % (user_func_name))
    import sp_user_functions
    user_func = sp_user_functions.factory[user_func_name]

    if maskfile:
        import types
        if type(maskfile) is bytes: mask = get_image(maskfile)
        else: mask = maskfile
    else: mask = model_circle(last_ring, nx, nx)
    #  references, do them on all processors...
    refi = []
    numref = EMUtil.get_image_count(refim)

    # IMAGES ARE SQUARES! center is in SPIDER convention
    cnx = nx / 2 + 1
    cny = cnx

    mode = "F"
    #precalculate rings
    numr = Numrinit(first_ring, last_ring, rstep, mode)
    wr = ringwe(numr, mode)

    # prepare reference images on all nodes
    ima.to_zero()
    for j in range(numref):
        #  even, odd, numer of even, number of images.  After frc, totav
        refi.append([get_im(refim, j), ima.copy(), 0])
    #  for each node read its share of data
    data = EMData.read_images(stack, list(range(image_start, image_end)))
    for im in range(image_start, image_end):
        data[im - image_start].set_attr('ID', im)

    if myid == main_node: seed(rand_seed)

    a0 = -1.0
    again = True
    Iter = 0

    ref_data = [mask, center, None, None]

    while Iter < max_iter and again:
        ringref = []
        mashi = cnx - last_ring - 2
        for j in range(numref):
            refi[j][0].process_inplace("normalize.mask", {
                "mask": mask,
                "no_sigma": 1
            })  # normalize reference images to N(0,1)
            cimage = Util.Polar2Dm(refi[j][0], cnx, cny, numr, mode)
            Util.Frngs(cimage, numr)
            Util.Applyws(cimage, numr, wr)
            ringref.append(cimage)
            # zero refi
            refi[j][0].to_zero()
            refi[j][1].to_zero()
            refi[j][2] = 0

        assign = [[] for i in range(numref)]
        # begin MPI section
        for im in range(image_start, image_end):
            alpha, sx, sy, mirror, scale = get_params2D(data[im - image_start])
            #  Why inverse?  07/11/2015 PAP
            alphai, sxi, syi, scalei = inverse_transform2(alpha, sx, sy)
            # normalize
            data[im - image_start].process_inplace("normalize.mask", {
                "mask": mask,
                "no_sigma": 0
            })  # subtract average under the mask
            # If shifts are outside of the permissible range, reset them
            if (abs(sxi) > mashi or abs(syi) > mashi):
                sxi = 0.0
                syi = 0.0
                set_params2D(data[im - image_start], [0.0, 0.0, 0.0, 0, 1.0])
            ny = nx
            txrng = search_range(nx, last_ring, sxi, xrng, "mref_ali2d_MPI")
            txrng = [txrng[1], txrng[0]]
            tyrng = search_range(ny, last_ring, syi, yrng, "mref_ali2d_MPI")
            tyrng = [tyrng[1], tyrng[0]]
            # align current image to the reference
            [angt, sxst, syst, mirrort, xiref,
             peakt] = Util.multiref_polar_ali_2d(data[im - image_start],
                                                 ringref, txrng, tyrng, step,
                                                 mode, numr, cnx + sxi,
                                                 cny + syi)

            iref = int(xiref)
            # combine parameters and set them to the header, ignore previous angle and mirror
            [alphan, sxn, syn,
             mn] = combine_params2(0.0, -sxi, -syi, 0, angt, sxst, syst,
                                   (int)(mirrort))
            set_params2D(data[im - image_start],
                         [alphan, sxn, syn, int(mn), scale])
            data[im - image_start].set_attr('assign', iref)
            # apply current parameters and add to the average
            temp = rot_shift2D(data[im - image_start], alphan, sxn, syn, mn)
            it = im % 2
            Util.add_img(refi[iref][it], temp)
            assign[iref].append(im)
            #assign[im] = iref
            refi[iref][2] += 1.0
        del ringref
        # end MPI section, bring partial things together, calculate new reference images, broadcast them back

        for j in range(numref):
            reduce_EMData_to_root(refi[j][0], myid, main_node)
            reduce_EMData_to_root(refi[j][1], myid, main_node)
            refi[j][2] = mpi_reduce(refi[j][2], 1, MPI_FLOAT, MPI_SUM,
                                    main_node, MPI_COMM_WORLD)
            if (myid == main_node): refi[j][2] = int(refi[j][2][0])
        # gather assignements
        for j in range(numref):
            if myid == main_node:
                for n in range(number_of_proc):
                    if n != main_node:
                        import sp_global_def
                        ln = mpi_recv(1, MPI_INT, n,
                                      sp_global_def.SPARX_MPI_TAG_UNIVERSAL,
                                      MPI_COMM_WORLD)
                        lis = mpi_recv(ln[0], MPI_INT, n,
                                       sp_global_def.SPARX_MPI_TAG_UNIVERSAL,
                                       MPI_COMM_WORLD)
                        for l in range(ln[0]):
                            assign[j].append(int(lis[l]))
            else:
                import sp_global_def
                mpi_send(len(assign[j]), 1, MPI_INT, main_node,
                         sp_global_def.SPARX_MPI_TAG_UNIVERSAL, MPI_COMM_WORLD)
                mpi_send(assign[j], len(assign[j]), MPI_INT, main_node,
                         sp_global_def.SPARX_MPI_TAG_UNIVERSAL, MPI_COMM_WORLD)

        if myid == main_node:
            # replace the name of the stack with reference with the current one
            refim = os.path.join(outdir, "aqm%03d.hdf" % Iter)
            a1 = 0.0
            ave_fsc = []
            for j in range(numref):
                if refi[j][2] < 4:
                    #ERROR("One of the references vanished","mref_ali2d_MPI",1)
                    #  if vanished, put a random image (only from main node!) there
                    assign[j] = []
                    assign[j].append(
                        randint(image_start, image_end - 1) - image_start)
                    refi[j][0] = data[assign[j][0]].copy()
                    #print 'ERROR', j
                else:
                    #frsc = fsc_mask(refi[j][0], refi[j][1], mask, 1.0, os.path.join(outdir,"drm%03d%04d"%(Iter, j)))
                    from sp_statistics import fsc
                    frsc = fsc(
                        refi[j][0], refi[j][1], 1.0,
                        os.path.join(outdir, "drm%03d%04d.txt" % (Iter, j)))
                    Util.add_img(refi[j][0], refi[j][1])
                    Util.mul_scalar(refi[j][0], 1.0 / float(refi[j][2]))

                    if ave_fsc == []:
                        for i in range(len(frsc[1])):
                            ave_fsc.append(frsc[1][i])
                        c_fsc = 1
                    else:
                        for i in range(len(frsc[1])):
                            ave_fsc[i] += frsc[1][i]
                        c_fsc += 1
                    #print 'OK', j, len(frsc[1]), frsc[1][0:5], ave_fsc[0:5]

            #print 'sum', sum(ave_fsc)
            if sum(ave_fsc) != 0:
                for i in range(len(ave_fsc)):
                    ave_fsc[i] /= float(c_fsc)
                    frsc[1][i] = ave_fsc[i]

            for j in range(numref):
                ref_data[2] = refi[j][0]
                ref_data[3] = frsc
                refi[j][0], cs = user_func(ref_data)

                # write the current average
                TMP = []
                for i_tmp in range(len(assign[j])):
                    TMP.append(float(assign[j][i_tmp]))
                TMP.sort()
                refi[j][0].set_attr_dict({'ave_n': refi[j][2], 'members': TMP})
                del TMP
                refi[j][0].process_inplace("normalize.mask", {
                    "mask": mask,
                    "no_sigma": 1
                })
                refi[j][0].write_image(refim, j)

            Iter += 1
            msg = "ITERATION #%3d        %d\n\n" % (Iter, again)
            print_msg(msg)
            for j in range(numref):
                msg = "   group #%3d   number of particles = %7d\n" % (
                    j, refi[j][2])
                print_msg(msg)
        Iter = bcast_number_to_all(Iter, main_node)  # need to tell all
        if again:
            for j in range(numref):
                bcast_EMData_to_all(refi[j][0], myid, main_node)

    #  clean up
    del assign
    # write out headers  and STOP, under MPI writing has to be done sequentially (time-consumming)
    mpi_barrier(MPI_COMM_WORLD)
    if CTF and data_had_ctf == 0:
        for im in range(len(data)):
            data[im].set_attr('ctf_applied', 0)
    par_str = ['xform.align2d', 'assign', 'ID']
    if myid == main_node:
        from sp_utilities import file_type
        if (file_type(stack) == "bdb"):
            from sp_utilities import recv_attr_dict_bdb
            recv_attr_dict_bdb(main_node, stack, data, par_str, image_start,
                               image_end, number_of_proc)
        else:
            from sp_utilities import recv_attr_dict
            recv_attr_dict(main_node, stack, data, par_str, image_start,
                           image_end, number_of_proc)
    else:
        send_attr_dict(main_node, data, par_str, image_start, image_end)
    if myid == main_node:
        print_end_msg("mref_ali2d_MPI")
Ejemplo n.º 14
0
def main():
	from time import sleep
	from sp_logger import Logger, BaseLogger_Files
	arglist = []
	i = 0
	while( i < len(sys.argv) ):
		if sys.argv[i]=='-p4pg':
			i = i+2
		elif sys.argv[i]=='-p4wd':
			i = i+2
		else:
			arglist.append( sys.argv[i] )
			i = i+1
	progname = os.path.basename(arglist[0])
	usage = progname + " stack  outdir  <mask> --focus=3Dmask --radius=outer_radius --delta=angular_step" +\
	"--an=angular_neighborhood --maxit=max_iter  --CTF --sym=c1 --function=user_function --independent=indenpendent_runs  --number_of_images_per_group=number_of_images_per_group  --low_pass_filter=.25  --seed=random_seed"
	parser = OptionParser(usage,version=SPARXVERSION)
	parser.add_option("--focus",                         type="string",               default=None,              help="3D mask for focused clustering ")
	parser.add_option("--ir",                            type= "int",                 default= 1, 	             help="inner radius for rotational correlation > 0 (set to 1)")
	parser.add_option("--radius",                        type= "int",                 default=-1,	             help="outer radius for rotational correlation <nx-1 (set to the radius of the particle)")
	parser.add_option("--maxit",	                     type= "int",                 default=50, 	             help="maximum number of iteration")
	parser.add_option("--rs",                            type= "int",                 default=1,	             help="step between rings in rotational correlation >0 (set to 1)" ) 
	parser.add_option("--xr",                            type="string",               default='1',               help="range for translation search in x direction, search is +/-xr ")
	parser.add_option("--yr",                            type="string",               default='-1',	             help="range for translation search in y direction, search is +/-yr (default = same as xr)")
	parser.add_option("--ts",                            type="string",               default='0.25',            help="step size of the translation search in both directions direction, search is -xr, -xr+ts, 0, xr-ts, xr ")
	parser.add_option("--delta",                         type="string",               default='2',               help="angular step of reference projections")
	parser.add_option("--an",                            type="string",               default='-1',	             help="angular neighborhood for local searches")
	parser.add_option("--center",                        type="int",                  default=0,	             help="0 - if you do not want the volume to be centered, 1 - center the volume using cog (default=0)")
	parser.add_option("--nassign",                       type="int",                  default=1, 	             help="number of reassignment iterations performed for each angular step (set to 3) ")
	parser.add_option("--nrefine",                       type="int",                  default=0, 	             help="number of alignment iterations performed for each angular step (set to 1) ")
	parser.add_option("--CTF",                           action="store_true",         default=False,             help="Consider CTF correction during the alignment ")
	parser.add_option("--stoprnct",                      type="float",                default=3.0,               help="Minimum percentage of assignment change to stop the program")
	parser.add_option("--sym",                           type="string",               default='c1',              help="symmetry of the structure ")
	parser.add_option("--function",                      type="string",               default='do_volume_mrk05', help="name of the reference preparation function")
	parser.add_option("--independent",                   type="int",                  default= 3,                help="number of independent run")
	parser.add_option("--number_of_images_per_group",    type='int',                  default=1000,              help="number of images per groups")
	parser.add_option("--low_pass_filter",               type="float",                default=-1.0,              help="absolute frequency of low-pass filter for 3d sorting on the original image size" )
	parser.add_option("--nxinit",                        type="int",                  default=64,                help="initial image size for sorting" )
	parser.add_option("--unaccounted",                   action="store_true",         default=False,             help="reconstruct the unaccounted images")
	parser.add_option("--seed",                          type="int",                  default=-1,                help="random seed for create initial random assignment for EQ Kmeans")
	parser.add_option("--smallest_group",                type="int",                  default=500,               help="minimum members for identified group" )
	parser.add_option("--previous_run1",                 type="string",               default='',                help="two previous runs" )
	parser.add_option("--previous_run2",                 type="string",               default='',                help="two previous runs" )
	parser.add_option("--group_size_for_unaccounted",    type="int",                  default=500,               help="size for unaccounted particles" )
	parser.add_option("--chunkdir",                      type="string",               default='',                help="chunkdir for computing margin of error")
	parser.add_option("--sausage",                       action="store_true",         default=False,             help="way of filter volume")
	parser.add_option("--PWadjustment",                  type="string",               default='',                help="1-D power spectrum of PDB file used for EM volume power spectrum correction")
	parser.add_option("--protein_shape",                 type="string",               default='g',               help="protein shape. It defines protein preferred orientation angles. Currently it has g and f two types ")
	parser.add_option("--upscale",                       type="float",                default=0.5,               help=" scaling parameter to adjust the power spectrum of EM volumes")
	parser.add_option("--wn",                            type="int",                  default=0,                 help="optimal window size for data processing")
	parser.add_option("--interpolation",                 type="string",               default="4nn",             help="3-d reconstruction interpolation method, two options, trl and 4nn")

	(options, args) = parser.parse_args(arglist[1:])
	if len(args) < 1  or len(args) > 4:
    		sxprint("Usage: " + usage)
    		sxprint("Please run \'" + progname + " -h\' for detailed options")
    		ERROR( "Invalid number of parameters used. Please see usage information above." )
    		return
	else:

		if len(args)>2:
			mask_file = args[2]
		else:
			mask_file = None

		orgstack                        =args[0]
		masterdir                       =args[1]
		sp_global_def.BATCH = True

		#---initialize MPI related variables
		nproc     = mpi.mpi_comm_size( mpi.MPI_COMM_WORLD )
		myid      = mpi.mpi_comm_rank( mpi.MPI_COMM_WORLD )
		mpi_comm  = mpi.MPI_COMM_WORLD
		main_node = 0

		# Create the main log file
		from sp_logger import Logger,BaseLogger_Files
		if myid ==main_node:
			log_main=Logger(BaseLogger_Files())
			log_main.prefix=masterdir+"/"
		else:
			log_main = None
		#--- fill input parameters into dictionary named after Constants
		Constants		                         ={}
		Constants["stack"]                       =args[0]
		Constants["masterdir"]                   =masterdir
		Constants["mask3D"]                      =mask_file
		Constants["focus3Dmask"]                 =options.focus
		Constants["indep_runs"]                  =options.independent
		Constants["stoprnct"]                    =options.stoprnct
		Constants["number_of_images_per_group"]  =options.number_of_images_per_group
		Constants["CTF"]                 		 =options.CTF
		Constants["maxit"]               		 =options.maxit
		Constants["ir"]                  		 =options.ir 
		Constants["radius"]              		 =options.radius 
		Constants["nassign"]             		 =options.nassign
		Constants["rs"]                  		 =options.rs 
		Constants["xr"]                  		 =options.xr
		Constants["yr"]                  		 =options.yr
		Constants["ts"]                  		 =options.ts
		Constants["delta"]               		 =options.delta
		Constants["an"]                  		 =options.an
		Constants["sym"]                 		 =options.sym
		Constants["center"]              		 =options.center
		Constants["nrefine"]             		 =options.nrefine
		Constants["user_func"]           		 =options.function
		Constants["low_pass_filter"]     		 =options.low_pass_filter # enforced low_pass_filter
		Constants["main_log_prefix"]     		 =args[1]
		#Constants["importali3d"]        		 =options.importali3d
		Constants["myid"]	             		 =myid
		Constants["main_node"]           		 =main_node
		Constants["nproc"]               		 =nproc
		Constants["log_main"]            		 =log_main
		Constants["nxinit"]              		 =options.nxinit
		Constants["unaccounted"]         		 =options.unaccounted
		Constants["seed"]                		 =options.seed
		Constants["smallest_group"]      		 =options.smallest_group
		Constants["previous_runs"]       		 =options.previous_run1+" "+options.previous_run2
		Constants["sausage"]             		 =options.sausage
		Constants["chunkdir"]            		 =options.chunkdir 
		Constants["PWadjustment"]        		 =options.PWadjustment
		Constants["upscale"]             		 =options.upscale
		Constants["wn"]                  		 =options.wn
		Constants["3d-interpolation"]    		 =options.interpolation 
		Constants["protein_shape"]       		 =options.protein_shape  
		#Constants["frequency_stop_search"] 	 =options.frequency_stop_search
		#Constants["scale_of_number"]    	     =options.scale_of_number
		# -------------------------------------------------------------
		#
		# Create and initialize Tracker dictionary with input options
		Tracker                   = {}
		Tracker["constants"]      =	Constants
		Tracker["maxit"]          = Tracker["constants"]["maxit"]
		Tracker["radius"]         = Tracker["constants"]["radius"]
		#Tracker["xr"]            = ""
		#Tracker["yr"]            = "-1"  # Do not change!
		#Tracker["ts"]            = 1
		#Tracker["an"]            = "-1"
		#Tracker["delta"]         = "2.0"
		#Tracker["zoom"]          = True
		#Tracker["nsoft"]         = 0
		#Tracker["local"]         = False
		Tracker["PWadjustment"]   = Tracker["constants"]["PWadjustment"]
		Tracker["upscale"]        = Tracker["constants"]["upscale"]
		Tracker["applyctf"]       = False  #  Should the data be premultiplied by the CTF.  Set to False for local continuous.
		#Tracker["refvol"]        = None
		Tracker["nxinit"]         = Tracker["constants"]["nxinit"]
		#Tracker["nxstep"]        = 32
		Tracker["icurrentres"]    = -1
		#Tracker["ireachedres"]   = -1
		Tracker["lowpass"]        = Tracker["constants"]["low_pass_filter"]
		Tracker["falloff"]        = 0.1
		#Tracker["inires"]        = options.inires  # Now in A, convert to absolute before using
		Tracker["fuse_freq"]      = 50  # Now in A, convert to absolute before using
		#Tracker["delpreviousmax"]= False
		#Tracker["anger"]         = -1.0
		#Tracker["shifter"]       = -1.0
		#Tracker["saturatecrit"]  = 0.95
		#Tracker["pixercutoff"]   = 2.0
		#Tracker["directory"]     = ""
		#Tracker["previousoutputdir"] = ""
		#Tracker["eliminated-outliers"] = False
		#Tracker["mainiteration"]       = 0
		#Tracker["movedback"]           = False
		#Tracker["state"]               = Tracker["constants"]["states"][0] 
		#Tracker["global_resolution"]   = 0.0
		Tracker["orgstack"]             = orgstack
		#--------------------------------------------------------------------
		
		# import from utilities
		from sp_utilities import sample_down_1D_curve,get_initial_ID,remove_small_groups,print_upper_triangular_matrix,print_a_line_with_timestamp
		from sp_utilities import convertasi,prepare_ptp,print_dict,get_resolution_mrk01,partition_to_groups,partition_independent_runs,get_outliers
		from sp_utilities import merge_groups, save_alist, margin_of_error, get_margin_of_error, do_two_way_comparison, select_two_runs, get_ali3d_params
		from sp_utilities import counting_projections, unload_dict, load_dict, get_stat_proj, create_random_list, get_number_of_groups, recons_mref
		from sp_utilities import apply_low_pass_filter, get_groups_from_partition, get_number_of_groups, get_complementary_elements_total, update_full_dict
		from sp_utilities import count_chunk_members, set_filter_parameters_from_adjusted_fsc, get_two_chunks_from_stack
		####------------------------------------------------------------------	
		
		# another part
		from sp_utilities import get_class_members, remove_small_groups, get_number_of_groups, get_stable_members_from_two_runs
		from sp_utilities import two_way_comparison_single, get_leftover_from_stable, get_initial_ID, Kmeans_exhaustive_run
		from sp_utilities import print_a_line_with_timestamp, split_a_group
		
		#
		# Get the pixel size; if none, set to 1.0, and the original image size
		from sp_utilities import get_shrink_data_huang
		from time import sleep
		import sp_user_functions
		user_func = sp_user_functions.factory[Tracker["constants"]["user_func"]]
		if(myid == main_node):
			line = ''
			sxprint((line+"Initialization of 3-D sorting"))
			a = get_im(Tracker["orgstack"])
			nnxo = a.get_xsize()
			if( Tracker["nxinit"] > nnxo ):
				ERROR( "Image size less than minimum permitted $d"%Tracker["nxinit"] )
				nnxo = -1 # we break here, so not sure what this is supposed to accomplish
				return
			else:
				if Tracker["constants"]["CTF"]:
					i = a.get_attr('ctf')
					pixel_size = i.apix
					fq = pixel_size/Tracker["fuse_freq"]
				else:
					pixel_size = 1.0
					#  No pixel size, fusing computed as 5 Fourier pixels
					fq = 5.0/nnxo
					del a
		else:
			nnxo = 0
			fq   = 0.0
			pixel_size = 1.0
		nnxo = bcast_number_to_all(nnxo, source_node = main_node)
		if( nnxo < 0 ):
			return
		pixel_size                           = bcast_number_to_all(pixel_size, source_node = main_node)
		fq                                   = bcast_number_to_all(fq, source_node = main_node)
		if Tracker["constants"]["wn"]==0:
			Tracker["constants"]["nnxo"] = nnxo
		else:
			Tracker["constants"]["nnxo"] = Tracker["constants"]["wn"]
			nnxo= Tracker["constants"]["wn"]
		Tracker["constants"]["pixel_size"]   = pixel_size
		Tracker["fuse_freq"]                 = fq
		
		del fq, nnxo, pixel_size

		if(Tracker["constants"]["radius"]  < 1):
			Tracker["constants"]["radius"]  = Tracker["constants"]["nnxo"]//2-2

		elif((2*Tracker["constants"]["radius"] +2) > Tracker["constants"]["nnxo"]):
			ERROR( "Particle radius set too large!", myid=myid )
			return
####-----------------------------------------------------------------------------------------
		# create the master directory
		if myid == main_node:
			if masterdir =="":
				timestring = strftime("_%d_%b_%Y_%H_%M_%S", localtime())
				masterdir ="master_sort3d"+timestring
				li =len(masterdir)
			else:
				li = 0
			cmd="{} {}".format("mkdir -p", masterdir)
			os.system(cmd)			
			sp_global_def.write_command(masterdir)
		else:
			li=0
		li = mpi.mpi_bcast( li, 1, mpi.MPI_INT, main_node, mpi.MPI_COMM_WORLD )[0]
		if li>0:
			masterdir = mpi.mpi_bcast( masterdir, li,MPI_CHAR, main_node, mpi.MPI_COMM_WORLD )
			masterdir = string.join(masterdir,"")
		####--- masterdir done!
		if myid == main_node:
			print_dict(Tracker["constants"],"Permanent settings of 3-D sorting program")
		from time import sleep
		while not os.path.exists(masterdir):  # Be sure each proc is able to access the created dir
				sxprint("Node ",myid,"  waiting...")
				sleep(5)
		mpi.mpi_barrier(mpi.MPI_COMM_WORLD)
		######### create a vstack from input stack to the local stack in masterdir
		# stack name set to default
		Tracker["constants"]["stack"]          = "bdb:"+masterdir+"/rdata"
		Tracker["constants"]["ali3d"]          = os.path.join(masterdir, "ali3d_init.txt")
		Tracker["constants"]["partstack"]      = Tracker["constants"]["ali3d"]
		Tracker["constants"]["ctf_params"]     = os.path.join(masterdir, "ctf_params.txt")
		######
		if myid == main_node:
			if(Tracker["orgstack"][:4] == "bdb:"):     cmd = "{} {} {}".format("e2bdb.py", Tracker["orgstack"],"--makevstack="+Tracker["constants"]["stack"])
			else:  cmd = "{} {} {}".format("sp_cpy.py", orgstack, Tracker["constants"]["stack"])
			cmdexecute(cmd)
			cmd = "{} {} {} {} ".format("sp_header.py", Tracker["constants"]["stack"],"--params=xform.projection","--export="+Tracker["constants"]["ali3d"])
			cmdexecute(cmd)
			cmd = "{} {} {} {} ".format("sp_header.py", Tracker["constants"]["stack"],"--params=ctf","--export="+Tracker["constants"]["ctf_params"])
			cmdexecute(cmd)
			#keepchecking = False
			total_stack = EMUtil.get_image_count(Tracker["orgstack"])
		else:
			total_stack =0
		total_stack = bcast_number_to_all(total_stack, source_node = main_node)
		"""
		if myid==main_node:
	   		from EMAN2db import db_open_dict	
	   		OB = db_open_dict(orgstack)
	   		DB = db_open_dict(Tracker["constants"]["stack"]) 
			for i in xrange(total_stack):
				DB[i] = OB[i]
			OB.close()
			DB.close()
	   	mpi.mpi_barrier( mpi.MPI_COMM_WORLD )
	   	if myid==main_node:
			params= []
			for i in xrange(total_stack):
				e=get_im(orgstack,i)
				phi,theta,psi,s2x,s2y = get_params_proj(e)
				params.append([phi,theta,psi,s2x,s2y])
			write_text_row(params,Tracker["constants"]["ali3d"])
		mpi.mpi_barrier( mpi.MPI_COMM_WORLD )
		"""
		#Tracker["total_stack"]             = total_stack
		Tracker["constants"]["total_stack"] = total_stack
		Tracker["shrinkage"]                = float(Tracker["nxinit"])/Tracker["constants"]["nnxo"]
		#####------------------------------------------------------------------------------
		if Tracker["constants"]["mask3D"]:
			Tracker["mask3D"] = os.path.join(masterdir,"smask.hdf")
		else:Tracker["mask3D"] = None
		if Tracker["constants"]["focus3Dmask"]:  Tracker["focus3D"]=os.path.join(masterdir,"sfocus.hdf")
		else:                                    Tracker["focus3D"] = None
		if myid ==main_node:
			if Tracker["constants"]["mask3D"]:
				get_shrink_3dmask(Tracker["nxinit"],Tracker["constants"]["mask3D"]).write_image(Tracker["mask3D"])
			if Tracker["constants"]["focus3Dmask"]:
				mask_3D = get_shrink_3dmask(Tracker["nxinit"],Tracker["constants"]["focus3Dmask"])
				st = Util.infomask(mask_3D, None, True)
				
				if( st[0] == 0.0 ):  
					ERROR( "sxrsort3d","incorrect focused mask, after binarize all values zero" )

				mask_3D.write_image(Tracker["focus3D"])
				del mask_3D
		if Tracker["constants"]["PWadjustment"]:
			PW_dict={}
			nxinit_pwsp=sample_down_1D_curve(Tracker["constants"]["nxinit"],Tracker["constants"]["nnxo"],Tracker["constants"]["PWadjustment"])
			Tracker["nxinit_PW"] = os.path.join(masterdir,"spwp.txt")
			if myid ==main_node:
				write_text_file(nxinit_pwsp,Tracker["nxinit_PW"])
			PW_dict[Tracker["constants"]["nnxo"]]   =Tracker["constants"]["PWadjustment"]
			PW_dict[Tracker["constants"]["nxinit"]] =Tracker["nxinit_PW"]
			Tracker["PW_dict"] = PW_dict 
		###----------------------------------------------------------------------------------		
		####---------------------------  Extract the previous results   #####################################################
		from random import shuffle
		if myid ==main_node:
			log_main.add(" Sphire rsort3d ")
			log_main.add("extract stable groups from two previous runs")
			stable_member_list                              = get_stable_members_from_two_runs(Tracker["constants"]["previous_runs"], Tracker["constants"]["total_stack"], log_main)
			Tracker["this_unaccounted_list"], new_stable_P1 = get_leftover_from_stable(stable_member_list, Tracker["constants"]["total_stack"], Tracker["constants"]["smallest_group"])
			Tracker["this_unaccounted_list"].sort()
			Tracker["total_stack"] = len(Tracker["this_unaccounted_list"])
			log_main.add("new stable is %d"%len(new_stable_P1))
		else:
			Tracker["total_stack"]           = 0
			Tracker["this_unaccounted_list"] = 0
			stable_member_list =0
		stable_member_list               = wrap_mpi_bcast(stable_member_list, main_node)
		Tracker["total_stack"]           = bcast_number_to_all(Tracker["total_stack"], source_node = main_node)
		left_one_from_old_two_runs       = wrap_mpi_bcast(Tracker["this_unaccounted_list"], main_node)
		if myid ==main_node:  
			write_text_file(left_one_from_old_two_runs, os.path.join(masterdir,"unaccounted_from_two_previous_runs.txt"))
			sxprint(" Extracting results of two previous runs is done!")
		#################################### Estimate resolution----------------------############# 

		#### make chunkdir dictionary for computing margin of error
		chunk_list = []
		if Tracker["constants"]["chunkdir"] !="": ##inhere previous random assignment of odd and even
			if myid == main_node:
				chunk_one = read_text_file(os.path.join(Tracker["constants"]["chunkdir"],"chunk0.txt"))
				chunk_two = read_text_file(os.path.join(Tracker["constants"]["chunkdir"],"chunk1.txt"))
			else:
				chunk_one = 0
				chunk_two = 0
			chunk_one = wrap_mpi_bcast(chunk_one, main_node)
			chunk_two = wrap_mpi_bcast(chunk_two, main_node)
		else:  ## if traces are lost, then creating new random assignment of odd, even particles
			chunks = list(range(Tracker["constants"]["total_stack"]))
			shuffle(chunks)
			chunk_one =chunks[0:Tracker["constants"]["total_stack"]//2]
			chunk_two =chunks[Tracker["constants"]["total_stack"]//2:Tracker["constants"]["total_stack"]]
			chunk_one = wrap_mpi_bcast(chunk_one, main_node)	
			chunk_two = wrap_mpi_bcast(chunk_two, main_node)	
				
		###### Fill chunk ID into headers when calling get_shrink_data_huang
		if myid ==main_node:
			sxprint(" random odd and even assignment done  !")
		mpi.mpi_barrier( mpi.MPI_COMM_WORLD )
		#------------------------------------------------------------------------------
		Tracker["chunk_dict"] = {}
		for element in chunk_one: Tracker["chunk_dict"][element] = 0
		for element in chunk_two: Tracker["chunk_dict"][element] = 1
		Tracker["P_chunk0"]   = len(chunk_one)/float(Tracker["constants"]["total_stack"])
		Tracker["P_chunk1"]   = len(chunk_two)/float(Tracker["constants"]["total_stack"])
		### create two volumes to estimate resolution
		if myid == main_node:
			write_text_file(chunk_one, os.path.join(masterdir,"chunk0.txt"))
			write_text_file(chunk_two, os.path.join(masterdir,"chunk1.txt"))
		mpi.mpi_barrier( mpi.MPI_COMM_WORLD )
		vols = []
		for index in range(2):
			data1,old_shifts1 = get_shrink_data_huang(Tracker,Tracker["constants"]["nxinit"], os.path.join(masterdir,"chunk%d.txt"%index), Tracker["constants"]["partstack"], myid, main_node, nproc, preshift = True)
			vol1 = recons3d_4nn_ctf_MPI(myid=myid, prjlist=data1, symmetry=Tracker["constants"]["sym"], finfo=None)
			if myid ==main_node:
				vol1_file_name = os.path.join(masterdir, "vol%d.hdf"%index)
				vol1.write_image(vol1_file_name)
			
			vols.append(vol1)
			mpi.mpi_barrier( mpi.MPI_COMM_WORLD )
		if myid ==main_node:
			low_pass, falloff, currentres = get_resolution_mrk01(vols, Tracker["constants"]["radius"]*Tracker["shrinkage"], Tracker["constants"]["nxinit"], masterdir,Tracker["mask3D"])
			if low_pass   > Tracker["constants"]["low_pass_filter"]: low_pass  = Tracker["constants"]["low_pass_filter"]
		else:
			low_pass   = 0.0
			falloff    = 0.0
			currentres = 0.0
		currentres                    = bcast_number_to_all(currentres,source_node = main_node)
		low_pass                      = bcast_number_to_all(low_pass,source_node   = main_node)
		falloff                       = bcast_number_to_all(falloff,source_node    = main_node)
		Tracker["currentres"]         = currentres
		
		####################################################################
		
		Tracker["falloff"] = falloff
		if Tracker["constants"]["low_pass_filter"] == -1.0:
			Tracker["low_pass_filter"] = low_pass*Tracker["shrinkage"]
		else:
			Tracker["low_pass_filter"] = Tracker["constants"]["low_pass_filter"]/Tracker["shrinkage"]
		Tracker["lowpass"]             = Tracker["low_pass_filter"]
		Tracker["falloff"]             = 0.1
		Tracker["global_fsc"]          = os.path.join(masterdir,"fsc.txt")
		##################################################################
		if myid ==main_node:
			log_main.add("The command-line inputs are :")
			log_main.add("**********************************************************")
			for a in sys.argv: 
					log_main.add(a)
			log_main.add("**********************************************************")
		from sp_filter import filt_tanl
		##################### START 3-D sorting ##########################
		if myid ==main_node:
			log_main.add("----------3-D sorting  program------- ")
			log_main.add("current resolution %6.3f for images of original size in terms of absolute frequency"%Tracker["currentres"])
			log_main.add("equivalent to %f Angstrom resolution"%(round((Tracker["constants"]["pixel_size"]/Tracker["currentres"]/Tracker["shrinkage"]),4)))
			filt_tanl(get_im(os.path.join(masterdir, "vol0.hdf")), Tracker["low_pass_filter"], 0.1).write_image(os.path.join(masterdir, "volf0.hdf"))			
			filt_tanl(get_im(os.path.join(masterdir, "vol1.hdf")), Tracker["low_pass_filter"], 0.1).write_image(os.path.join(masterdir, "volf1.hdf"))
			sxprint(" random odd and even assignment done  !")
		mpi.mpi_barrier( mpi.MPI_COMM_WORLD )
		## ---------------------------------------------------------------------------------------------########
		## Stop program and output results when the leftover from two sort3d runs is not sufficient for a new run    		########
		## ---------------------------------------------------  ---------------------------------------  ######
		Tracker["number_of_groups"] = get_number_of_groups(len(left_one_from_old_two_runs), Tracker["constants"]["number_of_images_per_group"])
		if Tracker["number_of_groups"] <=1 : # programs finishes 
			if myid == main_node:
				log_main.add("the unaccounted ones are no sufficient for a simple two-group run, output results!")
				log_main.add("this implies your two sort3d runs already achieved high reproducibale ratio. ")
				log_main.add("Or your number_of_images_per_group is too large ")
				log_main.add("the final reproducibility is  %f"%((Tracker["constants"]["total_stack"]-len(Tracker["this_unaccounted_list"]))/float(Tracker["constants"]["total_stack"])))
				for i in range(len(stable_member_list)): write_text_file(stable_member_list[i], os.path.join(masterdir,"P2_final_class%d.txt"%i))
				mask3d = get_im(Tracker["constants"]["mask3D"])
			else:
				mask3d = model_blank(Tracker["constants"]["nnxo"],Tracker["constants"]["nnxo"],Tracker["constants"]["nnxo"])
			bcast_EMData_to_all(mask3d, myid, main_node)
			for igrp in range(len(stable_member_list)):
				#name_of_class_file = os.path.join(masterdir, "P2_final_class%d.txt"%igrp)
				data, old_shifts = get_shrink_data_huang(Tracker,Tracker["constants"]["nnxo"], os.path.join(masterdir, "P2_final_class%d.txt"%igrp), Tracker["constants"]["partstack"], myid, main_node, nproc,preshift = True)
				if Tracker["constants"]["CTF"]:  
					volref, fscc = rec3D_two_chunks_MPI(data, 1.0, Tracker["constants"]["sym"], mask3d,os.path.join(masterdir,"resolution_%02d.txt"%igrp), myid, main_node, index =-1, npad=2)
				else: 
					sxprint("Missing CTF flag!")
					return
				mpi.mpi_barrier( mpi.MPI_COMM_WORLD )

				#nx_of_image=volref.get_xsize()
				if Tracker["constants"]["PWadjustment"] :		Tracker["PWadjustment"] = Tracker["PW_dict"][Tracker["constants"]["nnxo"]]
				else:											Tracker["PWadjustment"] = Tracker["constants"]["PWadjustment"]	
				if myid ==main_node:
					try: 
						lowpass = search_lowpass(fscc)
						falloff = 0.1
					except:
						lowpass = 0.4
						falloff = 0.1
						log_main.add(" lowpass and falloff from fsc are %f %f"%(lowpass, falloff))
					lowpass = round(lowpass,4)
					falloff = round(min(0.1,falloff),4)
					Tracker["lowpass"] = lowpass
					Tracker["falloff"] = falloff
					refdata            = [None]*4
					refdata[0]         = volref
					refdata[1]         = Tracker
					refdata[2]         = Tracker["constants"]["myid"]
					refdata[3]         = Tracker["constants"]["nproc"]
					volref             = user_func(refdata)
					cutoff = Tracker["constants"]["pixel_size"]/lowpass
					log_main.add("%d vol low pass filer %f   %f  cut to  %f Angstrom"%(igrp,Tracker["lowpass"],Tracker["falloff"],cutoff))
					volref.write_image(os.path.join(masterdir,"volf_final%d.hdf"%igrp))
			mpi.mpi_barrier( mpi.MPI_COMM_WORLD )
			return
		else: # Continue clustering on unaccounted ones that produced by two_way comparison of two previous runs
			#########################################################################################################################
			#if Tracker["constants"]["number_of_images_per_group"] ==-1: # Estimate number of images per group from delta, and scale up 
			#    or down by scale_of_number
			#	number_of_images_per_group = int(Tracker["constants"]["scale_of_number"]*len(n_angles))
			#
			#########################################################################################################################P2
			if myid ==main_node:
				sxprint(" Now continue clustering on accounted ones because they can make at least two groups!")
			P2_partitions        = []
			number_of_P2_runs    = 2  # Notice P2 start from two P1 runs
			### input list_to_be_processed
			import copy
			mpi.mpi_barrier( mpi.MPI_COMM_WORLD )
			for iter_P2_run in range(number_of_P2_runs): # two runs such that one can obtain reproducibility
				list_to_be_processed = left_one_from_old_two_runs[:]#Tracker["this_unaccounted_list"][:]
				Tracker["this_unaccounted_list"] = left_one_from_old_two_runs[:]
				if myid == main_node :    new_stable1 =  new_stable_P1[:]
				total_stack   = len(list_to_be_processed) # This is the input from two P1 runs
				#number_of_images_per_group = Tracker["constants"]["number_of_images_per_group"]
				P2_run_dir = os.path.join(masterdir, "P2_run%d"%iter_P2_run)
				Tracker["number_of_groups"] = get_number_of_groups(total_stack, Tracker["constants"]["number_of_images_per_group"])
				if myid == main_node:
					cmd="{} {}".format("mkdir", P2_run_dir)
					os.system(cmd)
					log_main.add("----------------P2 independent run %d--------------"%iter_P2_run)
					log_main.add("user provided number_of_images_per_group %d"%Tracker["constants"]["number_of_images_per_group"])
					sxprint("----------------P2 independent run %d--------------"%iter_P2_run)
				mpi.mpi_barrier( mpi.MPI_COMM_WORLD )
				#
				#Tracker["number_of_groups"] = get_number_of_groups(total_stack,Tracker["constants"]["number_of_images_per_group"])
				generation                  = 0
				
				if myid == main_node:
					log_main.add("number of groups is %d"%Tracker["number_of_groups"])
					log_main.add("total stack %d"%total_stack)
				while( Tracker["number_of_groups"]>=2 ):
					partition_dict      = {}
					full_dict           = {}
					workdir             = os.path.join(P2_run_dir,"generation%03d"%generation)
					Tracker["this_dir"] = workdir
					
					if myid ==main_node:
						cmd="{} {}".format("mkdir", workdir)
						os.system(cmd)
						log_main.add("---- generation         %5d"%generation)
						log_main.add("number of images per group is set as %d"%Tracker["constants"]["number_of_images_per_group"])
						log_main.add("the initial number of groups is  %d "%Tracker["number_of_groups"])
						log_main.add(" the number to be processed in this generation is %d"%len(list_to_be_processed))
						sxprint("---- generation         %5d"%generation)
						#core=read_text_row(Tracker["constants"]["ali3d"],-1)
						#write_text_row(core, os.path.join(workdir,"node%d.txt"%myid))
					mpi.mpi_barrier( mpi.MPI_COMM_WORLD )
					Tracker["this_data_list"]         = list_to_be_processed # leftover of P1 runs
					Tracker["total_stack"]            = len(list_to_be_processed)
					create_random_list(Tracker)
					
					###------ For super computer    ##############
					update_full_dict(list_to_be_processed, Tracker)
					###----
					##### ----------------Independent runs for EQ-Kmeans  ------------------------------------
					for indep_run in range(Tracker["constants"]["indep_runs"]):
						Tracker["this_particle_list"] = Tracker["this_indep_list"][indep_run]
						ref_vol = recons_mref(Tracker)
						if myid ==main_node:   log_main.add("independent run  %10d"%indep_run)
						mpi.mpi_barrier( mpi.MPI_COMM_WORLD )
						#this_particle_text_file =  # for get_shrink_data
						if myid ==main_node:
							write_text_file(list_to_be_processed, os.path.join(workdir, "independent_list_%03d.txt"%indep_run))
						mref_ali3d_EQ_Kmeans(ref_vol, os.path.join(workdir, "EQ_Kmeans%03d"%indep_run), os.path.join(workdir, "independent_list_%03d.txt"%indep_run), Tracker)
						partition_dict[indep_run] = Tracker["this_partition"]
						del ref_vol
					Tracker["partition_dict"]    = partition_dict
					Tracker["this_total_stack"]  = Tracker["total_stack"]
					do_two_way_comparison(Tracker)
					##############################
					
					if myid ==main_node: log_main.add("Now calculate stable volumes")
					if myid ==main_node:
						for igrp in range(len(Tracker["two_way_stable_member"])):
							Tracker["this_data_list"]      = Tracker["two_way_stable_member"][igrp]
							write_text_file(Tracker["this_data_list"], os.path.join(workdir,"stable_class%d.txt"%igrp))
					Tracker["this_data_list_file"] = -1
					mpi.mpi_barrier( mpi.MPI_COMM_WORLD )
					###
					number_of_ref_class = []
					ref_vol_list = []
					for igrp in range(len(Tracker["two_way_stable_member"])):
						data, old_shifts = get_shrink_data_huang(Tracker,Tracker["nxinit"], os.path.join(workdir, "stable_class%d.txt"%igrp), Tracker["constants"]["partstack"], myid, main_node, nproc, preshift = True)
						volref           = recons3d_4nn_ctf_MPI(myid=myid,prjlist=data,symmetry=Tracker["constants"]["sym"],finfo = None)
						ref_vol_list.append(volref)
						number_of_ref_class.append(len(Tracker["this_data_list"]))
					if myid ==main_node:
						log_main.add("group  %d  members %d "%(igrp,len(Tracker["this_data_list"])))	
						#ref_vol_list=apply_low_pass_filter(ref_vol_list,Tracker)
						for iref in range(len(ref_vol_list)): ref_vol_list[iref].write_image(os.path.join(workdir,"vol_stable.hdf"),iref)
						
					mpi.mpi_barrier( mpi.MPI_COMM_WORLD )
					################################
					
					Tracker["number_of_ref_class"]       = number_of_ref_class
					Tracker["this_data_list"]            = Tracker["this_accounted_list"]
					outdir = os.path.join(workdir, "Kmref")
					empty_groups,res_classes,final_list  = ali3d_mref_Kmeans_MPI(ref_vol_list, outdir, os.path.join(workdir,"Accounted.txt"), Tracker)
					Tracker["this_unaccounted_list"]     = get_complementary_elements(list_to_be_processed,final_list)
					if myid == main_node:  log_main.add("the number of particles not processed is %d"%len(Tracker["this_unaccounted_list"]))
					update_full_dict(Tracker["this_unaccounted_list"], Tracker)
					if myid == main_node: write_text_file(Tracker["this_unaccounted_list"], Tracker["this_unaccounted_text"])
					Tracker["number_of_groups"]          = len(res_classes)
					### Update data
					mpi.mpi_barrier( mpi.MPI_COMM_WORLD )
					if myid == main_node:
						number_of_ref_class=[]
						log_main.add(" Compute volumes of original size")
						for igrp in range(Tracker["number_of_groups"]):
							if os.path.exists( os.path.join( outdir,"Class%d.txt"%igrp ) ):
								new_stable1.append( read_text_file( os.path.join( outdir, "Class%d.txt"%igrp ) ) )
								log_main.add(" read Class file %d"%igrp)
								number_of_ref_class.append(len(new_stable1))
					else:  number_of_ref_class = 0
					number_of_ref_class = wrap_mpi_bcast(number_of_ref_class,main_node)
					
					################################
					
					mpi.mpi_barrier( mpi.MPI_COMM_WORLD )
					if myid ==main_node:
						vol_list = []
					for igrp in range(Tracker["number_of_groups"]):
						if myid ==main_node: log_main.add("start vol   %d"%igrp)
						data,old_shifts = get_shrink_data_huang(Tracker,Tracker["constants"]["nnxo"], os.path.join(outdir,"Class%d.txt"%igrp), Tracker["constants"]["partstack"],myid, main_node, nproc, preshift = True)
						volref          = recons3d_4nn_ctf_MPI(myid=myid, prjlist = data, symmetry = Tracker["constants"]["sym"],finfo= None)
						if myid == main_node: 
							vol_list.append(volref)
							log_main.add(" vol   %d is done"%igrp)
					Tracker["number_of_ref_class"] = number_of_ref_class
					mpi.mpi_barrier( mpi.MPI_COMM_WORLD )
					generation +=1
					#################################
					
					if myid ==main_node:
						for ivol in range(len(vol_list)): vol_list[ivol].write_image(os.path.join(workdir, "vol_of_Classes.hdf"),ivol)
						filt_tanl(vol_list[ivol],Tracker["constants"]["low_pass_filter"],.1).write_image(os.path.join(workdir, "volf_of_Classes.hdf"),ivol)
						log_main.add("number of unaccounted particles  %10d"%len(Tracker["this_unaccounted_list"]))
						log_main.add("number of accounted particles  %10d"%len(Tracker["this_accounted_list"]))
						del vol_list
					Tracker["this_data_list"]        = Tracker["this_unaccounted_list"]
					Tracker["total_stack"]           = len(Tracker["this_unaccounted_list"])
					Tracker["this_total_stack"]      = Tracker["total_stack"]
					#update_full_dict(complementary)
					#number_of_groups = int(float(len(Tracker["this_unaccounted_list"]))/number_of_images_per_group)
					del list_to_be_processed
					list_to_be_processed             = copy.deepcopy(Tracker["this_unaccounted_list"]) 
					Tracker["number_of_groups"]      = get_number_of_groups(len(list_to_be_processed),Tracker["constants"]["number_of_images_per_group"])
					mpi.mpi_barrier( mpi.MPI_COMM_WORLD )
					
	#############################################################################################################################
				### Reconstruct the unaccounted is only done once
			
				if (Tracker["constants"]["unaccounted"] and (len(Tracker["this_unaccounted_list"]) != 0)):
					data,old_shifts = get_shrink_data_huang(Tracker,Tracker["constants"]["nnxo"],Tracker["this_unaccounted_text"],Tracker["constants"]["partstack"],myid,main_node,nproc,preshift = True)
					volref = recons3d_4nn_ctf_MPI(myid=myid, prjlist = data, symmetry=Tracker["constants"]["sym"],finfo=None)
					volref = filt_tanl(volref, Tracker["constants"]["low_pass_filter"],.1)
					if myid ==main_node: volref.write_image(os.path.join(workdir, "volf_unaccounted.hdf"))
				
					######## Exhaustive Kmeans #############################################
				if myid ==main_node:
					if len(Tracker["this_unaccounted_list"])>=Tracker["constants"]["smallest_group"]:
						new_stable1.append(Tracker["this_unaccounted_list"])
					unaccounted                 = get_complementary_elements_total(Tracker["constants"]["total_stack"], final_list)
					Tracker["number_of_groups"] = len(new_stable1)
					log_main.add("----------------Exhaustive Kmeans------------------")
					log_main.add("number_of_groups is %d"%Tracker["number_of_groups"])
				else:    Tracker["number_of_groups"] = 0
				### prepare references for final K-means
				if myid == main_node:
					final_list =[]
					for alist in new_stable1:
						for element in alist:final_list.append(int(element))
					unaccounted = get_complementary_elements_total(Tracker["constants"]["total_stack"],final_list)
					if len(unaccounted) > Tracker["constants"]["smallest_group"]:  # treat unaccounted ones also as a group if it is not too small.
						new_stable1.append(unaccounted)
						Tracker["number_of_groups"] = len(new_stable1)
						for any in unaccounted:final_list.append(any)
					log_main.add("total number %d"%len(final_list))
				else:  final_list = 0
				Tracker["number_of_groups"] = bcast_number_to_all(Tracker["number_of_groups"],source_node = main_node)
				mpi.mpi_barrier( mpi.MPI_COMM_WORLD )
				final_list = wrap_mpi_bcast(final_list, main_node)
				workdir = os.path.join(P2_run_dir,"Exhaustive_Kmeans") # new workdir 
				if myid==main_node:
					os.mkdir(workdir)
					write_text_file(final_list, os.path.join(workdir,"final_list.txt"))
				else: new_stable1 = 0
				mpi.mpi_barrier( mpi.MPI_COMM_WORLD )
				## Create reference volumes
		
				if myid == main_node:
					number_of_ref_class = []
					for igrp in range(Tracker["number_of_groups"]):
						class_file = os.path.join(workdir,"final_class%d.txt"%igrp)
						write_text_file(new_stable1[igrp],class_file)
						log_main.add(" group %d   number of particles %d"%(igrp,len(new_stable1[igrp])))
						number_of_ref_class.append(len(new_stable1[igrp]))
				else:  number_of_ref_class= 0
				number_of_ref_class = wrap_mpi_bcast(number_of_ref_class,main_node)
		
				mpi.mpi_barrier( mpi.MPI_COMM_WORLD )
				ref_vol_list = []
				for igrp in range(Tracker["number_of_groups"]):
					if myid ==main_node : sxprint(" prepare reference %d"%igrp)
					#Tracker["this_data_list_file"] = os.path.join(workdir,"final_class%d.txt"%igrp)
					data,old_shifts                = get_shrink_data_huang(Tracker, Tracker["nxinit"],os.path.join(workdir,"final_class%d.txt"%igrp), Tracker["constants"]["partstack"], myid,main_node,nproc,preshift = True)
					volref                         = recons3d_4nn_ctf_MPI(myid=myid, prjlist = data, symmetry=Tracker["constants"]["sym"], finfo = None)
					#volref = filt_tanl(volref, Tracker["low_pass_filter"],.1)
					#if myid == main_node:
					#	volref.write_image(os.path.join(masterdir,"volf_stable.hdf"),iref)
					#volref = resample(volref,Tracker["shrinkage"])
					bcast_EMData_to_all(volref, myid, main_node)
					ref_vol_list.append(volref)
				mpi.mpi_barrier( mpi.MPI_COMM_WORLD )
				### -------variables used in Kmeans_exhaustive_run-----
				Tracker["number_of_ref_class"] = number_of_ref_class
				Tracker["this_data_list"]      = final_list
				Tracker["total_stack"]         = len(final_list)
				Tracker["this_dir"]            = workdir
				Tracker["this_data_list_file"] = os.path.join(workdir,"final_list.txt")
				KE_group                       = Kmeans_exhaustive_run(ref_vol_list,Tracker) # 
				P2_partitions.append(KE_group[:][:])
				if myid ==main_node:
					log_main.add(" the number of groups after exhaustive Kmeans is %d"%len(KE_group))
					for ike in range(len(KE_group)):log_main.add(" group   %d   number of objects %d"%(ike,len(KE_group[ike])))
					del new_stable1
				mpi.mpi_barrier( mpi.MPI_COMM_WORLD )
			if myid == main_node:   log_main.add("P2 runs are done, now start two-way comparision to exclude those that are not reproduced ")
			reproduced_groups = two_way_comparison_single(P2_partitions[0],P2_partitions[1],Tracker)# Here partition IDs are original indexes.
			###### ----------------Reconstruct reproduced groups------------------------#######
			######
			if myid == main_node:
				for index_of_reproduced_groups in range(len(reproduced_groups)):
					name_of_class_file = os.path.join(masterdir, "P2_final_class%d.txt"%index_of_reproduced_groups)
					write_text_file(reproduced_groups[index_of_reproduced_groups],name_of_class_file)
				log_main.add("-------start to reconstruct reproduced volumes individully to orignal size-----------")
				
			mpi.mpi_barrier( mpi.MPI_COMM_WORLD )
			if Tracker["constants"]["mask3D"]: mask_3d = get_shrink_3dmask(Tracker["constants"]["nnxo"],Tracker["constants"]["mask3D"])
			else:                              mask_3d = None
			
			for igrp in range(len(reproduced_groups)):
				data,old_shifts = get_shrink_data_huang(Tracker,Tracker["constants"]["nnxo"],os.path.join(masterdir, "P2_final_class%d.txt"%igrp),Tracker["constants"]["partstack"],myid,main_node,nproc,preshift = True)
				#volref = recons3d_4nn_ctf_MPI(myid=myid, prjlist = data, symmetry=Tracker["constants"]["sym"], finfo=None)
				if Tracker["constants"]["CTF"]: 
					volref, fscc = rec3D_two_chunks_MPI(data,1.0,Tracker["constants"]["sym"],mask_3d, \
										os.path.join(masterdir,"resolution_%02d.txt"%igrp),myid,main_node,index =-1,npad =2,finfo=None)
				else: 
					sxprint("Missing CTF flag!")
					return
				mpi.mpi_barrier( mpi.MPI_COMM_WORLD )
				fscc        = read_text_file(os.path.join(masterdir, "resolution_%02d.txt"%igrp),-1)
				nx_of_image = volref.get_xsize()
				if Tracker["constants"]["PWadjustment"]:	Tracker["PWadjustment"] = Tracker["PW_dict"][nx_of_image]
				else:										Tracker["PWadjustment"] = Tracker["constants"]["PWadjustment"]	
				try:
					lowpass = search_lowpass(fscc)
					falloff = 0.1
				except:
					lowpass= 0.4
					falloff= 0.1
				sxprint(lowpass)
				lowpass=round(lowpass,4)
				falloff=round(min(.1,falloff),4)
				Tracker["lowpass"]= lowpass
				Tracker["falloff"]= falloff
				if myid == main_node:
					refdata    =[None]*4
					refdata[0] = volref
					refdata[1] = Tracker
					refdata[2] = Tracker["constants"]["myid"]
					refdata[3] = Tracker["constants"]["nproc"]
					volref     = user_func(refdata)
					cutoff     = Tracker["constants"]["pixel_size"]/lowpass
					log_main.add("%d vol low pass filer %f   %f  cut to  %f Angstrom"%(igrp,Tracker["lowpass"],Tracker["falloff"],cutoff))
					volref.write_image(os.path.join(masterdir,"volf_final%d.hdf"%igrp))
		if myid==main_node:   log_main.add(" sxsort3d_P2 finishes. ")
		# Finish program
		mpi.mpi_barrier( mpi.MPI_COMM_WORLD )
		return