def helicalshiftali_MPI(stack, maskfile=None, maxit=100, CTF=False, snr=1.0, Fourvar=False, search_rng=-1): nproc = mpi.mpi_comm_size(mpi.MPI_COMM_WORLD) myid = mpi.mpi_comm_rank(mpi.MPI_COMM_WORLD) main_node = 0 ftp = file_type(stack) if myid == main_node: print_begin_msg("helical-shiftali_MPI") max_iter = int(maxit) if (myid == main_node): infils = EMUtil.get_all_attributes(stack, "filament") ptlcoords = EMUtil.get_all_attributes(stack, 'ptcl_source_coord') filaments = ordersegments(infils, ptlcoords) total_nfils = len(filaments) inidl = [0] * total_nfils for i in range(total_nfils): inidl[i] = len(filaments[i]) linidl = sum(inidl) nima = linidl tfilaments = [] for i in range(total_nfils): tfilaments += filaments[i] del filaments else: total_nfils = 0 linidl = 0 total_nfils = bcast_number_to_all(total_nfils, source_node=main_node) if myid != main_node: inidl = [-1] * total_nfils inidl = bcast_list_to_all(inidl, myid, source_node=main_node) linidl = bcast_number_to_all(linidl, source_node=main_node) if myid != main_node: tfilaments = [-1] * linidl tfilaments = bcast_list_to_all(tfilaments, myid, source_node=main_node) filaments = [] iendi = 0 for i in range(total_nfils): isti = iendi iendi = isti + inidl[i] filaments.append(tfilaments[isti:iendi]) del tfilaments, inidl if myid == main_node: print_msg("total number of filaments: %d" % total_nfils) if total_nfils < nproc: ERROR( 'number of CPUs (%i) is larger than the number of filaments (%i), please reduce the number of CPUs used' % (nproc, total_nfils), myid=myid) # balanced load temp = chunks_distribution([[len(filaments[i]), i] for i in range(len(filaments))], nproc)[myid:myid + 1][0] filaments = [filaments[temp[i][1]] for i in range(len(temp))] nfils = len(filaments) #filaments = [[0,1]] #print "filaments",filaments list_of_particles = [] indcs = [] k = 0 for i in range(nfils): list_of_particles += filaments[i] k1 = k + len(filaments[i]) indcs.append([k, k1]) k = k1 data = EMData.read_images(stack, list_of_particles) ldata = len(data) sxprint("ldata=", ldata) nx = data[0].get_xsize() ny = data[0].get_ysize() if maskfile == None: mrad = min(nx, ny) // 2 - 2 mask = pad(model_blank(2 * mrad + 1, ny, 1, 1.0), nx, ny, 1, 0.0) else: mask = get_im(maskfile) # apply initial xform.align2d parameters stored in header init_params = [] for im in range(ldata): t = data[im].get_attr('xform.align2d') init_params.append(t) p = t.get_params("2d") data[im] = rot_shift2D(data[im], p['alpha'], p['tx'], p['ty'], p['mirror'], p['scale']) if CTF: from sp_filter import filt_ctf from sp_morphology import ctf_img ctf_abs_sum = EMData(nx, ny, 1, False) ctf_2_sum = EMData(nx, ny, 1, False) else: ctf_2_sum = None ctf_abs_sum = None from sp_utilities import info for im in range(ldata): data[im].set_attr('ID', list_of_particles[im]) st = Util.infomask(data[im], mask, False) data[im] -= st[0] if CTF: ctf_params = data[im].get_attr("ctf") qctf = data[im].get_attr("ctf_applied") if qctf == 0: data[im] = filt_ctf(fft(data[im]), ctf_params) data[im].set_attr('ctf_applied', 1) elif qctf != 1: ERROR('Incorrectly set qctf flag', myid=myid) ctfimg = ctf_img(nx, ctf_params, ny=ny) Util.add_img2(ctf_2_sum, ctfimg) Util.add_img_abs(ctf_abs_sum, ctfimg) else: data[im] = fft(data[im]) del list_of_particles if CTF: reduce_EMData_to_root(ctf_2_sum, myid, main_node) reduce_EMData_to_root(ctf_abs_sum, myid, main_node) if CTF: if myid != main_node: del ctf_2_sum del ctf_abs_sum else: temp = EMData(nx, ny, 1, False) tsnr = 1. / snr for i in range(0, nx + 2, 2): for j in range(ny): temp.set_value_at(i, j, tsnr) temp.set_value_at(i + 1, j, 0.0) #info(ctf_2_sum) Util.add_img(ctf_2_sum, temp) #info(ctf_2_sum) del temp total_iter = 0 shift_x = [0.0] * ldata for Iter in range(max_iter): if myid == main_node: start_time = time() print_msg("Iteration #%4d\n" % (total_iter)) total_iter += 1 avg = EMData(nx, ny, 1, False) for im in range(ldata): Util.add_img(avg, fshift(data[im], shift_x[im])) reduce_EMData_to_root(avg, myid, main_node) if myid == main_node: if CTF: tavg = Util.divn_filter(avg, ctf_2_sum) else: tavg = Util.mult_scalar(avg, 1.0 / float(nima)) else: tavg = model_blank(nx, ny) if Fourvar: bcast_EMData_to_all(tavg, myid, main_node) vav, rvar = varf2d_MPI(myid, data, tavg, mask, "a", CTF) if myid == main_node: if Fourvar: tavg = fft(Util.divn_img(fft(tavg), vav)) vav_r = Util.pack_complex_to_real(vav) # normalize and mask tavg in real space tavg = fft(tavg) stat = Util.infomask(tavg, mask, False) tavg -= stat[0] Util.mul_img(tavg, mask) tavg.write_image("tavg.hdf", Iter) # For testing purposes: shift tavg to some random place and see if the centering is still correct #tavg = rot_shift3D(tavg,sx=3,sy=-4) if Fourvar: del vav bcast_EMData_to_all(tavg, myid, main_node) tavg = fft(tavg) sx_sum = 0.0 nxc = nx // 2 for ifil in range(nfils): """ # Calculate filament average avg = EMData(nx, ny, 1, False) filnima = 0 for im in xrange(indcs[ifil][0], indcs[ifil][1]): Util.add_img(avg, data[im]) filnima += 1 tavg = Util.mult_scalar(avg, 1.0/float(filnima)) """ # Calculate 1D ccf between each segment and filament average nsegms = indcs[ifil][1] - indcs[ifil][0] ctx = [None] * nsegms pcoords = [None] * nsegms for im in range(indcs[ifil][0], indcs[ifil][1]): ctx[im - indcs[ifil][0]] = Util.window(ccf(tavg, data[im]), nx, 1) pcoords[im - indcs[ifil][0]] = data[im].get_attr( 'ptcl_source_coord') #ctx[im-indcs[ifil][0]].write_image("ctx.hdf",im-indcs[ifil][0]) #print " CTX ",myid,im,Util.infomask(ctx[im-indcs[ifil][0]], None, True) # search for best x-shift cents = nsegms // 2 dst = sqrt( max((pcoords[cents][0] - pcoords[0][0])**2 + (pcoords[cents][1] - pcoords[0][1])**2, (pcoords[cents][0] - pcoords[-1][0])**2 + (pcoords[cents][1] - pcoords[-1][1])**2)) maxincline = atan2(ny // 2 - 2 - float(search_rng), dst) kang = int(dst * tan(maxincline) + 0.5) #print " settings ",nsegms,cents,dst,search_rng,maxincline,kang # ## C code for alignment. @ming results = [0.0] * 3 results = Util.helixshiftali(ctx, pcoords, nsegms, maxincline, kang, search_rng, nxc) sib = int(results[0]) bang = results[1] qm = results[2] #print qm, sib, bang # qm = -1.e23 # # for six in xrange(-search_rng, search_rng+1,1): # q0 = ctx[cents].get_value_at(six+nxc) # for incline in xrange(kang+1): # qt = q0 # qu = q0 # if(kang>0): tang = tan(maxincline/kang*incline) # else: tang = 0.0 # for kim in xrange(cents+1,nsegms): # dst = sqrt((pcoords[cents][0] - pcoords[kim][0])**2 + (pcoords[cents][1] - pcoords[kim][1])**2) # xl = dst*tang+six+nxc # ixl = int(xl) # dxl = xl - ixl # #print " A ", ifil,six,incline,kim,xl,ixl,dxl # qt += (1.0-dxl)*ctx[kim].get_value_at(ixl) + dxl*ctx[kim].get_value_at(ixl+1) # xl = -dst*tang+six+nxc # ixl = int(xl) # dxl = xl - ixl # qu += (1.0-dxl)*ctx[kim].get_value_at(ixl) + dxl*ctx[kim].get_value_at(ixl+1) # for kim in xrange(cents): # dst = sqrt((pcoords[cents][0] - pcoords[kim][0])**2 + (pcoords[cents][1] - pcoords[kim][1])**2) # xl = -dst*tang+six+nxc # ixl = int(xl) # dxl = xl - ixl # qt += (1.0-dxl)*ctx[kim].get_value_at(ixl) + dxl*ctx[kim].get_value_at(ixl+1) # xl = dst*tang+six+nxc # ixl = int(xl) # dxl = xl - ixl # qu += (1.0-dxl)*ctx[kim].get_value_at(ixl) + dxl*ctx[kim].get_value_at(ixl+1) # if( qt > qm ): # qm = qt # sib = six # bang = tang # if( qu > qm ): # qm = qu # sib = six # bang = -tang #if incline == 0: print "incline = 0 ",six,tang,qt,qu #print qm,six,sib,bang #print " got results ",indcs[ifil][0], indcs[ifil][1], ifil,myid,qm,sib,tang,bang,len(ctx),Util.infomask(ctx[0], None, True) for im in range(indcs[ifil][0], indcs[ifil][1]): kim = im - indcs[ifil][0] dst = sqrt((pcoords[cents][0] - pcoords[kim][0])**2 + (pcoords[cents][1] - pcoords[kim][1])**2) if (kim < cents): xl = -dst * bang + sib else: xl = dst * bang + sib shift_x[im] = xl # Average shift sx_sum += shift_x[indcs[ifil][0] + cents] # #print myid,sx_sum,total_nfils sx_sum = mpi.mpi_reduce(sx_sum, 1, mpi.MPI_FLOAT, mpi.MPI_SUM, main_node, mpi.MPI_COMM_WORLD) if myid == main_node: sx_sum = float(sx_sum[0]) / total_nfils print_msg("Average shift %6.2f\n" % (sx_sum)) else: sx_sum = 0.0 sx_sum = 0.0 sx_sum = bcast_number_to_all(sx_sum, source_node=main_node) for im in range(ldata): shift_x[im] -= sx_sum #print " %3d %6.3f"%(im,shift_x[im]) #exit() # combine shifts found with the original parameters for im in range(ldata): t1 = Transform() ##import random ##shix=random.randint(-10, 10) ##t1.set_params({"type":"2D","tx":shix}) t1.set_params({"type": "2D", "tx": shift_x[im]}) # combine t0 and t1 tt = t1 * init_params[im] data[im].set_attr("xform.align2d", tt) # write out headers and STOP, under MPI writing has to be done sequentially mpi.mpi_barrier(mpi.MPI_COMM_WORLD) par_str = ["xform.align2d", "ID"] if myid == main_node: from sp_utilities import file_type if (file_type(stack) == "bdb"): from sp_utilities import recv_attr_dict_bdb recv_attr_dict_bdb(main_node, stack, data, par_str, 0, ldata, nproc) else: from sp_utilities import recv_attr_dict recv_attr_dict(main_node, stack, data, par_str, 0, ldata, nproc) else: send_attr_dict(main_node, data, par_str, 0, ldata) if myid == main_node: print_end_msg("helical-shiftali_MPI")
def shiftali_MPI(stack, maskfile=None, maxit=100, CTF=False, snr=1.0, Fourvar=False, search_rng=-1, oneDx=False, search_rng_y=-1): number_of_proc = mpi.mpi_comm_size(mpi.MPI_COMM_WORLD) myid = mpi.mpi_comm_rank(mpi.MPI_COMM_WORLD) main_node = 0 ftp = file_type(stack) if myid == main_node: print_begin_msg("shiftali_MPI") max_iter = int(maxit) if myid == main_node: if ftp == "bdb": from EMAN2db import db_open_dict dummy = db_open_dict(stack, True) nima = EMUtil.get_image_count(stack) else: nima = 0 nima = bcast_number_to_all(nima, source_node=main_node) list_of_particles = list(range(nima)) image_start, image_end = MPI_start_end(nima, number_of_proc, myid) list_of_particles = list_of_particles[image_start:image_end] # read nx and ctf_app (if CTF) and broadcast to all nodes if myid == main_node: ima = EMData() ima.read_image(stack, list_of_particles[0], True) nx = ima.get_xsize() ny = ima.get_ysize() if CTF: ctf_app = ima.get_attr_default('ctf_applied', 2) del ima else: nx = 0 ny = 0 if CTF: ctf_app = 0 nx = bcast_number_to_all(nx, source_node=main_node) ny = bcast_number_to_all(ny, source_node=main_node) if CTF: ctf_app = bcast_number_to_all(ctf_app, source_node=main_node) if ctf_app > 0: ERROR("data cannot be ctf-applied", myid=myid) if maskfile == None: mrad = min(nx, ny) mask = model_circle(mrad // 2 - 2, nx, ny) else: mask = get_im(maskfile) if CTF: from sp_filter import filt_ctf from sp_morphology import ctf_img ctf_abs_sum = EMData(nx, ny, 1, False) ctf_2_sum = EMData(nx, ny, 1, False) else: ctf_2_sum = None from sp_global_def import CACHE_DISABLE if CACHE_DISABLE: data = EMData.read_images(stack, list_of_particles) else: for i in range(number_of_proc): if myid == i: data = EMData.read_images(stack, list_of_particles) if ftp == "bdb": mpi.mpi_barrier(mpi.MPI_COMM_WORLD) for im in range(len(data)): data[im].set_attr('ID', list_of_particles[im]) st = Util.infomask(data[im], mask, False) data[im] -= st[0] if CTF: ctf_params = data[im].get_attr("ctf") ctfimg = ctf_img(nx, ctf_params, ny=ny) Util.add_img2(ctf_2_sum, ctfimg) Util.add_img_abs(ctf_abs_sum, ctfimg) if CTF: reduce_EMData_to_root(ctf_2_sum, myid, main_node) reduce_EMData_to_root(ctf_abs_sum, myid, main_node) else: ctf_2_sum = None if CTF: if myid != main_node: del ctf_2_sum del ctf_abs_sum else: temp = EMData(nx, ny, 1, False) for i in range(0, nx, 2): for j in range(ny): temp.set_value_at(i, j, snr) Util.add_img(ctf_2_sum, temp) del temp total_iter = 0 # apply initial xform.align2d parameters stored in header init_params = [] for im in range(len(data)): t = data[im].get_attr('xform.align2d') init_params.append(t) p = t.get_params("2d") data[im] = rot_shift2D(data[im], p['alpha'], sx=p['tx'], sy=p['ty'], mirror=p['mirror'], scale=p['scale']) # fourier transform all images, and apply ctf if CTF for im in range(len(data)): if CTF: ctf_params = data[im].get_attr("ctf") data[im] = filt_ctf(fft(data[im]), ctf_params) else: data[im] = fft(data[im]) sx_sum = 0 sy_sum = 0 sx_sum_total = 0 sy_sum_total = 0 shift_x = [0.0] * len(data) shift_y = [0.0] * len(data) ishift_x = [0.0] * len(data) ishift_y = [0.0] * len(data) for Iter in range(max_iter): if myid == main_node: start_time = time() print_msg("Iteration #%4d\n" % (total_iter)) total_iter += 1 avg = EMData(nx, ny, 1, False) for im in data: Util.add_img(avg, im) reduce_EMData_to_root(avg, myid, main_node) if myid == main_node: if CTF: tavg = Util.divn_filter(avg, ctf_2_sum) else: tavg = Util.mult_scalar(avg, 1.0 / float(nima)) else: tavg = EMData(nx, ny, 1, False) if Fourvar: bcast_EMData_to_all(tavg, myid, main_node) vav, rvar = varf2d_MPI(myid, data, tavg, mask, "a", CTF) if myid == main_node: if Fourvar: tavg = fft(Util.divn_img(fft(tavg), vav)) vav_r = Util.pack_complex_to_real(vav) # normalize and mask tavg in real space tavg = fft(tavg) stat = Util.infomask(tavg, mask, False) tavg -= stat[0] Util.mul_img(tavg, mask) # For testing purposes: shift tavg to some random place and see if the centering is still correct #tavg = rot_shift3D(tavg,sx=3,sy=-4) tavg = fft(tavg) if Fourvar: del vav bcast_EMData_to_all(tavg, myid, main_node) sx_sum = 0 sy_sum = 0 if search_rng > 0: nwx = 2 * search_rng + 1 else: nwx = nx if search_rng_y > 0: nwy = 2 * search_rng_y + 1 else: nwy = ny not_zero = 0 for im in range(len(data)): if oneDx: ctx = Util.window(ccf(data[im], tavg), nwx, 1) p1 = peak_search(ctx) p1_x = -int(p1[0][3]) ishift_x[im] = p1_x sx_sum += p1_x else: p1 = peak_search(Util.window(ccf(data[im], tavg), nwx, nwy)) p1_x = -int(p1[0][4]) p1_y = -int(p1[0][5]) ishift_x[im] = p1_x ishift_y[im] = p1_y sx_sum += p1_x sy_sum += p1_y if not_zero == 0: if (not (ishift_x[im] == 0.0)) or (not (ishift_y[im] == 0.0)): not_zero = 1 sx_sum = mpi.mpi_reduce(sx_sum, 1, mpi.MPI_INT, mpi.MPI_SUM, main_node, mpi.MPI_COMM_WORLD) if not oneDx: sy_sum = mpi.mpi_reduce(sy_sum, 1, mpi.MPI_INT, mpi.MPI_SUM, main_node, mpi.MPI_COMM_WORLD) if myid == main_node: sx_sum_total = int(sx_sum[0]) if not oneDx: sy_sum_total = int(sy_sum[0]) else: sx_sum_total = 0 sy_sum_total = 0 sx_sum_total = bcast_number_to_all(sx_sum_total, source_node=main_node) if not oneDx: sy_sum_total = bcast_number_to_all(sy_sum_total, source_node=main_node) sx_ave = round(float(sx_sum_total) / nima) sy_ave = round(float(sy_sum_total) / nima) for im in range(len(data)): p1_x = ishift_x[im] - sx_ave p1_y = ishift_y[im] - sy_ave params2 = { "filter_type": Processor.fourier_filter_types.SHIFT, "x_shift": p1_x, "y_shift": p1_y, "z_shift": 0.0 } data[im] = Processor.EMFourierFilter(data[im], params2) shift_x[im] += p1_x shift_y[im] += p1_y # stop if all shifts are zero not_zero = mpi.mpi_reduce(not_zero, 1, mpi.MPI_INT, mpi.MPI_SUM, main_node, mpi.MPI_COMM_WORLD) if myid == main_node: not_zero_all = int(not_zero[0]) else: not_zero_all = 0 not_zero_all = bcast_number_to_all(not_zero_all, source_node=main_node) if myid == main_node: print_msg("Time of iteration = %12.2f\n" % (time() - start_time)) start_time = time() if not_zero_all == 0: break #for im in xrange(len(data)): data[im] = fft(data[im]) This should not be required as only header information is used # combine shifts found with the original parameters for im in range(len(data)): t0 = init_params[im] t1 = Transform() t1.set_params({ "type": "2D", "alpha": 0, "scale": t0.get_scale(), "mirror": 0, "tx": shift_x[im], "ty": shift_y[im] }) # combine t0 and t1 tt = t1 * t0 data[im].set_attr("xform.align2d", tt) # write out headers and STOP, under MPI writing has to be done sequentially mpi.mpi_barrier(mpi.MPI_COMM_WORLD) par_str = ["xform.align2d", "ID"] if myid == main_node: from sp_utilities import file_type if (file_type(stack) == "bdb"): from sp_utilities import recv_attr_dict_bdb recv_attr_dict_bdb(main_node, stack, data, par_str, image_start, image_end, number_of_proc) else: from sp_utilities import recv_attr_dict recv_attr_dict(main_node, stack, data, par_str, image_start, image_end, number_of_proc) else: send_attr_dict(main_node, data, par_str, image_start, image_end) if myid == main_node: print_end_msg("shiftali_MPI")
def align2d_scf(image, refim, xrng=-1, yrng=-1, ou=-1): nx = image.get_xsize() ny = image.get_xsize() if ou < 0: ou = min(old_div(nx, 2) - 1, old_div(ny, 2) - 1) if yrng < 0: yrng = xrng if ou < 2: sp_global_def.ERROR("Radius of the object (ou) has to be given", "align2d_scf", 1) sci = sp_fundamentals.scf(image) scr = sp_fundamentals.scf(refim) first_ring = 1 # alpha1, sxs, sys, mirr, peak1 = align2d_no_mirror(scf(image), scr, last_ring=ou, mode="H") # alpha2, sxs, sys, mirr, peak2 = align2d_no_mirror(scf(mirror(image)), scr, last_ring=ou, mode="H") # alpha1, sxs, sys, mirr, peak1 = align2d_no_mirror(sci, scr, first_ring = 1, last_ring=ou, mode="H") # alpha2, sxs, sys, mirr, peak2 = align2d_no_mirror(mirror(sci), scr, first_ring = 1, last_ring=ou, mode="H") # center in SPIDER convention cnx = old_div(nx, 2) + 1 cny = old_div(ny, 2) + 1 # precalculate rings numr = Numrinit(first_ring, ou, 1, "H") wr = ringwe(numr, "H") crefim = EMAN2_cppwrap.Util.Polar2Dm(scr, cnx, cny, numr, "H") EMAN2_cppwrap.Util.Frngs(crefim, numr) EMAN2_cppwrap.Util.Applyws(crefim, numr, wr) alpha1, sxs, sys, mirr, peak1 = ornq(sci, crefim, [0.0], [0.0], 1.0, "H", numr, cnx, cny) alpha2, sxs, sys, mirr, peak2 = ornq(sp_fundamentals.mirror(sci), crefim, [0.0], [0.0], 1.0, "H", numr, cnx, cny) if peak1 > peak2: mirr = 0 alpha = alpha1 else: mirr = 1 alpha = -alpha2 nrx = min(2 * (xrng + 1) + 1, ((old_div((nx - 2), 2)) * 2 + 1)) nry = min(2 * (yrng + 1) + 1, ((old_div((ny - 2), 2)) * 2 + 1)) frotim = sp_fundamentals.fft(refim) ccf1 = EMAN2_cppwrap.Util.window( sp_fundamentals.ccf( sp_fundamentals.rot_shift2D(image, alpha, 0.0, 0.0, mirr), frotim), nrx, nry, ) p1 = sp_utilities.peak_search(ccf1) ccf2 = EMAN2_cppwrap.Util.window( sp_fundamentals.ccf( sp_fundamentals.rot_shift2D(image, alpha + 180.0, 0.0, 0.0, mirr), frotim), nrx, nry, ) p2 = sp_utilities.peak_search(ccf2) # print p1 # print p2 peak_val1 = p1[0][0] peak_val2 = p2[0][0] if peak_val1 > peak_val2: sxs = -p1[0][4] sys = -p1[0][5] cx = int(p1[0][1]) cy = int(p1[0][2]) peak = peak_val1 else: alpha += 180.0 sxs = -p2[0][4] sys = -p2[0][5] peak = peak_val2 cx = int(p2[0][1]) cy = int(p2[0][2]) ccf1 = ccf2 # print cx,cy z = sp_utilities.model_blank(3, 3) for i in range(3): for j in range(3): z[i, j] = ccf1[i + cx - 1, j + cy - 1] # print ccf1[cx,cy],z[1,1] XSH, YSH, PEAKV = parabl(z) # print sxs, sys, XSH, YSH, PEAKV, peak if mirr == 1: sx = -sxs + XSH else: sx = sxs - XSH return alpha, sx, sys - YSH, mirr, PEAKV
def multalign2d_scf(image, refrings, frotim, numr, xrng=-1, yrng=-1, ou=-1): nx = image.get_xsize() ny = image.get_xsize() if ou < 0: ou = min(old_div(nx, 2) - 1, old_div(ny, 2) - 1) if yrng < 0: yrng = xrng if ou < 2: sp_global_def.ERROR("Radius of the object (ou) has to be given", "align2d_scf", 1) sci = sp_fundamentals.scf(image) first_ring = 1 # center in SPIDER convention cnx = old_div(nx, 2) + 1 cny = old_div(ny, 2) + 1 cimage = EMAN2_cppwrap.Util.Polar2Dm(sci, cnx, cny, numr, "H") EMAN2_cppwrap.Util.Frngs(cimage, numr) mimage = EMAN2_cppwrap.Util.Polar2Dm(sp_fundamentals.mirror(sci), cnx, cny, numr, "H") EMAN2_cppwrap.Util.Frngs(mimage, numr) nrx = min(2 * (xrng + 1) + 1, ((old_div((nx - 2), 2)) * 2 + 1)) nry = min(2 * (yrng + 1) + 1, ((old_div((ny - 2), 2)) * 2 + 1)) totpeak = -1.0e23 for iki in range(len(refrings)): # print "TEMPLATE ",iki # Find angle retvals = EMAN2_cppwrap.Util.Crosrng_e(refrings[iki], cimage, numr, 0, 0.0) alpha1 = ang_n(retvals["tot"], "H", numr[-1]) peak1 = retvals["qn"] retvals = EMAN2_cppwrap.Util.Crosrng_e(refrings[iki], mimage, numr, 0, 0.0) alpha2 = ang_n(retvals["tot"], "H", numr[-1]) peak2 = retvals["qn"] # print alpha1, peak1 # print alpha2, peak2 if peak1 > peak2: mirr = 0 alpha = alpha1 else: mirr = 1 alpha = -alpha2 ccf1 = EMAN2_cppwrap.Util.window( sp_fundamentals.ccf( sp_fundamentals.rot_shift2D(image, alpha, 0.0, 0.0, mirr), frotim[iki]), nrx, nry, ) p1 = sp_utilities.peak_search(ccf1) ccf2 = EMAN2_cppwrap.Util.window( sp_fundamentals.ccf( sp_fundamentals.rot_shift2D(image, alpha + 180.0, 0.0, 0.0, mirr), frotim[iki], ), nrx, nry, ) p2 = sp_utilities.peak_search(ccf2) # print p1 # print p2 peak_val1 = p1[0][0] peak_val2 = p2[0][0] if peak_val1 > peak_val2: sxs = -p1[0][4] sys = -p1[0][5] cx = int(p1[0][1]) cy = int(p1[0][2]) peak = peak_val1 else: alpha += 180.0 sxs = -p2[0][4] sys = -p2[0][5] peak = peak_val2 cx = int(p2[0][1]) cy = int(p2[0][2]) ccf1 = ccf2 # print cx,cy z = sp_utilities.model_blank(3, 3) for i in range(3): for j in range(3): z[i, j] = ccf1[i + cx - 1, j + cy - 1] # print ccf1[cx,cy],z[1,1] XSH, YSH, PEAKV = parabl(z) # print PEAKV if PEAKV > totpeak: totpeak = PEAKV iref = iki if mirr == 1: sx = -sxs + XSH else: sx = sxs - XSH sy = sys - YSH talpha = alpha tmirr = mirr # print "BETTER",sx,sy,iref,talpha,tmirr,totpeak # return alpha, sx, sys-YSH, mirr, PEAKV return sx, sy, iref, talpha, tmirr, totpeak
def align2d_direct3(input_images, refim, xrng=1, yrng=1, psimax=180, psistep=1, ou=-1, CTF=None): nx = input_images[0].get_xsize() if ou < 0: ou = old_div(nx, 2) - 1 mask = sp_utilities.model_circle(ou, nx, nx) nk = int(old_div(psimax, psistep)) nm = 2 * nk + 1 nc = nk + 1 refs = [None] * nm * 2 for i in range(nm): temp = sp_fundamentals.rot_shift2D(refim, (i - nc) * psistep) * mask refs[2 * i] = [ sp_fundamentals.fft(temp), sp_fundamentals.fft(sp_fundamentals.mirror(temp)), ] temp = sp_fundamentals.rot_shift2D(refim, (i - nc) * psistep + 180.0) * mask refs[2 * i + 1] = [ sp_fundamentals.fft(temp), sp_fundamentals.fft(sp_fundamentals.mirror(temp)), ] del temp results = [] mir = 0 for image in input_images: if CTF: ims = sp_filter.filt_ctf(sp_fundamentals.fft(image), image.get_attr("ctf")) else: ims = sp_fundamentals.fft(image) ama = -1.0e23 bang = 0.0 bsx = 0.0 bsy = 0.0 for i in range(nm * 2): for mirror_flag in [0, 1]: c = sp_fundamentals.ccf(ims, refs[i][mirror_flag]) w = EMAN2_cppwrap.Util.window(c, 2 * xrng + 1, 2 * yrng + 1) pp = sp_utilities.peak_search(w)[0] px = int(pp[4]) py = int(pp[5]) if pp[0] == 1.0 and px == 0 and py == 0: pass # XSH, YSH, PEAKV = 0.,0.,0. else: ww = sp_utilities.model_blank(3, 3) ux = int(pp[1]) uy = int(pp[2]) for k in range(3): for l in range(3): ww[k, l] = w[k + ux - 1, l + uy - 1] XSH, YSH, PEAKV = parabl(ww) # print i,pp[-1],XSH, YSH,px+XSH, py+YSH, PEAKV if PEAKV > ama: ama = PEAKV bsx = px + round(XSH, 2) bsy = py + round(YSH, 2) bang = i mir = mirror_flag # returned parameters have to be inverted bang = (old_div(bang, 2) - nc) * psistep + 180.0 * (bang % 2) bang, bsx, bsy, _ = sp_utilities.inverse_transform2( bang, (1 - 2 * mir) * bsx, bsy, mir) results.append([bang, bsx, bsy, mir, ama]) return results