Ejemplo n.º 1
0
def helicalshiftali_MPI(stack,
                        maskfile=None,
                        maxit=100,
                        CTF=False,
                        snr=1.0,
                        Fourvar=False,
                        search_rng=-1):

    nproc = mpi.mpi_comm_size(mpi.MPI_COMM_WORLD)
    myid = mpi.mpi_comm_rank(mpi.MPI_COMM_WORLD)
    main_node = 0

    ftp = file_type(stack)

    if myid == main_node:
        print_begin_msg("helical-shiftali_MPI")

    max_iter = int(maxit)
    if (myid == main_node):
        infils = EMUtil.get_all_attributes(stack, "filament")
        ptlcoords = EMUtil.get_all_attributes(stack, 'ptcl_source_coord')
        filaments = ordersegments(infils, ptlcoords)
        total_nfils = len(filaments)
        inidl = [0] * total_nfils
        for i in range(total_nfils):
            inidl[i] = len(filaments[i])
        linidl = sum(inidl)
        nima = linidl
        tfilaments = []
        for i in range(total_nfils):
            tfilaments += filaments[i]
        del filaments
    else:
        total_nfils = 0
        linidl = 0
    total_nfils = bcast_number_to_all(total_nfils, source_node=main_node)
    if myid != main_node:
        inidl = [-1] * total_nfils
    inidl = bcast_list_to_all(inidl, myid, source_node=main_node)
    linidl = bcast_number_to_all(linidl, source_node=main_node)
    if myid != main_node:
        tfilaments = [-1] * linidl
    tfilaments = bcast_list_to_all(tfilaments, myid, source_node=main_node)
    filaments = []
    iendi = 0
    for i in range(total_nfils):
        isti = iendi
        iendi = isti + inidl[i]
        filaments.append(tfilaments[isti:iendi])
    del tfilaments, inidl

    if myid == main_node:
        print_msg("total number of filaments: %d" % total_nfils)
    if total_nfils < nproc:
        ERROR(
            'number of CPUs (%i) is larger than the number of filaments (%i), please reduce the number of CPUs used'
            % (nproc, total_nfils),
            myid=myid)

    #  balanced load
    temp = chunks_distribution([[len(filaments[i]), i]
                                for i in range(len(filaments))],
                               nproc)[myid:myid + 1][0]
    filaments = [filaments[temp[i][1]] for i in range(len(temp))]
    nfils = len(filaments)

    #filaments = [[0,1]]
    #print "filaments",filaments
    list_of_particles = []
    indcs = []
    k = 0
    for i in range(nfils):
        list_of_particles += filaments[i]
        k1 = k + len(filaments[i])
        indcs.append([k, k1])
        k = k1
    data = EMData.read_images(stack, list_of_particles)
    ldata = len(data)
    sxprint("ldata=", ldata)
    nx = data[0].get_xsize()
    ny = data[0].get_ysize()
    if maskfile == None:
        mrad = min(nx, ny) // 2 - 2
        mask = pad(model_blank(2 * mrad + 1, ny, 1, 1.0), nx, ny, 1, 0.0)
    else:
        mask = get_im(maskfile)

    # apply initial xform.align2d parameters stored in header
    init_params = []
    for im in range(ldata):
        t = data[im].get_attr('xform.align2d')
        init_params.append(t)
        p = t.get_params("2d")
        data[im] = rot_shift2D(data[im], p['alpha'], p['tx'], p['ty'],
                               p['mirror'], p['scale'])

    if CTF:
        from sp_filter import filt_ctf
        from sp_morphology import ctf_img
        ctf_abs_sum = EMData(nx, ny, 1, False)
        ctf_2_sum = EMData(nx, ny, 1, False)
    else:
        ctf_2_sum = None
        ctf_abs_sum = None

    from sp_utilities import info

    for im in range(ldata):
        data[im].set_attr('ID', list_of_particles[im])
        st = Util.infomask(data[im], mask, False)
        data[im] -= st[0]
        if CTF:
            ctf_params = data[im].get_attr("ctf")
            qctf = data[im].get_attr("ctf_applied")
            if qctf == 0:
                data[im] = filt_ctf(fft(data[im]), ctf_params)
                data[im].set_attr('ctf_applied', 1)
            elif qctf != 1:
                ERROR('Incorrectly set qctf flag', myid=myid)
            ctfimg = ctf_img(nx, ctf_params, ny=ny)
            Util.add_img2(ctf_2_sum, ctfimg)
            Util.add_img_abs(ctf_abs_sum, ctfimg)
        else:
            data[im] = fft(data[im])

    del list_of_particles

    if CTF:
        reduce_EMData_to_root(ctf_2_sum, myid, main_node)
        reduce_EMData_to_root(ctf_abs_sum, myid, main_node)
    if CTF:
        if myid != main_node:
            del ctf_2_sum
            del ctf_abs_sum
        else:
            temp = EMData(nx, ny, 1, False)
            tsnr = 1. / snr
            for i in range(0, nx + 2, 2):
                for j in range(ny):
                    temp.set_value_at(i, j, tsnr)
                    temp.set_value_at(i + 1, j, 0.0)
            #info(ctf_2_sum)
            Util.add_img(ctf_2_sum, temp)
            #info(ctf_2_sum)
            del temp

    total_iter = 0
    shift_x = [0.0] * ldata

    for Iter in range(max_iter):
        if myid == main_node:
            start_time = time()
            print_msg("Iteration #%4d\n" % (total_iter))
        total_iter += 1
        avg = EMData(nx, ny, 1, False)
        for im in range(ldata):
            Util.add_img(avg, fshift(data[im], shift_x[im]))

        reduce_EMData_to_root(avg, myid, main_node)

        if myid == main_node:
            if CTF: tavg = Util.divn_filter(avg, ctf_2_sum)
            else: tavg = Util.mult_scalar(avg, 1.0 / float(nima))
        else:
            tavg = model_blank(nx, ny)

        if Fourvar:
            bcast_EMData_to_all(tavg, myid, main_node)
            vav, rvar = varf2d_MPI(myid, data, tavg, mask, "a", CTF)

        if myid == main_node:
            if Fourvar:
                tavg = fft(Util.divn_img(fft(tavg), vav))
                vav_r = Util.pack_complex_to_real(vav)
            # normalize and mask tavg in real space
            tavg = fft(tavg)
            stat = Util.infomask(tavg, mask, False)
            tavg -= stat[0]
            Util.mul_img(tavg, mask)
            tavg.write_image("tavg.hdf", Iter)
            # For testing purposes: shift tavg to some random place and see if the centering is still correct
            #tavg = rot_shift3D(tavg,sx=3,sy=-4)

        if Fourvar: del vav
        bcast_EMData_to_all(tavg, myid, main_node)
        tavg = fft(tavg)

        sx_sum = 0.0
        nxc = nx // 2

        for ifil in range(nfils):
            """
			# Calculate filament average
			avg = EMData(nx, ny, 1, False)
			filnima = 0
			for im in xrange(indcs[ifil][0], indcs[ifil][1]):
				Util.add_img(avg, data[im])
				filnima += 1
			tavg = Util.mult_scalar(avg, 1.0/float(filnima))
			"""
            # Calculate 1D ccf between each segment and filament average
            nsegms = indcs[ifil][1] - indcs[ifil][0]
            ctx = [None] * nsegms
            pcoords = [None] * nsegms
            for im in range(indcs[ifil][0], indcs[ifil][1]):
                ctx[im - indcs[ifil][0]] = Util.window(ccf(tavg, data[im]), nx,
                                                       1)
                pcoords[im - indcs[ifil][0]] = data[im].get_attr(
                    'ptcl_source_coord')
                #ctx[im-indcs[ifil][0]].write_image("ctx.hdf",im-indcs[ifil][0])
                #print "  CTX  ",myid,im,Util.infomask(ctx[im-indcs[ifil][0]], None, True)
            # search for best x-shift
            cents = nsegms // 2

            dst = sqrt(
                max((pcoords[cents][0] - pcoords[0][0])**2 +
                    (pcoords[cents][1] - pcoords[0][1])**2,
                    (pcoords[cents][0] - pcoords[-1][0])**2 +
                    (pcoords[cents][1] - pcoords[-1][1])**2))
            maxincline = atan2(ny // 2 - 2 - float(search_rng), dst)
            kang = int(dst * tan(maxincline) + 0.5)
            #print  "  settings ",nsegms,cents,dst,search_rng,maxincline,kang

            # ## C code for alignment. @ming
            results = [0.0] * 3
            results = Util.helixshiftali(ctx, pcoords, nsegms, maxincline,
                                         kang, search_rng, nxc)
            sib = int(results[0])
            bang = results[1]
            qm = results[2]
            #print qm, sib, bang

            # qm = -1.e23
            #
            # 			for six in xrange(-search_rng, search_rng+1,1):
            # 				q0 = ctx[cents].get_value_at(six+nxc)
            # 				for incline in xrange(kang+1):
            # 					qt = q0
            # 					qu = q0
            # 					if(kang>0):  tang = tan(maxincline/kang*incline)
            # 					else:        tang = 0.0
            # 					for kim in xrange(cents+1,nsegms):
            # 						dst = sqrt((pcoords[cents][0] - pcoords[kim][0])**2 + (pcoords[cents][1] - pcoords[kim][1])**2)
            # 						xl = dst*tang+six+nxc
            # 						ixl = int(xl)
            # 						dxl = xl - ixl
            # 						#print "  A  ", ifil,six,incline,kim,xl,ixl,dxl
            # 						qt += (1.0-dxl)*ctx[kim].get_value_at(ixl) + dxl*ctx[kim].get_value_at(ixl+1)
            # 						xl = -dst*tang+six+nxc
            # 						ixl = int(xl)
            # 						dxl = xl - ixl
            # 						qu += (1.0-dxl)*ctx[kim].get_value_at(ixl) + dxl*ctx[kim].get_value_at(ixl+1)
            # 					for kim in xrange(cents):
            # 						dst = sqrt((pcoords[cents][0] - pcoords[kim][0])**2 + (pcoords[cents][1] - pcoords[kim][1])**2)
            # 						xl = -dst*tang+six+nxc
            # 						ixl = int(xl)
            # 						dxl = xl - ixl
            # 						qt += (1.0-dxl)*ctx[kim].get_value_at(ixl) + dxl*ctx[kim].get_value_at(ixl+1)
            # 						xl =  dst*tang+six+nxc
            # 						ixl = int(xl)
            # 						dxl = xl - ixl
            # 						qu += (1.0-dxl)*ctx[kim].get_value_at(ixl) + dxl*ctx[kim].get_value_at(ixl+1)
            # 					if( qt > qm ):
            # 						qm = qt
            # 						sib = six
            # 						bang = tang
            # 					if( qu > qm ):
            # 						qm = qu
            # 						sib = six
            # 						bang = -tang
            #if incline == 0:  print  "incline = 0  ",six,tang,qt,qu
            #print qm,six,sib,bang
            #print " got results   ",indcs[ifil][0], indcs[ifil][1], ifil,myid,qm,sib,tang,bang,len(ctx),Util.infomask(ctx[0], None, True)
            for im in range(indcs[ifil][0], indcs[ifil][1]):
                kim = im - indcs[ifil][0]
                dst = sqrt((pcoords[cents][0] - pcoords[kim][0])**2 +
                           (pcoords[cents][1] - pcoords[kim][1])**2)
                if (kim < cents): xl = -dst * bang + sib
                else: xl = dst * bang + sib
                shift_x[im] = xl

            # Average shift
            sx_sum += shift_x[indcs[ifil][0] + cents]

        # #print myid,sx_sum,total_nfils
        sx_sum = mpi.mpi_reduce(sx_sum, 1, mpi.MPI_FLOAT, mpi.MPI_SUM,
                                main_node, mpi.MPI_COMM_WORLD)
        if myid == main_node:
            sx_sum = float(sx_sum[0]) / total_nfils
            print_msg("Average shift  %6.2f\n" % (sx_sum))
        else:
            sx_sum = 0.0
        sx_sum = 0.0
        sx_sum = bcast_number_to_all(sx_sum, source_node=main_node)
        for im in range(ldata):
            shift_x[im] -= sx_sum
            #print  "   %3d  %6.3f"%(im,shift_x[im])
        #exit()

    # combine shifts found with the original parameters
    for im in range(ldata):
        t1 = Transform()
        ##import random
        ##shix=random.randint(-10, 10)
        ##t1.set_params({"type":"2D","tx":shix})
        t1.set_params({"type": "2D", "tx": shift_x[im]})
        # combine t0 and t1
        tt = t1 * init_params[im]
        data[im].set_attr("xform.align2d", tt)
    # write out headers and STOP, under MPI writing has to be done sequentially
    mpi.mpi_barrier(mpi.MPI_COMM_WORLD)
    par_str = ["xform.align2d", "ID"]
    if myid == main_node:
        from sp_utilities import file_type
        if (file_type(stack) == "bdb"):
            from sp_utilities import recv_attr_dict_bdb
            recv_attr_dict_bdb(main_node, stack, data, par_str, 0, ldata,
                               nproc)
        else:
            from sp_utilities import recv_attr_dict
            recv_attr_dict(main_node, stack, data, par_str, 0, ldata, nproc)
    else:
        send_attr_dict(main_node, data, par_str, 0, ldata)
    if myid == main_node: print_end_msg("helical-shiftali_MPI")
Ejemplo n.º 2
0
def shiftali_MPI(stack,
                 maskfile=None,
                 maxit=100,
                 CTF=False,
                 snr=1.0,
                 Fourvar=False,
                 search_rng=-1,
                 oneDx=False,
                 search_rng_y=-1):

    number_of_proc = mpi.mpi_comm_size(mpi.MPI_COMM_WORLD)
    myid = mpi.mpi_comm_rank(mpi.MPI_COMM_WORLD)
    main_node = 0

    ftp = file_type(stack)

    if myid == main_node:
        print_begin_msg("shiftali_MPI")

    max_iter = int(maxit)

    if myid == main_node:
        if ftp == "bdb":
            from EMAN2db import db_open_dict
            dummy = db_open_dict(stack, True)
        nima = EMUtil.get_image_count(stack)
    else:
        nima = 0
    nima = bcast_number_to_all(nima, source_node=main_node)
    list_of_particles = list(range(nima))

    image_start, image_end = MPI_start_end(nima, number_of_proc, myid)
    list_of_particles = list_of_particles[image_start:image_end]

    # read nx and ctf_app (if CTF) and broadcast to all nodes
    if myid == main_node:
        ima = EMData()
        ima.read_image(stack, list_of_particles[0], True)
        nx = ima.get_xsize()
        ny = ima.get_ysize()
        if CTF: ctf_app = ima.get_attr_default('ctf_applied', 2)
        del ima
    else:
        nx = 0
        ny = 0
        if CTF: ctf_app = 0
    nx = bcast_number_to_all(nx, source_node=main_node)
    ny = bcast_number_to_all(ny, source_node=main_node)
    if CTF:
        ctf_app = bcast_number_to_all(ctf_app, source_node=main_node)
        if ctf_app > 0:
            ERROR("data cannot be ctf-applied", myid=myid)

    if maskfile == None:
        mrad = min(nx, ny)
        mask = model_circle(mrad // 2 - 2, nx, ny)
    else:
        mask = get_im(maskfile)

    if CTF:
        from sp_filter import filt_ctf
        from sp_morphology import ctf_img
        ctf_abs_sum = EMData(nx, ny, 1, False)
        ctf_2_sum = EMData(nx, ny, 1, False)
    else:
        ctf_2_sum = None

    from sp_global_def import CACHE_DISABLE
    if CACHE_DISABLE:
        data = EMData.read_images(stack, list_of_particles)
    else:
        for i in range(number_of_proc):
            if myid == i:
                data = EMData.read_images(stack, list_of_particles)
            if ftp == "bdb": mpi.mpi_barrier(mpi.MPI_COMM_WORLD)

    for im in range(len(data)):
        data[im].set_attr('ID', list_of_particles[im])
        st = Util.infomask(data[im], mask, False)
        data[im] -= st[0]
        if CTF:
            ctf_params = data[im].get_attr("ctf")
            ctfimg = ctf_img(nx, ctf_params, ny=ny)
            Util.add_img2(ctf_2_sum, ctfimg)
            Util.add_img_abs(ctf_abs_sum, ctfimg)

    if CTF:
        reduce_EMData_to_root(ctf_2_sum, myid, main_node)
        reduce_EMData_to_root(ctf_abs_sum, myid, main_node)
    else:
        ctf_2_sum = None
    if CTF:
        if myid != main_node:
            del ctf_2_sum
            del ctf_abs_sum
        else:
            temp = EMData(nx, ny, 1, False)
            for i in range(0, nx, 2):
                for j in range(ny):
                    temp.set_value_at(i, j, snr)
            Util.add_img(ctf_2_sum, temp)
            del temp

    total_iter = 0

    # apply initial xform.align2d parameters stored in header
    init_params = []
    for im in range(len(data)):
        t = data[im].get_attr('xform.align2d')
        init_params.append(t)
        p = t.get_params("2d")
        data[im] = rot_shift2D(data[im],
                               p['alpha'],
                               sx=p['tx'],
                               sy=p['ty'],
                               mirror=p['mirror'],
                               scale=p['scale'])

    # fourier transform all images, and apply ctf if CTF
    for im in range(len(data)):
        if CTF:
            ctf_params = data[im].get_attr("ctf")
            data[im] = filt_ctf(fft(data[im]), ctf_params)
        else:
            data[im] = fft(data[im])

    sx_sum = 0
    sy_sum = 0
    sx_sum_total = 0
    sy_sum_total = 0
    shift_x = [0.0] * len(data)
    shift_y = [0.0] * len(data)
    ishift_x = [0.0] * len(data)
    ishift_y = [0.0] * len(data)

    for Iter in range(max_iter):
        if myid == main_node:
            start_time = time()
            print_msg("Iteration #%4d\n" % (total_iter))
        total_iter += 1
        avg = EMData(nx, ny, 1, False)
        for im in data:
            Util.add_img(avg, im)

        reduce_EMData_to_root(avg, myid, main_node)

        if myid == main_node:
            if CTF:
                tavg = Util.divn_filter(avg, ctf_2_sum)
            else:
                tavg = Util.mult_scalar(avg, 1.0 / float(nima))
        else:
            tavg = EMData(nx, ny, 1, False)

        if Fourvar:
            bcast_EMData_to_all(tavg, myid, main_node)
            vav, rvar = varf2d_MPI(myid, data, tavg, mask, "a", CTF)

        if myid == main_node:
            if Fourvar:
                tavg = fft(Util.divn_img(fft(tavg), vav))
                vav_r = Util.pack_complex_to_real(vav)

            # normalize and mask tavg in real space
            tavg = fft(tavg)
            stat = Util.infomask(tavg, mask, False)
            tavg -= stat[0]
            Util.mul_img(tavg, mask)
            # For testing purposes: shift tavg to some random place and see if the centering is still correct
            #tavg = rot_shift3D(tavg,sx=3,sy=-4)
            tavg = fft(tavg)

        if Fourvar: del vav
        bcast_EMData_to_all(tavg, myid, main_node)

        sx_sum = 0
        sy_sum = 0
        if search_rng > 0: nwx = 2 * search_rng + 1
        else: nwx = nx

        if search_rng_y > 0: nwy = 2 * search_rng_y + 1
        else: nwy = ny

        not_zero = 0
        for im in range(len(data)):
            if oneDx:
                ctx = Util.window(ccf(data[im], tavg), nwx, 1)
                p1 = peak_search(ctx)
                p1_x = -int(p1[0][3])
                ishift_x[im] = p1_x
                sx_sum += p1_x
            else:
                p1 = peak_search(Util.window(ccf(data[im], tavg), nwx, nwy))
                p1_x = -int(p1[0][4])
                p1_y = -int(p1[0][5])
                ishift_x[im] = p1_x
                ishift_y[im] = p1_y
                sx_sum += p1_x
                sy_sum += p1_y

            if not_zero == 0:
                if (not (ishift_x[im] == 0.0)) or (not (ishift_y[im] == 0.0)):
                    not_zero = 1

        sx_sum = mpi.mpi_reduce(sx_sum, 1, mpi.MPI_INT, mpi.MPI_SUM, main_node,
                                mpi.MPI_COMM_WORLD)

        if not oneDx:
            sy_sum = mpi.mpi_reduce(sy_sum, 1, mpi.MPI_INT, mpi.MPI_SUM,
                                    main_node, mpi.MPI_COMM_WORLD)

        if myid == main_node:
            sx_sum_total = int(sx_sum[0])
            if not oneDx:
                sy_sum_total = int(sy_sum[0])
        else:
            sx_sum_total = 0
            sy_sum_total = 0

        sx_sum_total = bcast_number_to_all(sx_sum_total, source_node=main_node)

        if not oneDx:
            sy_sum_total = bcast_number_to_all(sy_sum_total,
                                               source_node=main_node)

        sx_ave = round(float(sx_sum_total) / nima)
        sy_ave = round(float(sy_sum_total) / nima)
        for im in range(len(data)):
            p1_x = ishift_x[im] - sx_ave
            p1_y = ishift_y[im] - sy_ave
            params2 = {
                "filter_type": Processor.fourier_filter_types.SHIFT,
                "x_shift": p1_x,
                "y_shift": p1_y,
                "z_shift": 0.0
            }
            data[im] = Processor.EMFourierFilter(data[im], params2)
            shift_x[im] += p1_x
            shift_y[im] += p1_y
        # stop if all shifts are zero
        not_zero = mpi.mpi_reduce(not_zero, 1, mpi.MPI_INT, mpi.MPI_SUM,
                                  main_node, mpi.MPI_COMM_WORLD)
        if myid == main_node:
            not_zero_all = int(not_zero[0])
        else:
            not_zero_all = 0
        not_zero_all = bcast_number_to_all(not_zero_all, source_node=main_node)

        if myid == main_node:
            print_msg("Time of iteration = %12.2f\n" % (time() - start_time))
            start_time = time()

        if not_zero_all == 0: break

    #for im in xrange(len(data)): data[im] = fft(data[im])  This should not be required as only header information is used
    # combine shifts found with the original parameters
    for im in range(len(data)):
        t0 = init_params[im]
        t1 = Transform()
        t1.set_params({
            "type": "2D",
            "alpha": 0,
            "scale": t0.get_scale(),
            "mirror": 0,
            "tx": shift_x[im],
            "ty": shift_y[im]
        })
        # combine t0 and t1
        tt = t1 * t0
        data[im].set_attr("xform.align2d", tt)

    # write out headers and STOP, under MPI writing has to be done sequentially
    mpi.mpi_barrier(mpi.MPI_COMM_WORLD)
    par_str = ["xform.align2d", "ID"]
    if myid == main_node:
        from sp_utilities import file_type
        if (file_type(stack) == "bdb"):
            from sp_utilities import recv_attr_dict_bdb
            recv_attr_dict_bdb(main_node, stack, data, par_str, image_start,
                               image_end, number_of_proc)
        else:
            from sp_utilities import recv_attr_dict
            recv_attr_dict(main_node, stack, data, par_str, image_start,
                           image_end, number_of_proc)

    else:
        send_attr_dict(main_node, data, par_str, image_start, image_end)
    if myid == main_node: print_end_msg("shiftali_MPI")
Ejemplo n.º 3
0
def align2d_scf(image, refim, xrng=-1, yrng=-1, ou=-1):
    nx = image.get_xsize()
    ny = image.get_xsize()
    if ou < 0:
        ou = min(old_div(nx, 2) - 1, old_div(ny, 2) - 1)
    if yrng < 0:
        yrng = xrng
    if ou < 2:
        sp_global_def.ERROR("Radius of the object (ou) has to be given",
                            "align2d_scf", 1)
    sci = sp_fundamentals.scf(image)
    scr = sp_fundamentals.scf(refim)
    first_ring = 1

    # alpha1, sxs, sys, mirr, peak1 = align2d_no_mirror(scf(image), scr, last_ring=ou, mode="H")
    # alpha2, sxs, sys, mirr, peak2 = align2d_no_mirror(scf(mirror(image)), scr, last_ring=ou, mode="H")
    # alpha1, sxs, sys, mirr, peak1 = align2d_no_mirror(sci, scr, first_ring = 1, last_ring=ou, mode="H")
    # alpha2, sxs, sys, mirr, peak2 = align2d_no_mirror(mirror(sci), scr,  first_ring = 1, last_ring=ou, mode="H")

    # center in SPIDER convention
    cnx = old_div(nx, 2) + 1
    cny = old_div(ny, 2) + 1
    # precalculate rings
    numr = Numrinit(first_ring, ou, 1, "H")
    wr = ringwe(numr, "H")
    crefim = EMAN2_cppwrap.Util.Polar2Dm(scr, cnx, cny, numr, "H")
    EMAN2_cppwrap.Util.Frngs(crefim, numr)
    EMAN2_cppwrap.Util.Applyws(crefim, numr, wr)
    alpha1, sxs, sys, mirr, peak1 = ornq(sci, crefim, [0.0], [0.0], 1.0, "H",
                                         numr, cnx, cny)
    alpha2, sxs, sys, mirr, peak2 = ornq(sp_fundamentals.mirror(sci), crefim,
                                         [0.0], [0.0], 1.0, "H", numr, cnx,
                                         cny)

    if peak1 > peak2:
        mirr = 0
        alpha = alpha1
    else:
        mirr = 1
        alpha = -alpha2
    nrx = min(2 * (xrng + 1) + 1, ((old_div((nx - 2), 2)) * 2 + 1))
    nry = min(2 * (yrng + 1) + 1, ((old_div((ny - 2), 2)) * 2 + 1))
    frotim = sp_fundamentals.fft(refim)
    ccf1 = EMAN2_cppwrap.Util.window(
        sp_fundamentals.ccf(
            sp_fundamentals.rot_shift2D(image, alpha, 0.0, 0.0, mirr), frotim),
        nrx,
        nry,
    )
    p1 = sp_utilities.peak_search(ccf1)

    ccf2 = EMAN2_cppwrap.Util.window(
        sp_fundamentals.ccf(
            sp_fundamentals.rot_shift2D(image, alpha + 180.0, 0.0, 0.0, mirr),
            frotim),
        nrx,
        nry,
    )
    p2 = sp_utilities.peak_search(ccf2)
    # print p1
    # print p2

    peak_val1 = p1[0][0]
    peak_val2 = p2[0][0]

    if peak_val1 > peak_val2:
        sxs = -p1[0][4]
        sys = -p1[0][5]
        cx = int(p1[0][1])
        cy = int(p1[0][2])
        peak = peak_val1
    else:
        alpha += 180.0
        sxs = -p2[0][4]
        sys = -p2[0][5]
        peak = peak_val2
        cx = int(p2[0][1])
        cy = int(p2[0][2])
        ccf1 = ccf2
    # print cx,cy
    z = sp_utilities.model_blank(3, 3)
    for i in range(3):
        for j in range(3):
            z[i, j] = ccf1[i + cx - 1, j + cy - 1]
    # print  ccf1[cx,cy],z[1,1]
    XSH, YSH, PEAKV = parabl(z)
    # print sxs, sys, XSH, YSH, PEAKV, peak
    if mirr == 1:
        sx = -sxs + XSH
    else:
        sx = sxs - XSH
    return alpha, sx, sys - YSH, mirr, PEAKV
Ejemplo n.º 4
0
def multalign2d_scf(image, refrings, frotim, numr, xrng=-1, yrng=-1, ou=-1):

    nx = image.get_xsize()
    ny = image.get_xsize()
    if ou < 0:
        ou = min(old_div(nx, 2) - 1, old_div(ny, 2) - 1)
    if yrng < 0:
        yrng = xrng
    if ou < 2:
        sp_global_def.ERROR("Radius of the object (ou) has to be given",
                            "align2d_scf", 1)
    sci = sp_fundamentals.scf(image)
    first_ring = 1
    # center in SPIDER convention
    cnx = old_div(nx, 2) + 1
    cny = old_div(ny, 2) + 1

    cimage = EMAN2_cppwrap.Util.Polar2Dm(sci, cnx, cny, numr, "H")
    EMAN2_cppwrap.Util.Frngs(cimage, numr)
    mimage = EMAN2_cppwrap.Util.Polar2Dm(sp_fundamentals.mirror(sci), cnx, cny,
                                         numr, "H")
    EMAN2_cppwrap.Util.Frngs(mimage, numr)

    nrx = min(2 * (xrng + 1) + 1, ((old_div((nx - 2), 2)) * 2 + 1))
    nry = min(2 * (yrng + 1) + 1, ((old_div((ny - 2), 2)) * 2 + 1))

    totpeak = -1.0e23

    for iki in range(len(refrings)):
        # print  "TEMPLATE  ",iki
        #  Find angle
        retvals = EMAN2_cppwrap.Util.Crosrng_e(refrings[iki], cimage, numr, 0,
                                               0.0)
        alpha1 = ang_n(retvals["tot"], "H", numr[-1])
        peak1 = retvals["qn"]
        retvals = EMAN2_cppwrap.Util.Crosrng_e(refrings[iki], mimage, numr, 0,
                                               0.0)
        alpha2 = ang_n(retvals["tot"], "H", numr[-1])
        peak2 = retvals["qn"]
        # print  alpha1, peak1
        # print  alpha2, peak2

        if peak1 > peak2:
            mirr = 0
            alpha = alpha1
        else:
            mirr = 1
            alpha = -alpha2

        ccf1 = EMAN2_cppwrap.Util.window(
            sp_fundamentals.ccf(
                sp_fundamentals.rot_shift2D(image, alpha, 0.0, 0.0, mirr),
                frotim[iki]),
            nrx,
            nry,
        )
        p1 = sp_utilities.peak_search(ccf1)

        ccf2 = EMAN2_cppwrap.Util.window(
            sp_fundamentals.ccf(
                sp_fundamentals.rot_shift2D(image, alpha + 180.0, 0.0, 0.0,
                                            mirr),
                frotim[iki],
            ),
            nrx,
            nry,
        )
        p2 = sp_utilities.peak_search(ccf2)
        # print p1
        # print p2

        peak_val1 = p1[0][0]
        peak_val2 = p2[0][0]

        if peak_val1 > peak_val2:
            sxs = -p1[0][4]
            sys = -p1[0][5]
            cx = int(p1[0][1])
            cy = int(p1[0][2])
            peak = peak_val1
        else:
            alpha += 180.0
            sxs = -p2[0][4]
            sys = -p2[0][5]
            peak = peak_val2
            cx = int(p2[0][1])
            cy = int(p2[0][2])
            ccf1 = ccf2
        # print cx,cy
        z = sp_utilities.model_blank(3, 3)
        for i in range(3):
            for j in range(3):
                z[i, j] = ccf1[i + cx - 1, j + cy - 1]
        # print  ccf1[cx,cy],z[1,1]
        XSH, YSH, PEAKV = parabl(z)
        # print  PEAKV
        if PEAKV > totpeak:
            totpeak = PEAKV
            iref = iki
            if mirr == 1:
                sx = -sxs + XSH
            else:
                sx = sxs - XSH
            sy = sys - YSH
            talpha = alpha
            tmirr = mirr
            # print "BETTER",sx,sy,iref,talpha,tmirr,totpeak
            # return alpha, sx, sys-YSH, mirr, PEAKV
    return sx, sy, iref, talpha, tmirr, totpeak
Ejemplo n.º 5
0
def align2d_direct3(input_images,
                    refim,
                    xrng=1,
                    yrng=1,
                    psimax=180,
                    psistep=1,
                    ou=-1,
                    CTF=None):

    nx = input_images[0].get_xsize()
    if ou < 0:
        ou = old_div(nx, 2) - 1
    mask = sp_utilities.model_circle(ou, nx, nx)
    nk = int(old_div(psimax, psistep))
    nm = 2 * nk + 1
    nc = nk + 1
    refs = [None] * nm * 2
    for i in range(nm):
        temp = sp_fundamentals.rot_shift2D(refim, (i - nc) * psistep) * mask
        refs[2 * i] = [
            sp_fundamentals.fft(temp),
            sp_fundamentals.fft(sp_fundamentals.mirror(temp)),
        ]
        temp = sp_fundamentals.rot_shift2D(refim,
                                           (i - nc) * psistep + 180.0) * mask
        refs[2 * i + 1] = [
            sp_fundamentals.fft(temp),
            sp_fundamentals.fft(sp_fundamentals.mirror(temp)),
        ]
    del temp

    results = []
    mir = 0
    for image in input_images:
        if CTF:
            ims = sp_filter.filt_ctf(sp_fundamentals.fft(image),
                                     image.get_attr("ctf"))
        else:
            ims = sp_fundamentals.fft(image)
        ama = -1.0e23
        bang = 0.0
        bsx = 0.0
        bsy = 0.0
        for i in range(nm * 2):
            for mirror_flag in [0, 1]:
                c = sp_fundamentals.ccf(ims, refs[i][mirror_flag])
                w = EMAN2_cppwrap.Util.window(c, 2 * xrng + 1, 2 * yrng + 1)
                pp = sp_utilities.peak_search(w)[0]
                px = int(pp[4])
                py = int(pp[5])
                if pp[0] == 1.0 and px == 0 and py == 0:
                    pass  # XSH, YSH, PEAKV = 0.,0.,0.
                else:
                    ww = sp_utilities.model_blank(3, 3)
                    ux = int(pp[1])
                    uy = int(pp[2])
                    for k in range(3):
                        for l in range(3):
                            ww[k, l] = w[k + ux - 1, l + uy - 1]
                    XSH, YSH, PEAKV = parabl(ww)
                    # print i,pp[-1],XSH, YSH,px+XSH, py+YSH, PEAKV
                    if PEAKV > ama:
                        ama = PEAKV
                        bsx = px + round(XSH, 2)
                        bsy = py + round(YSH, 2)
                        bang = i
                        mir = mirror_flag
        # returned parameters have to be inverted
        bang = (old_div(bang, 2) - nc) * psistep + 180.0 * (bang % 2)
        bang, bsx, bsy, _ = sp_utilities.inverse_transform2(
            bang, (1 - 2 * mir) * bsx, bsy, mir)
        results.append([bang, bsx, bsy, mir, ama])
    return results