Beispiel #1
0
def apply_ctf(projection_2d, ctf):
	"""
	Applies a CTF to a projection
	:param projection_2d: 2D Projection
	:param ctf: CTF
	:return: CTF filtered projection
	"""
	projection_filtered = fltr.filt_ctf(projection_2d, ctf)
	return projection_filtered
Beispiel #2
0
def gen_rings_ctf(prjref, nx, ctf, numr):
    """
	  Convert set of ffts of projections to Fourier rings with additional multiplication by a ctf
	  The command returns list of rings
	"""
    from math import sin, cos, pi
    from sp_fundamentals import fft
    from sp_alignment import ringwe
    from sp_filter import filt_ctf
    mode = "F"
    wr_four = ringwe(numr, "F")
    cnx = nx // 2 + 1
    cny = nx // 2 + 1
    qv = pi / 180.0

    refrings = [
    ]  # list of (image objects) reference projections in Fourier representation

    for i in range(len(prjref)):
        cimage = Util.Polar2Dm(filt_ctf(prjref[i], ctf, True), cnx, cny, numr,
                               mode)  # currently set to quadratic....
        Util.Normalize_ring(cimage, numr, 0)

        Util.Frngs(cimage, numr)
        Util.Applyws(cimage, numr, wr_four)
        refrings.append(cimage)
        phi = prjref[i].get_attr('phi')
        theta = prjref[i].get_attr('theta')
        psi = prjref[i].get_attr('psi')
        n1 = sin(theta * qv) * cos(phi * qv)
        n2 = sin(theta * qv) * sin(phi * qv)
        n3 = cos(theta * qv)
        refrings[i].set_attr_dict({
            "n1": n1,
            "n2": n2,
            "n3": n3,
            "phi": phi,
            "theta": theta,
            "psi": psi
        })

    return refrings
Beispiel #3
0
def helicalshiftali_MPI(stack,
                        maskfile=None,
                        maxit=100,
                        CTF=False,
                        snr=1.0,
                        Fourvar=False,
                        search_rng=-1):

    nproc = mpi.mpi_comm_size(mpi.MPI_COMM_WORLD)
    myid = mpi.mpi_comm_rank(mpi.MPI_COMM_WORLD)
    main_node = 0

    ftp = file_type(stack)

    if myid == main_node:
        print_begin_msg("helical-shiftali_MPI")

    max_iter = int(maxit)
    if (myid == main_node):
        infils = EMUtil.get_all_attributes(stack, "filament")
        ptlcoords = EMUtil.get_all_attributes(stack, 'ptcl_source_coord')
        filaments = ordersegments(infils, ptlcoords)
        total_nfils = len(filaments)
        inidl = [0] * total_nfils
        for i in range(total_nfils):
            inidl[i] = len(filaments[i])
        linidl = sum(inidl)
        nima = linidl
        tfilaments = []
        for i in range(total_nfils):
            tfilaments += filaments[i]
        del filaments
    else:
        total_nfils = 0
        linidl = 0
    total_nfils = bcast_number_to_all(total_nfils, source_node=main_node)
    if myid != main_node:
        inidl = [-1] * total_nfils
    inidl = bcast_list_to_all(inidl, myid, source_node=main_node)
    linidl = bcast_number_to_all(linidl, source_node=main_node)
    if myid != main_node:
        tfilaments = [-1] * linidl
    tfilaments = bcast_list_to_all(tfilaments, myid, source_node=main_node)
    filaments = []
    iendi = 0
    for i in range(total_nfils):
        isti = iendi
        iendi = isti + inidl[i]
        filaments.append(tfilaments[isti:iendi])
    del tfilaments, inidl

    if myid == main_node:
        print_msg("total number of filaments: %d" % total_nfils)
    if total_nfils < nproc:
        ERROR(
            'number of CPUs (%i) is larger than the number of filaments (%i), please reduce the number of CPUs used'
            % (nproc, total_nfils),
            myid=myid)

    #  balanced load
    temp = chunks_distribution([[len(filaments[i]), i]
                                for i in range(len(filaments))],
                               nproc)[myid:myid + 1][0]
    filaments = [filaments[temp[i][1]] for i in range(len(temp))]
    nfils = len(filaments)

    #filaments = [[0,1]]
    #print "filaments",filaments
    list_of_particles = []
    indcs = []
    k = 0
    for i in range(nfils):
        list_of_particles += filaments[i]
        k1 = k + len(filaments[i])
        indcs.append([k, k1])
        k = k1
    data = EMData.read_images(stack, list_of_particles)
    ldata = len(data)
    sxprint("ldata=", ldata)
    nx = data[0].get_xsize()
    ny = data[0].get_ysize()
    if maskfile == None:
        mrad = min(nx, ny) // 2 - 2
        mask = pad(model_blank(2 * mrad + 1, ny, 1, 1.0), nx, ny, 1, 0.0)
    else:
        mask = get_im(maskfile)

    # apply initial xform.align2d parameters stored in header
    init_params = []
    for im in range(ldata):
        t = data[im].get_attr('xform.align2d')
        init_params.append(t)
        p = t.get_params("2d")
        data[im] = rot_shift2D(data[im], p['alpha'], p['tx'], p['ty'],
                               p['mirror'], p['scale'])

    if CTF:
        from sp_filter import filt_ctf
        from sp_morphology import ctf_img
        ctf_abs_sum = EMData(nx, ny, 1, False)
        ctf_2_sum = EMData(nx, ny, 1, False)
    else:
        ctf_2_sum = None
        ctf_abs_sum = None

    from sp_utilities import info

    for im in range(ldata):
        data[im].set_attr('ID', list_of_particles[im])
        st = Util.infomask(data[im], mask, False)
        data[im] -= st[0]
        if CTF:
            ctf_params = data[im].get_attr("ctf")
            qctf = data[im].get_attr("ctf_applied")
            if qctf == 0:
                data[im] = filt_ctf(fft(data[im]), ctf_params)
                data[im].set_attr('ctf_applied', 1)
            elif qctf != 1:
                ERROR('Incorrectly set qctf flag', myid=myid)
            ctfimg = ctf_img(nx, ctf_params, ny=ny)
            Util.add_img2(ctf_2_sum, ctfimg)
            Util.add_img_abs(ctf_abs_sum, ctfimg)
        else:
            data[im] = fft(data[im])

    del list_of_particles

    if CTF:
        reduce_EMData_to_root(ctf_2_sum, myid, main_node)
        reduce_EMData_to_root(ctf_abs_sum, myid, main_node)
    if CTF:
        if myid != main_node:
            del ctf_2_sum
            del ctf_abs_sum
        else:
            temp = EMData(nx, ny, 1, False)
            tsnr = 1. / snr
            for i in range(0, nx + 2, 2):
                for j in range(ny):
                    temp.set_value_at(i, j, tsnr)
                    temp.set_value_at(i + 1, j, 0.0)
            #info(ctf_2_sum)
            Util.add_img(ctf_2_sum, temp)
            #info(ctf_2_sum)
            del temp

    total_iter = 0
    shift_x = [0.0] * ldata

    for Iter in range(max_iter):
        if myid == main_node:
            start_time = time()
            print_msg("Iteration #%4d\n" % (total_iter))
        total_iter += 1
        avg = EMData(nx, ny, 1, False)
        for im in range(ldata):
            Util.add_img(avg, fshift(data[im], shift_x[im]))

        reduce_EMData_to_root(avg, myid, main_node)

        if myid == main_node:
            if CTF: tavg = Util.divn_filter(avg, ctf_2_sum)
            else: tavg = Util.mult_scalar(avg, 1.0 / float(nima))
        else:
            tavg = model_blank(nx, ny)

        if Fourvar:
            bcast_EMData_to_all(tavg, myid, main_node)
            vav, rvar = varf2d_MPI(myid, data, tavg, mask, "a", CTF)

        if myid == main_node:
            if Fourvar:
                tavg = fft(Util.divn_img(fft(tavg), vav))
                vav_r = Util.pack_complex_to_real(vav)
            # normalize and mask tavg in real space
            tavg = fft(tavg)
            stat = Util.infomask(tavg, mask, False)
            tavg -= stat[0]
            Util.mul_img(tavg, mask)
            tavg.write_image("tavg.hdf", Iter)
            # For testing purposes: shift tavg to some random place and see if the centering is still correct
            #tavg = rot_shift3D(tavg,sx=3,sy=-4)

        if Fourvar: del vav
        bcast_EMData_to_all(tavg, myid, main_node)
        tavg = fft(tavg)

        sx_sum = 0.0
        nxc = nx // 2

        for ifil in range(nfils):
            """
			# Calculate filament average
			avg = EMData(nx, ny, 1, False)
			filnima = 0
			for im in xrange(indcs[ifil][0], indcs[ifil][1]):
				Util.add_img(avg, data[im])
				filnima += 1
			tavg = Util.mult_scalar(avg, 1.0/float(filnima))
			"""
            # Calculate 1D ccf between each segment and filament average
            nsegms = indcs[ifil][1] - indcs[ifil][0]
            ctx = [None] * nsegms
            pcoords = [None] * nsegms
            for im in range(indcs[ifil][0], indcs[ifil][1]):
                ctx[im - indcs[ifil][0]] = Util.window(ccf(tavg, data[im]), nx,
                                                       1)
                pcoords[im - indcs[ifil][0]] = data[im].get_attr(
                    'ptcl_source_coord')
                #ctx[im-indcs[ifil][0]].write_image("ctx.hdf",im-indcs[ifil][0])
                #print "  CTX  ",myid,im,Util.infomask(ctx[im-indcs[ifil][0]], None, True)
            # search for best x-shift
            cents = nsegms // 2

            dst = sqrt(
                max((pcoords[cents][0] - pcoords[0][0])**2 +
                    (pcoords[cents][1] - pcoords[0][1])**2,
                    (pcoords[cents][0] - pcoords[-1][0])**2 +
                    (pcoords[cents][1] - pcoords[-1][1])**2))
            maxincline = atan2(ny // 2 - 2 - float(search_rng), dst)
            kang = int(dst * tan(maxincline) + 0.5)
            #print  "  settings ",nsegms,cents,dst,search_rng,maxincline,kang

            # ## C code for alignment. @ming
            results = [0.0] * 3
            results = Util.helixshiftali(ctx, pcoords, nsegms, maxincline,
                                         kang, search_rng, nxc)
            sib = int(results[0])
            bang = results[1]
            qm = results[2]
            #print qm, sib, bang

            # qm = -1.e23
            #
            # 			for six in xrange(-search_rng, search_rng+1,1):
            # 				q0 = ctx[cents].get_value_at(six+nxc)
            # 				for incline in xrange(kang+1):
            # 					qt = q0
            # 					qu = q0
            # 					if(kang>0):  tang = tan(maxincline/kang*incline)
            # 					else:        tang = 0.0
            # 					for kim in xrange(cents+1,nsegms):
            # 						dst = sqrt((pcoords[cents][0] - pcoords[kim][0])**2 + (pcoords[cents][1] - pcoords[kim][1])**2)
            # 						xl = dst*tang+six+nxc
            # 						ixl = int(xl)
            # 						dxl = xl - ixl
            # 						#print "  A  ", ifil,six,incline,kim,xl,ixl,dxl
            # 						qt += (1.0-dxl)*ctx[kim].get_value_at(ixl) + dxl*ctx[kim].get_value_at(ixl+1)
            # 						xl = -dst*tang+six+nxc
            # 						ixl = int(xl)
            # 						dxl = xl - ixl
            # 						qu += (1.0-dxl)*ctx[kim].get_value_at(ixl) + dxl*ctx[kim].get_value_at(ixl+1)
            # 					for kim in xrange(cents):
            # 						dst = sqrt((pcoords[cents][0] - pcoords[kim][0])**2 + (pcoords[cents][1] - pcoords[kim][1])**2)
            # 						xl = -dst*tang+six+nxc
            # 						ixl = int(xl)
            # 						dxl = xl - ixl
            # 						qt += (1.0-dxl)*ctx[kim].get_value_at(ixl) + dxl*ctx[kim].get_value_at(ixl+1)
            # 						xl =  dst*tang+six+nxc
            # 						ixl = int(xl)
            # 						dxl = xl - ixl
            # 						qu += (1.0-dxl)*ctx[kim].get_value_at(ixl) + dxl*ctx[kim].get_value_at(ixl+1)
            # 					if( qt > qm ):
            # 						qm = qt
            # 						sib = six
            # 						bang = tang
            # 					if( qu > qm ):
            # 						qm = qu
            # 						sib = six
            # 						bang = -tang
            #if incline == 0:  print  "incline = 0  ",six,tang,qt,qu
            #print qm,six,sib,bang
            #print " got results   ",indcs[ifil][0], indcs[ifil][1], ifil,myid,qm,sib,tang,bang,len(ctx),Util.infomask(ctx[0], None, True)
            for im in range(indcs[ifil][0], indcs[ifil][1]):
                kim = im - indcs[ifil][0]
                dst = sqrt((pcoords[cents][0] - pcoords[kim][0])**2 +
                           (pcoords[cents][1] - pcoords[kim][1])**2)
                if (kim < cents): xl = -dst * bang + sib
                else: xl = dst * bang + sib
                shift_x[im] = xl

            # Average shift
            sx_sum += shift_x[indcs[ifil][0] + cents]

        # #print myid,sx_sum,total_nfils
        sx_sum = mpi.mpi_reduce(sx_sum, 1, mpi.MPI_FLOAT, mpi.MPI_SUM,
                                main_node, mpi.MPI_COMM_WORLD)
        if myid == main_node:
            sx_sum = float(sx_sum[0]) / total_nfils
            print_msg("Average shift  %6.2f\n" % (sx_sum))
        else:
            sx_sum = 0.0
        sx_sum = 0.0
        sx_sum = bcast_number_to_all(sx_sum, source_node=main_node)
        for im in range(ldata):
            shift_x[im] -= sx_sum
            #print  "   %3d  %6.3f"%(im,shift_x[im])
        #exit()

    # combine shifts found with the original parameters
    for im in range(ldata):
        t1 = Transform()
        ##import random
        ##shix=random.randint(-10, 10)
        ##t1.set_params({"type":"2D","tx":shix})
        t1.set_params({"type": "2D", "tx": shift_x[im]})
        # combine t0 and t1
        tt = t1 * init_params[im]
        data[im].set_attr("xform.align2d", tt)
    # write out headers and STOP, under MPI writing has to be done sequentially
    mpi.mpi_barrier(mpi.MPI_COMM_WORLD)
    par_str = ["xform.align2d", "ID"]
    if myid == main_node:
        from sp_utilities import file_type
        if (file_type(stack) == "bdb"):
            from sp_utilities import recv_attr_dict_bdb
            recv_attr_dict_bdb(main_node, stack, data, par_str, 0, ldata,
                               nproc)
        else:
            from sp_utilities import recv_attr_dict
            recv_attr_dict(main_node, stack, data, par_str, 0, ldata, nproc)
    else:
        send_attr_dict(main_node, data, par_str, 0, ldata)
    if myid == main_node: print_end_msg("helical-shiftali_MPI")
Beispiel #4
0
def shiftali_MPI(stack,
                 maskfile=None,
                 maxit=100,
                 CTF=False,
                 snr=1.0,
                 Fourvar=False,
                 search_rng=-1,
                 oneDx=False,
                 search_rng_y=-1):

    number_of_proc = mpi.mpi_comm_size(mpi.MPI_COMM_WORLD)
    myid = mpi.mpi_comm_rank(mpi.MPI_COMM_WORLD)
    main_node = 0

    ftp = file_type(stack)

    if myid == main_node:
        print_begin_msg("shiftali_MPI")

    max_iter = int(maxit)

    if myid == main_node:
        if ftp == "bdb":
            from EMAN2db import db_open_dict
            dummy = db_open_dict(stack, True)
        nima = EMUtil.get_image_count(stack)
    else:
        nima = 0
    nima = bcast_number_to_all(nima, source_node=main_node)
    list_of_particles = list(range(nima))

    image_start, image_end = MPI_start_end(nima, number_of_proc, myid)
    list_of_particles = list_of_particles[image_start:image_end]

    # read nx and ctf_app (if CTF) and broadcast to all nodes
    if myid == main_node:
        ima = EMData()
        ima.read_image(stack, list_of_particles[0], True)
        nx = ima.get_xsize()
        ny = ima.get_ysize()
        if CTF: ctf_app = ima.get_attr_default('ctf_applied', 2)
        del ima
    else:
        nx = 0
        ny = 0
        if CTF: ctf_app = 0
    nx = bcast_number_to_all(nx, source_node=main_node)
    ny = bcast_number_to_all(ny, source_node=main_node)
    if CTF:
        ctf_app = bcast_number_to_all(ctf_app, source_node=main_node)
        if ctf_app > 0:
            ERROR("data cannot be ctf-applied", myid=myid)

    if maskfile == None:
        mrad = min(nx, ny)
        mask = model_circle(mrad // 2 - 2, nx, ny)
    else:
        mask = get_im(maskfile)

    if CTF:
        from sp_filter import filt_ctf
        from sp_morphology import ctf_img
        ctf_abs_sum = EMData(nx, ny, 1, False)
        ctf_2_sum = EMData(nx, ny, 1, False)
    else:
        ctf_2_sum = None

    from sp_global_def import CACHE_DISABLE
    if CACHE_DISABLE:
        data = EMData.read_images(stack, list_of_particles)
    else:
        for i in range(number_of_proc):
            if myid == i:
                data = EMData.read_images(stack, list_of_particles)
            if ftp == "bdb": mpi.mpi_barrier(mpi.MPI_COMM_WORLD)

    for im in range(len(data)):
        data[im].set_attr('ID', list_of_particles[im])
        st = Util.infomask(data[im], mask, False)
        data[im] -= st[0]
        if CTF:
            ctf_params = data[im].get_attr("ctf")
            ctfimg = ctf_img(nx, ctf_params, ny=ny)
            Util.add_img2(ctf_2_sum, ctfimg)
            Util.add_img_abs(ctf_abs_sum, ctfimg)

    if CTF:
        reduce_EMData_to_root(ctf_2_sum, myid, main_node)
        reduce_EMData_to_root(ctf_abs_sum, myid, main_node)
    else:
        ctf_2_sum = None
    if CTF:
        if myid != main_node:
            del ctf_2_sum
            del ctf_abs_sum
        else:
            temp = EMData(nx, ny, 1, False)
            for i in range(0, nx, 2):
                for j in range(ny):
                    temp.set_value_at(i, j, snr)
            Util.add_img(ctf_2_sum, temp)
            del temp

    total_iter = 0

    # apply initial xform.align2d parameters stored in header
    init_params = []
    for im in range(len(data)):
        t = data[im].get_attr('xform.align2d')
        init_params.append(t)
        p = t.get_params("2d")
        data[im] = rot_shift2D(data[im],
                               p['alpha'],
                               sx=p['tx'],
                               sy=p['ty'],
                               mirror=p['mirror'],
                               scale=p['scale'])

    # fourier transform all images, and apply ctf if CTF
    for im in range(len(data)):
        if CTF:
            ctf_params = data[im].get_attr("ctf")
            data[im] = filt_ctf(fft(data[im]), ctf_params)
        else:
            data[im] = fft(data[im])

    sx_sum = 0
    sy_sum = 0
    sx_sum_total = 0
    sy_sum_total = 0
    shift_x = [0.0] * len(data)
    shift_y = [0.0] * len(data)
    ishift_x = [0.0] * len(data)
    ishift_y = [0.0] * len(data)

    for Iter in range(max_iter):
        if myid == main_node:
            start_time = time()
            print_msg("Iteration #%4d\n" % (total_iter))
        total_iter += 1
        avg = EMData(nx, ny, 1, False)
        for im in data:
            Util.add_img(avg, im)

        reduce_EMData_to_root(avg, myid, main_node)

        if myid == main_node:
            if CTF:
                tavg = Util.divn_filter(avg, ctf_2_sum)
            else:
                tavg = Util.mult_scalar(avg, 1.0 / float(nima))
        else:
            tavg = EMData(nx, ny, 1, False)

        if Fourvar:
            bcast_EMData_to_all(tavg, myid, main_node)
            vav, rvar = varf2d_MPI(myid, data, tavg, mask, "a", CTF)

        if myid == main_node:
            if Fourvar:
                tavg = fft(Util.divn_img(fft(tavg), vav))
                vav_r = Util.pack_complex_to_real(vav)

            # normalize and mask tavg in real space
            tavg = fft(tavg)
            stat = Util.infomask(tavg, mask, False)
            tavg -= stat[0]
            Util.mul_img(tavg, mask)
            # For testing purposes: shift tavg to some random place and see if the centering is still correct
            #tavg = rot_shift3D(tavg,sx=3,sy=-4)
            tavg = fft(tavg)

        if Fourvar: del vav
        bcast_EMData_to_all(tavg, myid, main_node)

        sx_sum = 0
        sy_sum = 0
        if search_rng > 0: nwx = 2 * search_rng + 1
        else: nwx = nx

        if search_rng_y > 0: nwy = 2 * search_rng_y + 1
        else: nwy = ny

        not_zero = 0
        for im in range(len(data)):
            if oneDx:
                ctx = Util.window(ccf(data[im], tavg), nwx, 1)
                p1 = peak_search(ctx)
                p1_x = -int(p1[0][3])
                ishift_x[im] = p1_x
                sx_sum += p1_x
            else:
                p1 = peak_search(Util.window(ccf(data[im], tavg), nwx, nwy))
                p1_x = -int(p1[0][4])
                p1_y = -int(p1[0][5])
                ishift_x[im] = p1_x
                ishift_y[im] = p1_y
                sx_sum += p1_x
                sy_sum += p1_y

            if not_zero == 0:
                if (not (ishift_x[im] == 0.0)) or (not (ishift_y[im] == 0.0)):
                    not_zero = 1

        sx_sum = mpi.mpi_reduce(sx_sum, 1, mpi.MPI_INT, mpi.MPI_SUM, main_node,
                                mpi.MPI_COMM_WORLD)

        if not oneDx:
            sy_sum = mpi.mpi_reduce(sy_sum, 1, mpi.MPI_INT, mpi.MPI_SUM,
                                    main_node, mpi.MPI_COMM_WORLD)

        if myid == main_node:
            sx_sum_total = int(sx_sum[0])
            if not oneDx:
                sy_sum_total = int(sy_sum[0])
        else:
            sx_sum_total = 0
            sy_sum_total = 0

        sx_sum_total = bcast_number_to_all(sx_sum_total, source_node=main_node)

        if not oneDx:
            sy_sum_total = bcast_number_to_all(sy_sum_total,
                                               source_node=main_node)

        sx_ave = round(float(sx_sum_total) / nima)
        sy_ave = round(float(sy_sum_total) / nima)
        for im in range(len(data)):
            p1_x = ishift_x[im] - sx_ave
            p1_y = ishift_y[im] - sy_ave
            params2 = {
                "filter_type": Processor.fourier_filter_types.SHIFT,
                "x_shift": p1_x,
                "y_shift": p1_y,
                "z_shift": 0.0
            }
            data[im] = Processor.EMFourierFilter(data[im], params2)
            shift_x[im] += p1_x
            shift_y[im] += p1_y
        # stop if all shifts are zero
        not_zero = mpi.mpi_reduce(not_zero, 1, mpi.MPI_INT, mpi.MPI_SUM,
                                  main_node, mpi.MPI_COMM_WORLD)
        if myid == main_node:
            not_zero_all = int(not_zero[0])
        else:
            not_zero_all = 0
        not_zero_all = bcast_number_to_all(not_zero_all, source_node=main_node)

        if myid == main_node:
            print_msg("Time of iteration = %12.2f\n" % (time() - start_time))
            start_time = time()

        if not_zero_all == 0: break

    #for im in xrange(len(data)): data[im] = fft(data[im])  This should not be required as only header information is used
    # combine shifts found with the original parameters
    for im in range(len(data)):
        t0 = init_params[im]
        t1 = Transform()
        t1.set_params({
            "type": "2D",
            "alpha": 0,
            "scale": t0.get_scale(),
            "mirror": 0,
            "tx": shift_x[im],
            "ty": shift_y[im]
        })
        # combine t0 and t1
        tt = t1 * t0
        data[im].set_attr("xform.align2d", tt)

    # write out headers and STOP, under MPI writing has to be done sequentially
    mpi.mpi_barrier(mpi.MPI_COMM_WORLD)
    par_str = ["xform.align2d", "ID"]
    if myid == main_node:
        from sp_utilities import file_type
        if (file_type(stack) == "bdb"):
            from sp_utilities import recv_attr_dict_bdb
            recv_attr_dict_bdb(main_node, stack, data, par_str, image_start,
                               image_end, number_of_proc)
        else:
            from sp_utilities import recv_attr_dict
            recv_attr_dict(main_node, stack, data, par_str, image_start,
                           image_end, number_of_proc)

    else:
        send_attr_dict(main_node, data, par_str, image_start, image_end)
    if myid == main_node: print_end_msg("shiftali_MPI")
Beispiel #5
0
def main():

    progname = os.path.basename(sys.argv[0])
    usage = progname + " proj_stack output_averages --MPI"
    parser = OptionParser(usage, version=SPARXVERSION)

    parser.add_option("--img_per_group",
                      type="int",
                      default=100,
                      help="number of images per group")
    parser.add_option("--radius",
                      type="int",
                      default=-1,
                      help="radius for alignment")
    parser.add_option(
        "--xr",
        type="string",
        default="2 1",
        help="range for translation search in x direction, search is +/xr")
    parser.add_option(
        "--yr",
        type="string",
        default="-1",
        help=
        "range for translation search in y direction, search is +/yr (default = same as xr)"
    )
    parser.add_option(
        "--ts",
        type="string",
        default="1 0.5",
        help=
        "step size of the translation search in both directions, search is -xr, -xr+ts, 0, xr-ts, xr, can be fractional"
    )
    parser.add_option(
        "--iter",
        type="int",
        default=30,
        help="number of iterations within alignment (default = 30)")
    parser.add_option(
        "--num_ali",
        type="int",
        default=5,
        help="number of alignments performed for stability (default = 5)")
    parser.add_option("--thld_err",
                      type="float",
                      default=1.0,
                      help="threshold of pixel error (default = 1.732)")
    parser.add_option(
        "--grouping",
        type="string",
        default="GRP",
        help=
        "do grouping of projections: PPR - per projection, GRP - different size groups, exclusive (default), GEV - grouping equal size"
    )
    parser.add_option(
        "--delta",
        type="float",
        default=-1.0,
        help="angular step for reference projections (required for GEV method)"
    )
    parser.add_option(
        "--fl",
        type="float",
        default=0.3,
        help="cut-off frequency of hyperbolic tangent low-pass Fourier filter")
    parser.add_option(
        "--aa",
        type="float",
        default=0.2,
        help="fall-off of hyperbolic tangent low-pass Fourier filter")
    parser.add_option("--CTF",
                      action="store_true",
                      default=False,
                      help="Consider CTF correction during the alignment ")
    parser.add_option("--MPI",
                      action="store_true",
                      default=False,
                      help="use MPI version")

    (options, args) = parser.parse_args()

    myid = mpi.mpi_comm_rank(MPI_COMM_WORLD)
    number_of_proc = mpi.mpi_comm_size(MPI_COMM_WORLD)
    main_node = 0

    if len(args) == 2:
        stack = args[0]
        outdir = args[1]
    else:
        sp_global_def.ERROR("Incomplete list of arguments",
                            "sxproj_stability.main",
                            1,
                            myid=myid)
        return
    if not options.MPI:
        sp_global_def.ERROR("Non-MPI not supported!",
                            "sxproj_stability.main",
                            1,
                            myid=myid)
        return

    if sp_global_def.CACHE_DISABLE:
        from sp_utilities import disable_bdb_cache
        disable_bdb_cache()
    sp_global_def.BATCH = True

    img_per_grp = options.img_per_group
    radius = options.radius
    ite = options.iter
    num_ali = options.num_ali
    thld_err = options.thld_err

    xrng = get_input_from_string(options.xr)
    if options.yr == "-1":
        yrng = xrng
    else:
        yrng = get_input_from_string(options.yr)

    step = get_input_from_string(options.ts)

    if myid == main_node:
        nima = EMUtil.get_image_count(stack)
        img = get_image(stack)
        nx = img.get_xsize()
        ny = img.get_ysize()
    else:
        nima = 0
        nx = 0
        ny = 0
    nima = bcast_number_to_all(nima)
    nx = bcast_number_to_all(nx)
    ny = bcast_number_to_all(ny)
    if radius == -1: radius = nx / 2 - 2
    mask = model_circle(radius, nx, nx)

    st = time()
    if options.grouping == "GRP":
        if myid == main_node:
            sxprint("  A  ", myid, "  ", time() - st)
            proj_attr = EMUtil.get_all_attributes(stack, "xform.projection")
            proj_params = []
            for i in range(nima):
                dp = proj_attr[i].get_params("spider")
                phi, theta, psi, s2x, s2y = dp["phi"], dp["theta"], dp[
                    "psi"], -dp["tx"], -dp["ty"]
                proj_params.append([phi, theta, psi, s2x, s2y])

            # Here is where the grouping is done, I didn't put enough annotation in the group_proj_by_phitheta,
            # So I will briefly explain it here
            # proj_list  : Returns a list of list of particle numbers, each list contains img_per_grp particle numbers
            #              except for the last one. Depending on the number of particles left, they will either form a
            #              group or append themselves to the last group
            # angle_list : Also returns a list of list, each list contains three numbers (phi, theta, delta), (phi,
            #              theta) is the projection angle of the center of the group, delta is the range of this group
            # mirror_list: Also returns a list of list, each list contains img_per_grp True or False, which indicates
            #              whether it should take mirror position.
            # In this program angle_list and mirror list are not of interest.

            proj_list_all, angle_list, mirror_list = group_proj_by_phitheta(
                proj_params, img_per_grp=img_per_grp)
            del proj_params
            sxprint("  B  number of groups  ", myid, "  ", len(proj_list_all),
                    time() - st)
        mpi_barrier(MPI_COMM_WORLD)

        # Number of groups, actually there could be one or two more groups, since the size of the remaining group varies
        # we will simply assign them to main node.
        n_grp = nima / img_per_grp - 1

        # Divide proj_list_all equally to all nodes, and becomes proj_list
        proj_list = []
        for i in range(n_grp):
            proc_to_stay = i % number_of_proc
            if proc_to_stay == main_node:
                if myid == main_node: proj_list.append(proj_list_all[i])
            elif myid == main_node:
                mpi_send(len(proj_list_all[i]), 1, MPI_INT, proc_to_stay,
                         SPARX_MPI_TAG_UNIVERSAL, MPI_COMM_WORLD)
                mpi_send(proj_list_all[i], len(proj_list_all[i]), MPI_INT,
                         proc_to_stay, SPARX_MPI_TAG_UNIVERSAL, MPI_COMM_WORLD)
            elif myid == proc_to_stay:
                img_per_grp = mpi_recv(1, MPI_INT, main_node,
                                       SPARX_MPI_TAG_UNIVERSAL, MPI_COMM_WORLD)
                img_per_grp = int(img_per_grp[0])
                temp = mpi_recv(img_per_grp, MPI_INT, main_node,
                                SPARX_MPI_TAG_UNIVERSAL, MPI_COMM_WORLD)
                proj_list.append(list(map(int, temp)))
                del temp
            mpi_barrier(MPI_COMM_WORLD)
        sxprint("  C  ", myid, "  ", time() - st)
        if myid == main_node:
            # Assign the remaining groups to main_node
            for i in range(n_grp, len(proj_list_all)):
                proj_list.append(proj_list_all[i])
            del proj_list_all, angle_list, mirror_list

    #   Compute stability per projection projection direction, equal number assigned, thus overlaps
    elif options.grouping == "GEV":

        if options.delta == -1.0:
            ERROR(
                "Angular step for reference projections is required for GEV method"
            )
            return

        from sp_utilities import even_angles, nearestk_to_refdir, getvec
        refproj = even_angles(options.delta)
        img_begin, img_end = MPI_start_end(len(refproj), number_of_proc, myid)
        # Now each processor keeps its own share of reference projections
        refprojdir = refproj[img_begin:img_end]
        del refproj

        ref_ang = [0.0] * (len(refprojdir) * 2)
        for i in range(len(refprojdir)):
            ref_ang[i * 2] = refprojdir[0][0]
            ref_ang[i * 2 + 1] = refprojdir[0][1] + i * 0.1

        sxprint("  A  ", myid, "  ", time() - st)
        proj_attr = EMUtil.get_all_attributes(stack, "xform.projection")
        #  the solution below is very slow, do not use it unless there is a problem with the i/O
        """
		for i in xrange(number_of_proc):
			if myid == i:
				proj_attr = EMUtil.get_all_attributes(stack, "xform.projection")
			mpi_barrier(MPI_COMM_WORLD)
		"""
        sxprint("  B  ", myid, "  ", time() - st)

        proj_ang = [0.0] * (nima * 2)
        for i in range(nima):
            dp = proj_attr[i].get_params("spider")
            proj_ang[i * 2] = dp["phi"]
            proj_ang[i * 2 + 1] = dp["theta"]
        sxprint("  C  ", myid, "  ", time() - st)
        asi = Util.nearestk_to_refdir(proj_ang, ref_ang, img_per_grp)
        del proj_ang, ref_ang
        proj_list = []
        for i in range(len(refprojdir)):
            proj_list.append(asi[i * img_per_grp:(i + 1) * img_per_grp])
        del asi
        sxprint("  D  ", myid, "  ", time() - st)
        #from sys import exit
        #exit()

    #   Compute stability per projection
    elif options.grouping == "PPR":
        sxprint("  A  ", myid, "  ", time() - st)
        proj_attr = EMUtil.get_all_attributes(stack, "xform.projection")
        sxprint("  B  ", myid, "  ", time() - st)
        proj_params = []
        for i in range(nima):
            dp = proj_attr[i].get_params("spider")
            phi, theta, psi, s2x, s2y = dp["phi"], dp["theta"], dp[
                "psi"], -dp["tx"], -dp["ty"]
            proj_params.append([phi, theta, psi, s2x, s2y])
        img_begin, img_end = MPI_start_end(nima, number_of_proc, myid)
        sxprint("  C  ", myid, "  ", time() - st)
        from sp_utilities import nearest_proj
        proj_list, mirror_list = nearest_proj(
            proj_params, img_per_grp,
            list(range(img_begin, img_begin + 1)))  #range(img_begin, img_end))
        refprojdir = proj_params[img_begin:img_end]
        del proj_params, mirror_list
        sxprint("  D  ", myid, "  ", time() - st)

    else:
        ERROR("Incorrect projection grouping option")
        return

    ###########################################################################################################
    # Begin stability test
    from sp_utilities import get_params_proj, read_text_file
    #if myid == 0:
    #	from utilities import read_text_file
    #	proj_list[0] = map(int, read_text_file("lggrpp0.txt"))

    from sp_utilities import model_blank
    aveList = [model_blank(nx, ny)] * len(proj_list)
    if options.grouping == "GRP":
        refprojdir = [[0.0, 0.0, -1.0]] * len(proj_list)
    for i in range(len(proj_list)):
        sxprint("  E  ", myid, "  ", time() - st)
        class_data = EMData.read_images(stack, proj_list[i])
        #print "  R  ",myid,"  ",time()-st
        if options.CTF:
            from sp_filter import filt_ctf
            for im in range(len(class_data)):  #  MEM LEAK!!
                atemp = class_data[im].copy()
                btemp = filt_ctf(atemp, atemp.get_attr("ctf"), binary=1)
                class_data[im] = btemp
                #class_data[im] = filt_ctf(class_data[im], class_data[im].get_attr("ctf"), binary=1)
        for im in class_data:
            try:
                t = im.get_attr(
                    "xform.align2d")  # if they are there, no need to set them!
            except:
                try:
                    t = im.get_attr("xform.projection")
                    d = t.get_params("spider")
                    set_params2D(im, [0.0, -d["tx"], -d["ty"], 0, 1.0])
                except:
                    set_params2D(im, [0.0, 0.0, 0.0, 0, 1.0])
        #print "  F  ",myid,"  ",time()-st
        # Here, we perform realignment num_ali times
        all_ali_params = []
        for j in range(num_ali):
            if (xrng[0] == 0.0 and yrng[0] == 0.0):
                avet = ali2d_ras(class_data,
                                 randomize=True,
                                 ir=1,
                                 ou=radius,
                                 rs=1,
                                 step=1.0,
                                 dst=90.0,
                                 maxit=ite,
                                 check_mirror=True,
                                 FH=options.fl,
                                 FF=options.aa)
            else:
                avet = within_group_refinement(class_data, mask, True, 1,
                                               radius, 1, xrng, yrng, step,
                                               90.0, ite, options.fl,
                                               options.aa)
            ali_params = []
            for im in range(len(class_data)):
                alpha, sx, sy, mirror, scale = get_params2D(class_data[im])
                ali_params.extend([alpha, sx, sy, mirror])
            all_ali_params.append(ali_params)
        #aveList[i] = avet
        #print "  G  ",myid,"  ",time()-st
        del ali_params
        # We determine the stability of this group here.
        # stable_set contains all particles deemed stable, it is a list of list
        # each list has two elements, the first is the pixel error, the second is the image number
        # stable_set is sorted based on pixel error
        #from utilities import write_text_file
        #write_text_file(all_ali_params, "all_ali_params%03d.txt"%myid)
        stable_set, mir_stab_rate, average_pix_err = multi_align_stability(
            all_ali_params, 0.0, 10000.0, thld_err, False, 2 * radius + 1)
        #print "  H  ",myid,"  ",time()-st
        if (len(stable_set) > 5):
            stable_set_id = []
            members = []
            pix_err = []
            # First put the stable members into attr 'members' and 'pix_err'
            for s in stable_set:
                # s[1] - number in this subset
                stable_set_id.append(s[1])
                # the original image number
                members.append(proj_list[i][s[1]])
                pix_err.append(s[0])
            # Then put the unstable members into attr 'members' and 'pix_err'
            from sp_fundamentals import rot_shift2D
            avet.to_zero()
            if options.grouping == "GRP":
                aphi = 0.0
                atht = 0.0
                vphi = 0.0
                vtht = 0.0
            l = -1
            for j in range(len(proj_list[i])):
                #  Here it will only work if stable_set_id is sorted in the increasing number, see how l progresses
                if j in stable_set_id:
                    l += 1
                    avet += rot_shift2D(class_data[j], stable_set[l][2][0],
                                        stable_set[l][2][1],
                                        stable_set[l][2][2],
                                        stable_set[l][2][3])
                    if options.grouping == "GRP":
                        phi, theta, psi, sxs, sy_s = get_params_proj(
                            class_data[j])
                        if (theta > 90.0):
                            phi = (phi + 540.0) % 360.0
                            theta = 180.0 - theta
                        aphi += phi
                        atht += theta
                        vphi += phi * phi
                        vtht += theta * theta
                else:
                    members.append(proj_list[i][j])
                    pix_err.append(99999.99)
            aveList[i] = avet.copy()
            if l > 1:
                l += 1
                aveList[i] /= l
                if options.grouping == "GRP":
                    aphi /= l
                    atht /= l
                    vphi = (vphi - l * aphi * aphi) / l
                    vtht = (vtht - l * atht * atht) / l
                    from math import sqrt
                    refprojdir[i] = [
                        aphi, atht,
                        (sqrt(max(vphi, 0.0)) + sqrt(max(vtht, 0.0))) / 2.0
                    ]

            # Here more information has to be stored, PARTICULARLY WHAT IS THE REFERENCE DIRECTION
            aveList[i].set_attr('members', members)
            aveList[i].set_attr('refprojdir', refprojdir[i])
            aveList[i].set_attr('pixerr', pix_err)
        else:
            sxprint(" empty group ", i, refprojdir[i])
            aveList[i].set_attr('members', [-1])
            aveList[i].set_attr('refprojdir', refprojdir[i])
            aveList[i].set_attr('pixerr', [99999.])

    del class_data

    if myid == main_node:
        km = 0
        for i in range(number_of_proc):
            if i == main_node:
                for im in range(len(aveList)):
                    aveList[im].write_image(args[1], km)
                    km += 1
            else:
                nl = mpi_recv(1, MPI_INT, i, SPARX_MPI_TAG_UNIVERSAL,
                              MPI_COMM_WORLD)
                nl = int(nl[0])
                for im in range(nl):
                    ave = recv_EMData(i, im + i + 70000)
                    nm = mpi_recv(1, MPI_INT, i, SPARX_MPI_TAG_UNIVERSAL,
                                  MPI_COMM_WORLD)
                    nm = int(nm[0])
                    members = mpi_recv(nm, MPI_INT, i, SPARX_MPI_TAG_UNIVERSAL,
                                       MPI_COMM_WORLD)
                    ave.set_attr('members', list(map(int, members)))
                    members = mpi_recv(nm, MPI_FLOAT, i,
                                       SPARX_MPI_TAG_UNIVERSAL, MPI_COMM_WORLD)
                    ave.set_attr('pixerr', list(map(float, members)))
                    members = mpi_recv(3, MPI_FLOAT, i,
                                       SPARX_MPI_TAG_UNIVERSAL, MPI_COMM_WORLD)
                    ave.set_attr('refprojdir', list(map(float, members)))
                    ave.write_image(args[1], km)
                    km += 1
    else:
        mpi_send(len(aveList), 1, MPI_INT, main_node, SPARX_MPI_TAG_UNIVERSAL,
                 MPI_COMM_WORLD)
        for im in range(len(aveList)):
            send_EMData(aveList[im], main_node, im + myid + 70000)
            members = aveList[im].get_attr('members')
            mpi_send(len(members), 1, MPI_INT, main_node,
                     SPARX_MPI_TAG_UNIVERSAL, MPI_COMM_WORLD)
            mpi_send(members, len(members), MPI_INT, main_node,
                     SPARX_MPI_TAG_UNIVERSAL, MPI_COMM_WORLD)
            members = aveList[im].get_attr('pixerr')
            mpi_send(members, len(members), MPI_FLOAT, main_node,
                     SPARX_MPI_TAG_UNIVERSAL, MPI_COMM_WORLD)
            try:
                members = aveList[im].get_attr('refprojdir')
                mpi_send(members, 3, MPI_FLOAT, main_node,
                         SPARX_MPI_TAG_UNIVERSAL, MPI_COMM_WORLD)
            except:
                mpi_send([-999.0, -999.0, -999.0], 3, MPI_FLOAT, main_node,
                         SPARX_MPI_TAG_UNIVERSAL, MPI_COMM_WORLD)

    sp_global_def.BATCH = False
    mpi_barrier(MPI_COMM_WORLD)
Beispiel #6
0
def generate_helimic(refvol,
                     outdir,
                     pixel,
                     CTF=False,
                     Cs=2.0,
                     voltage=200.0,
                     ampcont=10.0,
                     nonoise=False,
                     rand_seed=14567):

    from sp_utilities import model_blank, model_gauss, model_gauss_noise, pad, get_im
    from random import random
    from sp_projection import prgs, prep_vol
    from sp_filter import filt_gaussl, filt_ctf
    from EMAN2 import EMAN2Ctf

    if os.path.exists(outdir):
        ERROR(
            "Output directory exists, please change the name and restart the program"
        )
        return

    os.mkdir(outdir)
    seed(rand_seed)
    Util.set_randnum_seed(rand_seed)
    angles = []
    for i in range(3):
        angles.append([0.0 + 60.0 * i, 90.0 - i * 5, 0.0, 0.0, 0.0])

    nangle = len(angles)

    volfts = get_im(refvol)
    nx = volfts.get_xsize()
    ny = volfts.get_ysize()
    nz = volfts.get_zsize()
    volfts, kbx, kby, kbz = prep_vol(volfts)
    iprj = 0
    width = 500
    xstart = 0
    ystart = 0

    for idef in range(3, 6):
        mic = model_blank(2048, 2048)
        #defocus = idef*0.2
        defocus = idef * 0.6  ##@ming
        if CTF:
            #ctf = EMAN2Ctf()
            #ctf.from_dict( {"defocus":defocus, "cs":Cs, "voltage":voltage, "apix":pixel, "ampcont":ampcont, "bfactor":0.0} )
            from sp_utilities import generate_ctf
            ctf = generate_ctf(
                [defocus, 2, 200, 1.84, 0.0, ampcont, defocus * 0.2, 80]
            )  ##@ming   the range of astigmatism amplitude is between 10 percent and 22 percent. 20 percent is a good choice.
        i = idef - 4
        for k in range(1):
            psi = 90 + 10 * i
            proj = prgs(
                volfts, kbz,
                [angles[idef - 3][0], angles[idef - 3][1], psi, 0.0, 0.0], kbx,
                kby)
            proj = Util.window(proj, 320, nz)
            mic += pad(proj, 2048, 2048, 1, 0.0, 750 * i, 20 * i, 0)

        if not nonoise: mic += model_gauss_noise(30.0, 2048, 2048)
        if CTF:
            #apply CTF
            mic = filt_ctf(mic, ctf)

        if not nonoise:
            mic += filt_gaussl(model_gauss_noise(17.5, 2048, 2048), 0.3)

        mic.write_image("%s/mic%1d.hdf" % (outdir, idef - 3), 0)
Beispiel #7
0
def ali2d_single_iter(
    data,
    numr,
    wr,
    cs,
    tavg,
    cnx,
    cny,
    xrng,
    yrng,
    step,
    nomirror=False,
    mode="F",
    CTF=False,
    random_method="",
    T=1.0,
    ali_params="xform.align2d",
    delta=0.0,
):
    """
		single iteration of 2D alignment using ormq
		if CTF = True, apply CTF to data (not to reference!)
	"""

    maxrin = numr[-1]  #  length
    ou = numr[-3]  #  maximum radius
    if random_method == "SCF":
        frotim = [sp_fundamentals.fft(tavg)]
        xrng = int(xrng + 0.5)
        yrng = int(yrng + 0.5)
        cimage = EMAN2_cppwrap.Util.Polar2Dm(sp_fundamentals.scf(tavg), cnx,
                                             cny, numr, mode)
        EMAN2_cppwrap.Util.Frngs(cimage, numr)
        EMAN2_cppwrap.Util.Applyws(cimage, numr, wr)
    else:
        # 2D alignment using rotational ccf in polar coords and quadratic interpolation
        cimage = EMAN2_cppwrap.Util.Polar2Dm(tavg, cnx, cny, numr, mode)
        EMAN2_cppwrap.Util.Frngs(cimage, numr)
        EMAN2_cppwrap.Util.Applyws(cimage, numr, wr)

    sx_sum = 0.0
    sy_sum = 0.0
    sxn = 0.0
    syn = 0.0
    mn = 0
    nope = 0
    mashi = cnx - ou - 2
    for im in range(len(data)):
        if CTF:
            # Apply CTF to image
            ctf_params = data[im].get_attr("ctf")
            ima = sp_filter.filt_ctf(data[im], ctf_params, True)
        else:
            ima = data[im]

        if random_method == "PCP":
            sxi = data[im][0][0].get_attr("sxi")
            syi = data[im][0][0].get_attr("syi")
            nx = ny = data[im][0][0].get_attr("inx")
        else:
            nx = ima.get_xsize()
            ny = ima.get_ysize()
            alpha, sx, sy, mirror, dummy = sp_utilities.get_params2D(
                data[im], ali_params)
            alpha, sx, sy, dummy = sp_utilities.combine_params2(
                alpha, sx, sy, mirror, 0.0, -cs[0], -cs[1], 0)
            alphai, sxi, syi, scalei = sp_utilities.inverse_transform2(
                alpha, sx, sy)
            #  introduce constraints on parameters to accomodate use of cs centering
            sxi = min(max(sxi, -mashi), mashi)
            syi = min(max(syi, -mashi), mashi)

        #  The search range procedure was adjusted for 3D searches, so since in 2D the order of operations is inverted, we have to invert ranges
        txrng = search_range(nx, ou, sxi, xrng, "ali2d_single_iter")
        txrng = [txrng[1], txrng[0]]
        tyrng = search_range(ny, ou, syi, yrng, "ali2d_single_iter")
        tyrng = [tyrng[1], tyrng[0]]
        # print im, "B",cnx,sxi,syi,txrng, tyrng
        # align current image to the reference
        if random_method == "SHC":
            """Multiline Comment0"""
            #  For shc combining of shifts is problematic as the image may randomly slide away and never come back.
            #  A possibility would be to reject moves that results in too large departure from the center.
            #  On the other hand, one cannot simply do searches around the proper center all the time,
            #    as if xr is decreased, the image cannot be brought back if the established shifts are further than new range
            olo = EMAN2_cppwrap.Util.shc(
                ima,
                [cimage],
                txrng,
                tyrng,
                step,
                -1.0,
                mode,
                numr,
                cnx + sxi,
                cny + syi,
                "c1",
            )
            ##olo = Util.shc(ima, [cimage], xrng, yrng, step, -1.0, mode, numr, cnx, cny, "c1")
            if data[im].get_attr("previousmax") < olo[5]:
                # [angt, sxst, syst, mirrort, peakt] = ormq(ima, cimage, xrng, yrng, step, mode, numr, cnx+sxi, cny+syi, delta)
                # print  angt, sxst, syst, mirrort, peakt,olo
                angt = olo[0]
                sxst = olo[1]
                syst = olo[2]
                mirrort = int(olo[3])
                # combine parameters and set them to the header, ignore previous angle and mirror
                [alphan, sxn, syn,
                 mn] = sp_utilities.combine_params2(0.0, -sxi, -syi, 0, angt,
                                                    sxst, syst, mirrort)
                sp_utilities.set_params2D(data[im],
                                          [alphan, sxn, syn, mn, 1.0],
                                          ali_params)
                ##set_params2D(data[im], [angt, sxst, syst, mirrort, 1.0], ali_params)
                data[im].set_attr("previousmax", olo[5])
            else:
                # Did not find a better peak, but we have to set shifted parameters, as the average shifted
                sp_utilities.set_params2D(data[im],
                                          [alpha, sx, sy, mirror, 1.0],
                                          ali_params)
                nope += 1
                mn = 0
                sxn = 0.0
                syn = 0.0
        elif random_method == "SCF":
            sxst, syst, iref, angt, mirrort, totpeak = multalign2d_scf(
                data[im], [cimage], frotim, numr, xrng, yrng, ou=ou)
            [alphan, sxn, syn,
             mn] = sp_utilities.combine_params2(0.0, -sxi, -syi, 0, angt, sxst,
                                                syst, mirrort)
            sp_utilities.set_params2D(data[im], [alphan, sxn, syn, mn, 1.0],
                                      ali_params)
        elif random_method == "PCP":
            [angt, sxst, syst, mirrort,
             peakt] = ormq_fast(data[im], cimage, txrng, tyrng, step, numr,
                                mode, delta)
            sxst = rings[0][0][0].get_attr("sxi")
            syst = rings[0][0][0].get_attr("syi")
            sp_global_def.sxprint(sxst, syst, sx, sy)
            dummy, sxs, sys, dummy = sp_utilities.inverse_transform2(
                -angt, sx + sxst, sy + syst)
            sp_utilities.set_params2D(data[im][0][0],
                                      [angt, sxs, sys, mirrort, 1.0],
                                      ali_params)
        else:
            if nomirror:
                [angt, sxst, syst, mirrort,
                 peakt] = ornq(ima, cimage, txrng, tyrng, step, mode, numr,
                               cnx + sxi, cny + syi)
            else:
                [angt, sxst, syst, mirrort, peakt] = ormq(
                    ima,
                    cimage,
                    txrng,
                    tyrng,
                    step,
                    mode,
                    numr,
                    cnx + sxi,
                    cny + syi,
                    delta,
                )
            # combine parameters and set them to the header, ignore previous angle and mirror
            [alphan, sxn, syn,
             mn] = sp_utilities.combine_params2(0.0, -sxi, -syi, 0, angt, sxst,
                                                syst, mirrort)
            sp_utilities.set_params2D(data[im], [alphan, sxn, syn, mn, 1.0],
                                      ali_params)

        if mn == 0:
            sx_sum += sxn
        else:
            sx_sum -= sxn
        sy_sum += syn

    return sx_sum, sy_sum, nope
Beispiel #8
0
def align2d_direct3(input_images,
                    refim,
                    xrng=1,
                    yrng=1,
                    psimax=180,
                    psistep=1,
                    ou=-1,
                    CTF=None):

    nx = input_images[0].get_xsize()
    if ou < 0:
        ou = old_div(nx, 2) - 1
    mask = sp_utilities.model_circle(ou, nx, nx)
    nk = int(old_div(psimax, psistep))
    nm = 2 * nk + 1
    nc = nk + 1
    refs = [None] * nm * 2
    for i in range(nm):
        temp = sp_fundamentals.rot_shift2D(refim, (i - nc) * psistep) * mask
        refs[2 * i] = [
            sp_fundamentals.fft(temp),
            sp_fundamentals.fft(sp_fundamentals.mirror(temp)),
        ]
        temp = sp_fundamentals.rot_shift2D(refim,
                                           (i - nc) * psistep + 180.0) * mask
        refs[2 * i + 1] = [
            sp_fundamentals.fft(temp),
            sp_fundamentals.fft(sp_fundamentals.mirror(temp)),
        ]
    del temp

    results = []
    mir = 0
    for image in input_images:
        if CTF:
            ims = sp_filter.filt_ctf(sp_fundamentals.fft(image),
                                     image.get_attr("ctf"))
        else:
            ims = sp_fundamentals.fft(image)
        ama = -1.0e23
        bang = 0.0
        bsx = 0.0
        bsy = 0.0
        for i in range(nm * 2):
            for mirror_flag in [0, 1]:
                c = sp_fundamentals.ccf(ims, refs[i][mirror_flag])
                w = EMAN2_cppwrap.Util.window(c, 2 * xrng + 1, 2 * yrng + 1)
                pp = sp_utilities.peak_search(w)[0]
                px = int(pp[4])
                py = int(pp[5])
                if pp[0] == 1.0 and px == 0 and py == 0:
                    pass  # XSH, YSH, PEAKV = 0.,0.,0.
                else:
                    ww = sp_utilities.model_blank(3, 3)
                    ux = int(pp[1])
                    uy = int(pp[2])
                    for k in range(3):
                        for l in range(3):
                            ww[k, l] = w[k + ux - 1, l + uy - 1]
                    XSH, YSH, PEAKV = parabl(ww)
                    # print i,pp[-1],XSH, YSH,px+XSH, py+YSH, PEAKV
                    if PEAKV > ama:
                        ama = PEAKV
                        bsx = px + round(XSH, 2)
                        bsy = py + round(YSH, 2)
                        bang = i
                        mir = mirror_flag
        # returned parameters have to be inverted
        bang = (old_div(bang, 2) - nc) * psistep + 180.0 * (bang % 2)
        bang, bsx, bsy, _ = sp_utilities.inverse_transform2(
            bang, (1 - 2 * mir) * bsx, bsy, mir)
        results.append([bang, bsx, bsy, mir, ama])
    return results
Beispiel #9
0
def main():
    from sp_utilities import get_input_from_string
    progname = os.path.basename(sys.argv[0])
    usage = progname + " stack output_average --radius=particle_radius --xr=xr --yr=yr --ts=ts --thld_err=thld_err --num_ali=num_ali --fl=fl --aa=aa --CTF --verbose --stables"
    parser = OptionParser(usage, version=SPARXVERSION)
    parser.add_option("--radius",
                      type="int",
                      default=-1,
                      help=" particle radius for alignment")
    parser.add_option(
        "--xr",
        type="string",
        default="2 1",
        help=
        "range for translation search in x direction, search is +/xr (default 2,1)"
    )
    parser.add_option(
        "--yr",
        type="string",
        default="-1",
        help=
        "range for translation search in y direction, search is +/yr (default = same as xr)"
    )
    parser.add_option(
        "--ts",
        type="string",
        default="1 0.5",
        help=
        "step size of the translation search in both directions, search is -xr, -xr+ts, 0, xr-ts, xr, can be fractional (default: 1,0.5)"
    )
    parser.add_option("--thld_err",
                      type="float",
                      default=0.7,
                      help="threshld of pixel error (default = 0.75)")
    parser.add_option(
        "--num_ali",
        type="int",
        default=5,
        help="number of alignments performed for stability (default = 5)")
    parser.add_option("--maxit",
                      type="int",
                      default=30,
                      help="number of iterations for each xr (default = 30)")
    parser.add_option(
        "--fl",
        type="float",
        default=0.45,
        help=
        "cut-off frequency of hyperbolic tangent low-pass Fourier filter (default = 0.3)"
    )
    parser.add_option(
        "--aa",
        type="float",
        default=0.2,
        help=
        "fall-off of hyperbolic tangent low-pass Fourier filter (default = 0.2)"
    )
    parser.add_option("--CTF",
                      action="store_true",
                      default=False,
                      help="Use CTF correction during the alignment ")
    parser.add_option("--verbose",
                      action="store_true",
                      default=False,
                      help="print individual pixel error (default = False)")
    parser.add_option(
        "--stables",
        action="store_true",
        default=False,
        help="output the stable particles number in file (default = False)")
    parser.add_option(
        "--method",
        type="string",
        default=" ",
        help="SHC (standard method is default when flag is ommitted)")

    (options, args) = parser.parse_args()

    if len(args) != 1 and len(args) != 2:
        sxprint("Usage: " + usage)
        sxprint("Please run \'" + progname + " -h\' for detailed options")
        ERROR(
            "Invalid number of parameters used. Please see usage information above."
        )
        return
    else:
        if sp_global_def.CACHE_DISABLE:
            from sp_utilities import disable_bdb_cache
            disable_bdb_cache()

        from sp_applications import within_group_refinement, ali2d_ras
        from sp_pixel_error import multi_align_stability
        from sp_utilities import write_text_file, write_text_row

        sp_global_def.BATCH = True

        xrng = get_input_from_string(options.xr)

        if options.yr == "-1":
            yrng = xrng
        else:
            yrng = get_input_from_string(options.yr)

        step = get_input_from_string(options.ts)

        class_data = EMData.read_images(args[0])

        nx = class_data[0].get_xsize()
        ou = options.radius
        num_ali = options.num_ali
        if ou == -1: ou = nx / 2 - 2
        from sp_utilities import model_circle, get_params2D, set_params2D
        mask = model_circle(ou, nx, nx)

        if options.CTF:
            from sp_filter import filt_ctf
            for im in range(len(class_data)):
                #  Flip phases
                class_data[im] = filt_ctf(class_data[im],
                                          class_data[im].get_attr("ctf"),
                                          binary=1)
        for im in class_data:
            im.set_attr("previousmax", -1.0e10)
            try:
                t = im.get_attr(
                    "xform.align2d")  # if they are there, no need to set them!
            except:
                try:
                    t = im.get_attr("xform.projection")
                    d = t.get_params("spider")
                    set_params2D(im, [0.0, -d["tx"], -d["ty"], 0, 1.0])
                except:
                    set_params2D(im, [0.0, 0.0, 0.0, 0, 1.0])
        all_ali_params = []

        for ii in range(num_ali):
            ali_params = []
            if options.verbose:
                ALPHA = []
                SX = []
                SY = []
                MIRROR = []
            if (xrng[0] == 0.0 and yrng[0] == 0.0):
                avet = ali2d_ras(class_data, randomize = True, ir = 1, ou = ou, rs = 1, step = 1.0, dst = 90.0, \
                  maxit = options.maxit, check_mirror = True, FH=options.fl, FF=options.aa)
            else:
                avet = within_group_refinement(class_data, mask, True, 1, ou, 1, xrng, yrng, step, 90.0, \
                  maxit = options.maxit, FH=options.fl, FF=options.aa, method = options.method)
                from sp_utilities import info
                #print "  avet  ",info(avet)
            for im in class_data:
                alpha, sx, sy, mirror, scale = get_params2D(im)
                ali_params.extend([alpha, sx, sy, mirror])
                if options.verbose:
                    ALPHA.append(alpha)
                    SX.append(sx)
                    SY.append(sy)
                    MIRROR.append(mirror)
            all_ali_params.append(ali_params)
            if options.verbose:
                write_text_file([ALPHA, SX, SY, MIRROR],
                                "ali_params_run_%d" % ii)
        """
		avet = class_data[0]
		from sp_utilities import read_text_file
		all_ali_params = []
		for ii in xrange(5):
			temp = read_text_file( "ali_params_run_%d"%ii,-1)
			uuu = []
			for k in xrange(len(temp[0])):
				uuu.extend([temp[0][k],temp[1][k],temp[2][k],temp[3][k]])
			all_ali_params.append(uuu)


		"""

        stable_set, mir_stab_rate, pix_err = multi_align_stability(
            all_ali_params, 0.0, 10000.0, options.thld_err, options.verbose,
            2 * ou + 1)
        sxprint("%4s %20s %20s %20s %30s %6.2f" %
                ("", "Size of set", "Size of stable set", "Mirror stab rate",
                 "Pixel error prior to pruning the set above threshold of",
                 options.thld_err))
        sxprint("Average stat: %10d %20d %20.2f   %15.2f" %
                (len(class_data), len(stable_set), mir_stab_rate, pix_err))
        if (len(stable_set) > 0):
            if options.stables:
                stab_mem = [[0, 0.0, 0] for j in range(len(stable_set))]
                for j in range(len(stable_set)):
                    stab_mem[j] = [int(stable_set[j][1]), stable_set[j][0], j]
                write_text_row(stab_mem, "stable_particles.txt")

            stable_set_id = []
            particle_pixerr = []
            for s in stable_set:
                stable_set_id.append(s[1])
                particle_pixerr.append(s[0])
            from sp_fundamentals import rot_shift2D
            avet.to_zero()
            l = -1
            sxprint("average parameters:  angle, x-shift, y-shift, mirror")
            for j in stable_set_id:
                l += 1
                sxprint(" %4d  %4d  %12.2f %12.2f %12.2f        %1d" %
                        (l, j, stable_set[l][2][0], stable_set[l][2][1],
                         stable_set[l][2][2], int(stable_set[l][2][3])))
                avet += rot_shift2D(class_data[j], stable_set[l][2][0],
                                    stable_set[l][2][1], stable_set[l][2][2],
                                    stable_set[l][2][3])
            avet /= (l + 1)
            avet.set_attr('members', stable_set_id)
            avet.set_attr('pix_err', pix_err)
            avet.set_attr('pixerr', particle_pixerr)
            avet.write_image(args[1])

        sp_global_def.BATCH = False