Beispiel #1
0
def filterlocal(ui, vi, m, falloff, myid, main_node, number_of_proc):

    if myid == main_node:

        nx = vi.get_xsize()
        ny = vi.get_ysize()
        nz = vi.get_zsize()
        #  Round all resolution numbers to two digits
        for x in range(nx):
            for y in range(ny):
                for z in range(nz):
                    ui.set_value_at_fast(x, y, z, round(ui.get_value_at(x, y, z), 2))
        dis = [nx, ny, nz]
    else:
        falloff = 0.0
        radius = 0
        dis = [0, 0, 0]
    falloff = sp_utilities.bcast_number_to_all(falloff, main_node)
    dis = sp_utilities.bcast_list_to_all(dis, myid, source_node=main_node)

    if myid != main_node:
        nx = int(dis[0])
        ny = int(dis[1])
        nz = int(dis[2])

        vi = sp_utilities.model_blank(nx, ny, nz)
        ui = sp_utilities.model_blank(nx, ny, nz)

    sp_utilities.bcast_EMData_to_all(vi, myid, main_node)
    sp_utilities.bcast_EMData_to_all(ui, myid, main_node)

    sp_fundamentals.fftip(vi)  #  volume to be filtered

    st = EMAN2_cppwrap.Util.infomask(ui, m, True)

    filteredvol = sp_utilities.model_blank(nx, ny, nz)
    cutoff = max(st[2] - 0.01, 0.0)
    while cutoff < st[3]:
        cutoff = round(cutoff + 0.01, 2)
        # if(myid == main_node):  print  cutoff,st
        pt = EMAN2_cppwrap.Util.infomask(
            sp_morphology.threshold_outside(ui, cutoff - 0.00501, cutoff + 0.005),
            m,
            True,
        )  # Ideally, one would want to check only slices in question...
        if pt[0] != 0.0:
            # print cutoff,pt[0]
            vovo = sp_fundamentals.fft(filt_tanl(vi, cutoff, falloff))
            for z in range(myid, nz, number_of_proc):
                for x in range(nx):
                    for y in range(ny):
                        if m.get_value_at(x, y, z) > 0.5:
                            if round(ui.get_value_at(x, y, z), 2) == cutoff:
                                filteredvol.set_value_at_fast(
                                    x, y, z, vovo.get_value_at(x, y, z)
                                )

    mpi.mpi_barrier(mpi.MPI_COMM_WORLD)
    sp_utilities.reduce_EMData_to_root(filteredvol, myid, main_node, mpi.MPI_COMM_WORLD)
    return filteredvol
Beispiel #2
0
def helicalshiftali_MPI(stack,
                        maskfile=None,
                        maxit=100,
                        CTF=False,
                        snr=1.0,
                        Fourvar=False,
                        search_rng=-1):

    nproc = mpi.mpi_comm_size(mpi.MPI_COMM_WORLD)
    myid = mpi.mpi_comm_rank(mpi.MPI_COMM_WORLD)
    main_node = 0

    ftp = file_type(stack)

    if myid == main_node:
        print_begin_msg("helical-shiftali_MPI")

    max_iter = int(maxit)
    if (myid == main_node):
        infils = EMUtil.get_all_attributes(stack, "filament")
        ptlcoords = EMUtil.get_all_attributes(stack, 'ptcl_source_coord')
        filaments = ordersegments(infils, ptlcoords)
        total_nfils = len(filaments)
        inidl = [0] * total_nfils
        for i in range(total_nfils):
            inidl[i] = len(filaments[i])
        linidl = sum(inidl)
        nima = linidl
        tfilaments = []
        for i in range(total_nfils):
            tfilaments += filaments[i]
        del filaments
    else:
        total_nfils = 0
        linidl = 0
    total_nfils = bcast_number_to_all(total_nfils, source_node=main_node)
    if myid != main_node:
        inidl = [-1] * total_nfils
    inidl = bcast_list_to_all(inidl, myid, source_node=main_node)
    linidl = bcast_number_to_all(linidl, source_node=main_node)
    if myid != main_node:
        tfilaments = [-1] * linidl
    tfilaments = bcast_list_to_all(tfilaments, myid, source_node=main_node)
    filaments = []
    iendi = 0
    for i in range(total_nfils):
        isti = iendi
        iendi = isti + inidl[i]
        filaments.append(tfilaments[isti:iendi])
    del tfilaments, inidl

    if myid == main_node:
        print_msg("total number of filaments: %d" % total_nfils)
    if total_nfils < nproc:
        ERROR(
            'number of CPUs (%i) is larger than the number of filaments (%i), please reduce the number of CPUs used'
            % (nproc, total_nfils),
            myid=myid)

    #  balanced load
    temp = chunks_distribution([[len(filaments[i]), i]
                                for i in range(len(filaments))],
                               nproc)[myid:myid + 1][0]
    filaments = [filaments[temp[i][1]] for i in range(len(temp))]
    nfils = len(filaments)

    #filaments = [[0,1]]
    #print "filaments",filaments
    list_of_particles = []
    indcs = []
    k = 0
    for i in range(nfils):
        list_of_particles += filaments[i]
        k1 = k + len(filaments[i])
        indcs.append([k, k1])
        k = k1
    data = EMData.read_images(stack, list_of_particles)
    ldata = len(data)
    sxprint("ldata=", ldata)
    nx = data[0].get_xsize()
    ny = data[0].get_ysize()
    if maskfile == None:
        mrad = min(nx, ny) // 2 - 2
        mask = pad(model_blank(2 * mrad + 1, ny, 1, 1.0), nx, ny, 1, 0.0)
    else:
        mask = get_im(maskfile)

    # apply initial xform.align2d parameters stored in header
    init_params = []
    for im in range(ldata):
        t = data[im].get_attr('xform.align2d')
        init_params.append(t)
        p = t.get_params("2d")
        data[im] = rot_shift2D(data[im], p['alpha'], p['tx'], p['ty'],
                               p['mirror'], p['scale'])

    if CTF:
        from sp_filter import filt_ctf
        from sp_morphology import ctf_img
        ctf_abs_sum = EMData(nx, ny, 1, False)
        ctf_2_sum = EMData(nx, ny, 1, False)
    else:
        ctf_2_sum = None
        ctf_abs_sum = None

    from sp_utilities import info

    for im in range(ldata):
        data[im].set_attr('ID', list_of_particles[im])
        st = Util.infomask(data[im], mask, False)
        data[im] -= st[0]
        if CTF:
            ctf_params = data[im].get_attr("ctf")
            qctf = data[im].get_attr("ctf_applied")
            if qctf == 0:
                data[im] = filt_ctf(fft(data[im]), ctf_params)
                data[im].set_attr('ctf_applied', 1)
            elif qctf != 1:
                ERROR('Incorrectly set qctf flag', myid=myid)
            ctfimg = ctf_img(nx, ctf_params, ny=ny)
            Util.add_img2(ctf_2_sum, ctfimg)
            Util.add_img_abs(ctf_abs_sum, ctfimg)
        else:
            data[im] = fft(data[im])

    del list_of_particles

    if CTF:
        reduce_EMData_to_root(ctf_2_sum, myid, main_node)
        reduce_EMData_to_root(ctf_abs_sum, myid, main_node)
    if CTF:
        if myid != main_node:
            del ctf_2_sum
            del ctf_abs_sum
        else:
            temp = EMData(nx, ny, 1, False)
            tsnr = 1. / snr
            for i in range(0, nx + 2, 2):
                for j in range(ny):
                    temp.set_value_at(i, j, tsnr)
                    temp.set_value_at(i + 1, j, 0.0)
            #info(ctf_2_sum)
            Util.add_img(ctf_2_sum, temp)
            #info(ctf_2_sum)
            del temp

    total_iter = 0
    shift_x = [0.0] * ldata

    for Iter in range(max_iter):
        if myid == main_node:
            start_time = time()
            print_msg("Iteration #%4d\n" % (total_iter))
        total_iter += 1
        avg = EMData(nx, ny, 1, False)
        for im in range(ldata):
            Util.add_img(avg, fshift(data[im], shift_x[im]))

        reduce_EMData_to_root(avg, myid, main_node)

        if myid == main_node:
            if CTF: tavg = Util.divn_filter(avg, ctf_2_sum)
            else: tavg = Util.mult_scalar(avg, 1.0 / float(nima))
        else:
            tavg = model_blank(nx, ny)

        if Fourvar:
            bcast_EMData_to_all(tavg, myid, main_node)
            vav, rvar = varf2d_MPI(myid, data, tavg, mask, "a", CTF)

        if myid == main_node:
            if Fourvar:
                tavg = fft(Util.divn_img(fft(tavg), vav))
                vav_r = Util.pack_complex_to_real(vav)
            # normalize and mask tavg in real space
            tavg = fft(tavg)
            stat = Util.infomask(tavg, mask, False)
            tavg -= stat[0]
            Util.mul_img(tavg, mask)
            tavg.write_image("tavg.hdf", Iter)
            # For testing purposes: shift tavg to some random place and see if the centering is still correct
            #tavg = rot_shift3D(tavg,sx=3,sy=-4)

        if Fourvar: del vav
        bcast_EMData_to_all(tavg, myid, main_node)
        tavg = fft(tavg)

        sx_sum = 0.0
        nxc = nx // 2

        for ifil in range(nfils):
            """
			# Calculate filament average
			avg = EMData(nx, ny, 1, False)
			filnima = 0
			for im in xrange(indcs[ifil][0], indcs[ifil][1]):
				Util.add_img(avg, data[im])
				filnima += 1
			tavg = Util.mult_scalar(avg, 1.0/float(filnima))
			"""
            # Calculate 1D ccf between each segment and filament average
            nsegms = indcs[ifil][1] - indcs[ifil][0]
            ctx = [None] * nsegms
            pcoords = [None] * nsegms
            for im in range(indcs[ifil][0], indcs[ifil][1]):
                ctx[im - indcs[ifil][0]] = Util.window(ccf(tavg, data[im]), nx,
                                                       1)
                pcoords[im - indcs[ifil][0]] = data[im].get_attr(
                    'ptcl_source_coord')
                #ctx[im-indcs[ifil][0]].write_image("ctx.hdf",im-indcs[ifil][0])
                #print "  CTX  ",myid,im,Util.infomask(ctx[im-indcs[ifil][0]], None, True)
            # search for best x-shift
            cents = nsegms // 2

            dst = sqrt(
                max((pcoords[cents][0] - pcoords[0][0])**2 +
                    (pcoords[cents][1] - pcoords[0][1])**2,
                    (pcoords[cents][0] - pcoords[-1][0])**2 +
                    (pcoords[cents][1] - pcoords[-1][1])**2))
            maxincline = atan2(ny // 2 - 2 - float(search_rng), dst)
            kang = int(dst * tan(maxincline) + 0.5)
            #print  "  settings ",nsegms,cents,dst,search_rng,maxincline,kang

            # ## C code for alignment. @ming
            results = [0.0] * 3
            results = Util.helixshiftali(ctx, pcoords, nsegms, maxincline,
                                         kang, search_rng, nxc)
            sib = int(results[0])
            bang = results[1]
            qm = results[2]
            #print qm, sib, bang

            # qm = -1.e23
            #
            # 			for six in xrange(-search_rng, search_rng+1,1):
            # 				q0 = ctx[cents].get_value_at(six+nxc)
            # 				for incline in xrange(kang+1):
            # 					qt = q0
            # 					qu = q0
            # 					if(kang>0):  tang = tan(maxincline/kang*incline)
            # 					else:        tang = 0.0
            # 					for kim in xrange(cents+1,nsegms):
            # 						dst = sqrt((pcoords[cents][0] - pcoords[kim][0])**2 + (pcoords[cents][1] - pcoords[kim][1])**2)
            # 						xl = dst*tang+six+nxc
            # 						ixl = int(xl)
            # 						dxl = xl - ixl
            # 						#print "  A  ", ifil,six,incline,kim,xl,ixl,dxl
            # 						qt += (1.0-dxl)*ctx[kim].get_value_at(ixl) + dxl*ctx[kim].get_value_at(ixl+1)
            # 						xl = -dst*tang+six+nxc
            # 						ixl = int(xl)
            # 						dxl = xl - ixl
            # 						qu += (1.0-dxl)*ctx[kim].get_value_at(ixl) + dxl*ctx[kim].get_value_at(ixl+1)
            # 					for kim in xrange(cents):
            # 						dst = sqrt((pcoords[cents][0] - pcoords[kim][0])**2 + (pcoords[cents][1] - pcoords[kim][1])**2)
            # 						xl = -dst*tang+six+nxc
            # 						ixl = int(xl)
            # 						dxl = xl - ixl
            # 						qt += (1.0-dxl)*ctx[kim].get_value_at(ixl) + dxl*ctx[kim].get_value_at(ixl+1)
            # 						xl =  dst*tang+six+nxc
            # 						ixl = int(xl)
            # 						dxl = xl - ixl
            # 						qu += (1.0-dxl)*ctx[kim].get_value_at(ixl) + dxl*ctx[kim].get_value_at(ixl+1)
            # 					if( qt > qm ):
            # 						qm = qt
            # 						sib = six
            # 						bang = tang
            # 					if( qu > qm ):
            # 						qm = qu
            # 						sib = six
            # 						bang = -tang
            #if incline == 0:  print  "incline = 0  ",six,tang,qt,qu
            #print qm,six,sib,bang
            #print " got results   ",indcs[ifil][0], indcs[ifil][1], ifil,myid,qm,sib,tang,bang,len(ctx),Util.infomask(ctx[0], None, True)
            for im in range(indcs[ifil][0], indcs[ifil][1]):
                kim = im - indcs[ifil][0]
                dst = sqrt((pcoords[cents][0] - pcoords[kim][0])**2 +
                           (pcoords[cents][1] - pcoords[kim][1])**2)
                if (kim < cents): xl = -dst * bang + sib
                else: xl = dst * bang + sib
                shift_x[im] = xl

            # Average shift
            sx_sum += shift_x[indcs[ifil][0] + cents]

        # #print myid,sx_sum,total_nfils
        sx_sum = mpi.mpi_reduce(sx_sum, 1, mpi.MPI_FLOAT, mpi.MPI_SUM,
                                main_node, mpi.MPI_COMM_WORLD)
        if myid == main_node:
            sx_sum = float(sx_sum[0]) / total_nfils
            print_msg("Average shift  %6.2f\n" % (sx_sum))
        else:
            sx_sum = 0.0
        sx_sum = 0.0
        sx_sum = bcast_number_to_all(sx_sum, source_node=main_node)
        for im in range(ldata):
            shift_x[im] -= sx_sum
            #print  "   %3d  %6.3f"%(im,shift_x[im])
        #exit()

    # combine shifts found with the original parameters
    for im in range(ldata):
        t1 = Transform()
        ##import random
        ##shix=random.randint(-10, 10)
        ##t1.set_params({"type":"2D","tx":shix})
        t1.set_params({"type": "2D", "tx": shift_x[im]})
        # combine t0 and t1
        tt = t1 * init_params[im]
        data[im].set_attr("xform.align2d", tt)
    # write out headers and STOP, under MPI writing has to be done sequentially
    mpi.mpi_barrier(mpi.MPI_COMM_WORLD)
    par_str = ["xform.align2d", "ID"]
    if myid == main_node:
        from sp_utilities import file_type
        if (file_type(stack) == "bdb"):
            from sp_utilities import recv_attr_dict_bdb
            recv_attr_dict_bdb(main_node, stack, data, par_str, 0, ldata,
                               nproc)
        else:
            from sp_utilities import recv_attr_dict
            recv_attr_dict(main_node, stack, data, par_str, 0, ldata, nproc)
    else:
        send_attr_dict(main_node, data, par_str, 0, ldata)
    if myid == main_node: print_end_msg("helical-shiftali_MPI")
Beispiel #3
0
def main():
    arglist = []
    for arg in sys.argv:
        arglist.append(arg)
    progname = optparse.os.path.basename(arglist[0])
    usage = progname + """ inputvolume  locresvolume maskfile outputfile   --radius --falloff  --MPI

	    Locally filer a volume based on local resolution volume (sxlocres.py) within area outlined by the maskfile
	"""
    parser = optparse.OptionParser(usage, version=sp_global_def.SPARXVERSION)

    parser.add_option(
        "--radius",
        type="int",
        default=-1,
        help=
        "if there is no maskfile, sphere with r=radius will be used, by default the radius is nx/2-1"
    )
    parser.add_option("--falloff",
                      type="float",
                      default=0.1,
                      help="falloff of tanl filter (default 0.1)")
    parser.add_option("--MPI",
                      action="store_true",
                      default=False,
                      help="use MPI version")

    (options, args) = parser.parse_args(arglist[1:])

    if len(args) < 3 or len(args) > 4:
        sp_global_def.sxprint("See usage " + usage)
        sp_global_def.ERROR(
            "Wrong number of parameters. Please see usage information above.")
        return

    if sp_global_def.CACHE_DISABLE:
        pass  #IMPORTIMPORTIMPORT from sp_utilities import disable_bdb_cache
        sp_utilities.disable_bdb_cache()

    if options.MPI:
        number_of_proc = mpi.mpi_comm_size(mpi.MPI_COMM_WORLD)
        myid = mpi.mpi_comm_rank(mpi.MPI_COMM_WORLD)
        main_node = 0

        if (myid == main_node):
            #print sys.argv
            vi = sp_utilities.get_im(sys.argv[1])
            ui = sp_utilities.get_im(sys.argv[2])
            #print   Util.infomask(ui, None, True)
            radius = options.radius
            nx = vi.get_xsize()
            ny = vi.get_ysize()
            nz = vi.get_zsize()
            dis = [nx, ny, nz]
        else:
            falloff = 0.0
            radius = 0
            dis = [0, 0, 0]
            vi = None
            ui = None
        dis = sp_utilities.bcast_list_to_all(dis, myid, source_node=main_node)

        if (myid != main_node):
            nx = int(dis[0])
            ny = int(dis[1])
            nz = int(dis[2])
        radius = sp_utilities.bcast_number_to_all(radius, main_node)
        if len(args) == 3:
            if (radius == -1):
                radius = min(nx, ny, nz) // 2 - 1
            m = sp_utilities.model_circle(radius, nx, ny, nz)
            outvol = args[2]

        elif len(args) == 4:
            if (myid == main_node):
                m = sp_morphology.binarize(sp_utilities.get_im(args[2]), 0.5)
            else:
                m = sp_utilities.model_blank(nx, ny, nz)
            outvol = args[3]
            sp_utilities.bcast_EMData_to_all(m, myid, main_node)

        pass  #IMPORTIMPORTIMPORT from sp_filter import filterlocal
        filteredvol = sp_filter.filterlocal(ui, vi, m, options.falloff, myid,
                                            main_node, number_of_proc)

        if (myid == 0):
            filteredvol.write_image(outvol)

    else:
        vi = sp_utilities.get_im(args[0])
        ui = sp_utilities.get_im(
            args[1]
        )  # resolution volume, values are assumed to be from 0 to 0.5

        nn = vi.get_xsize()

        falloff = options.falloff

        if len(args) == 3:
            radius = options.radius
            if (radius == -1):
                radius = nn // 2 - 1
            m = sp_utilities.model_circle(radius, nn, nn, nn)
            outvol = args[2]

        elif len(args) == 4:
            m = sp_morphology.binarize(sp_utilities.get_im(args[2]), 0.5)
            outvol = args[3]

        sp_fundamentals.fftip(vi)  # this is the volume to be filtered

        #  Round all resolution numbers to two digits
        for x in range(nn):
            for y in range(nn):
                for z in range(nn):
                    ui.set_value_at_fast(x, y, z,
                                         round(ui.get_value_at(x, y, z), 2))
        st = EMAN2_cppwrap.Util.infomask(ui, m, True)

        filteredvol = sp_utilities.model_blank(nn, nn, nn)
        cutoff = max(st[2] - 0.01, 0.0)
        while (cutoff < st[3]):
            cutoff = round(cutoff + 0.01, 2)
            pt = EMAN2_cppwrap.Util.infomask(
                sp_morphology.threshold_outside(ui, cutoff - 0.00501,
                                                cutoff + 0.005), m, True)
            if (pt[0] != 0.0):
                vovo = sp_fundamentals.fft(
                    sp_filter.filt_tanl(vi, cutoff, falloff))
                for x in range(nn):
                    for y in range(nn):
                        for z in range(nn):
                            if (m.get_value_at(x, y, z) > 0.5):
                                if (round(ui.get_value_at(x, y, z),
                                          2) == cutoff):
                                    filteredvol.set_value_at_fast(
                                        x, y, z, vovo.get_value_at(x, y, z))

        sp_global_def.write_command(optparse.os.path.dirname(outvol))
        filteredvol.write_image(outvol)
Beispiel #4
0
def filterlocal(ui, vi, m, falloff, myid, main_node, number_of_proc):
	from mpi 	  	  import mpi_init, mpi_comm_size, mpi_comm_rank, MPI_COMM_WORLD
	from mpi 	  	  import mpi_reduce, mpi_bcast, mpi_barrier, mpi_gatherv, mpi_send, mpi_recv
	from mpi 	  	  import MPI_SUM, MPI_FLOAT, MPI_INT
	from sp_utilities import bcast_number_to_all, bcast_list_to_all, model_blank, bcast_EMData_to_all, reduce_EMData_to_root
	from sp_morphology import threshold_outside
	from sp_filter import filt_tanl
	from sp_fundamentals import fft, fftip

	if(myid == main_node):

		nx = vi.get_xsize()
		ny = vi.get_ysize()
		nz = vi.get_zsize()
		#  Round all resolution numbers to two digits
		for x in range(nx):
			for y in range(ny):
				for z in range(nz):
					ui.set_value_at_fast( x,y,z, round(ui.get_value_at(x,y,z), 2) )
		dis = [nx,ny,nz]
	else:
		falloff = 0.0
		radius  = 0
		dis = [0,0,0]
	falloff = bcast_number_to_all(falloff, main_node)
	dis = bcast_list_to_all(dis, myid, source_node = main_node)

	if(myid != main_node):
		nx = int(dis[0])
		ny = int(dis[1])
		nz = int(dis[2])

		vi = model_blank(nx,ny,nz)
		ui = model_blank(nx,ny,nz)

	bcast_EMData_to_all(vi, myid, main_node)
	bcast_EMData_to_all(ui, myid, main_node)

	fftip(vi)  #  volume to be filtered

	st = Util.infomask(ui, m, True)


	filteredvol = model_blank(nx,ny,nz)
	cutoff = max(st[2] - 0.01,0.0)
	while(cutoff < st[3] ):
		cutoff = round(cutoff + 0.01, 2)
		#if(myid == main_node):  print  cutoff,st
		pt = Util.infomask( threshold_outside(ui, cutoff - 0.00501, cutoff + 0.005), m, True)  # Ideally, one would want to check only slices in question...
		if(pt[0] != 0.0):
			#print cutoff,pt[0]
			vovo = fft( filt_tanl(vi, cutoff, falloff) )
			for z in range(myid, nz, number_of_proc):
				for x in range(nx):
					for y in range(ny):
						if(m.get_value_at(x,y,z) > 0.5):
							if(round(ui.get_value_at(x,y,z),2) == cutoff):
								filteredvol.set_value_at_fast(x,y,z,vovo.get_value_at(x,y,z))

	mpi_barrier(MPI_COMM_WORLD)
	reduce_EMData_to_root(filteredvol, myid, main_node, MPI_COMM_WORLD)
	return filteredvol
Beispiel #5
0
def resample( prjfile, outdir, bufprefix, nbufvol, nvol, seedbase,\
		delta, d, snr, CTF, npad,\
		MPI, myid, ncpu, verbose = 0 ):
	from   sp_utilities import even_angles
	from   random import seed, jumpahead, shuffle
	import os
	from   sys import exit

	nprj = EMUtil.get_image_count( prjfile )

	if MPI:

		if myid == 0:
			if os.path.exists(outdir):  nx = 1
			else:  nx = 0
		else:  nx = 0
		ny = bcast_number_to_all(nx, source_node = 0)
		if ny == 1:  ERROR('Output directory exists, please change the name and restart the program', "resample", 1,myid)
		mpi.mpi_barrier( mpi.MPI_COMM_WORLD )

		if myid == 0:
			os.makedirs(outdir)
			sp_global_def.write_command(outdir)
		mpi.mpi_barrier( mpi.MPI_COMM_WORLD )
	else:
		if os.path.exists(outdir):
			ERROR('Output directory exists, please change the name and restart the program', "resample", 1,0)
		os.makedirs(outdir)
		sp_global_def.write_command(outdir)
	if(verbose == 1):  finfo=open( os.path.join(outdir, "progress%04d.txt" % myid), "w" )
	else:              finfo = None
	#print  " before evenangles",myid
	from sp_utilities import getvec
	from numpy import array, reshape
	refa = even_angles(delta)
	nrefa = len(refa)
	refnormal = zeros((nrefa,3),'float32')

	tetref = [0.0]*nrefa
	for i in range(nrefa):
		tr = getvec( refa[i][0], refa[i][1] )
		for j in range(3):  refnormal[i][j] = tr[j]
		tetref[i] = refa[i][1]
	del refa
	vct = array([0.0]*(3*nprj),'float32')
	if myid == 0:
		sxprint(" will read ",myid)
		tr = EMUtil.get_all_attributes(prjfile,'xform.projection')
		tetprj = [0.0]*nprj
		for i in range(nprj):
			temp = tr[i].get_params("spider")
			tetprj[i] = temp["theta"]
			if(tetprj[i] > 90.0): tetprj[i]  = 180.0 - tetprj[i] 
			vct[3*i+0] = tr[i].at(2,0)
			vct[3*i+1] = tr[i].at(2,1)
			vct[3*i+2] = tr[i].at(2,2)
		del tr
	else:
		tetprj = [0.0]*nprj
	#print "  READ ",myid
	if  MPI:
		#print " will bcast",myid
		vct = mpi.mpi_bcast( vct, len(vct), mpi.MPI_FLOAT, 0, mpi.MPI_COMM_WORLD )
		from sp_utilities import  bcast_list_to_all
		tetprj = bcast_list_to_all(tetprj, myid, 0)
	#print  "  reshape  ",myid
	vct = reshape(vct,(nprj,3))
	assignments = [[] for i in range(nrefa)]
	dspn = 1.25*delta
	for k in range(nprj):
		best_s = -1.0
		best_i = -1
		for i in range( nrefa ):
			if(abs(tetprj[k] - tetref[i]) <= dspn):
				s = abs(refnormal[i][0]*vct[k][0] + refnormal[i][1]*vct[k][1] + refnormal[i][2]*vct[k][2])
				if s > best_s:
					best_s = s
					best_i = i
			assignments[best_i].append(k)
	am = len(assignments[0])
	mufur = 1.0/am
	for i in range(1,len(assignments)):
		ti = len(assignments[i])
		am = min(am, ti)
		if(ti>0):  mufur += 1.0/ti

	del tetprj,tetref

	dp = 1.0 - d  # keep that many in each direction
	keep = int(am*dp +0.5)
	mufur = keep*nrefa/(1.0 - mufur*keep/float(nrefa))
	if myid == 0:
		sxprint(" Number of projections ",nprj,".  Number of reference directions ",nrefa,",  multiplicative factor for the variance ",mufur)
		sxprint(" Minimum number of assignments ",am,"  Number of projections used per stratum ", keep," Number of projections in resampled structure ",int(am*dp +0.5)*nrefa)
		if am <2 or am == keep:
			sxprint("incorrect settings")
			exit()  #                                         FIX

	if(seedbase < 1):
		seed()
		jumpahead(17*myid+123)
	else:
		seed(seedbase)
		jumpahead(17*myid+123)

	volfile = os.path.join(outdir, "bsvol%04d.hdf" % myid)
	from random import randint
	niter = nvol/ncpu/nbufvol
	for kiter in range(niter):
		if(verbose == 1):
			finfo.write( "Iteration %d: \n" % kiter )
			finfo.flush()

		iter_start = time()
		#  the following has to be converted to resample  mults=1 means take given projection., mults=0 means omit

		mults = [ [0]*nprj for i in range(nbufvol) ]
		for i in range(nbufvol):
			for l in range(nrefa):
				mass = assignments[l][:]
				shuffle(mass)
				mass = mass[:keep]
				mass.sort()
				#print  l, "  *  ",mass
				for k in range(keep):
					mults[i][mass[k]] = 1
			'''
			lout = []
			for l in xrange(len(mults[i])):
				if mults[i][l] == 1:  lout.append(l)
			write_text_file(lout, os.path.join(outdir, "list%04d_%03d.txt" %(i, myid)))
			del lout
			'''

		del mass

		rectors, fftvols, wgtvols = resample_prepare( prjfile, nbufvol, snr, CTF, npad )
		resample_insert( bufprefix, fftvols, wgtvols, mults, CTF, npad, finfo )
		del mults
		resample_finish( rectors, fftvols, wgtvols, volfile, kiter, nprj, finfo )
		rectors = None
		fftvols = None
		wgtvols = None
		if(verbose == 1):
			finfo.write( "time for iteration: %10.3f\n" % (time() - iter_start) )
			finfo.flush()
Beispiel #6
0
def main():
    arglist = []
    for arg in sys.argv:
        arglist.append(arg)
    progname = os.path.basename(arglist[0])
    usage = progname + """ firstvolume  secondvolume  maskfile  directory  --prefix  --wn  --step  --cutoff  --radius  --fsc  --res_overall  --out_ang_res  --apix  --MPI

	Compute local resolution in real space within area outlined by the maskfile and within regions wn x wn x wn
	"""
    parser = optparse.OptionParser(usage, version=sp_global_def.SPARXVERSION)

    parser.add_option("--prefix",
                      type="str",
                      default='localres',
                      help="Prefix for the output files. (default localres)")
    parser.add_option(
        "--wn",
        type="int",
        default=7,
        help=
        "Size of window within which local real-space FSC is computed. (default 7)"
    )
    parser.add_option(
        "--step",
        type="float",
        default=1.0,
        help="Shell step in Fourier size in pixels. (default 1.0)")
    parser.add_option("--cutoff",
                      type="float",
                      default=0.143,
                      help="Resolution cut-off for FSC. (default 0.143)")
    parser.add_option(
        "--radius",
        type="int",
        default=-1,
        help=
        "If there is no maskfile, sphere with r=radius will be used. By default, the radius is nx/2-wn (default -1)"
    )
    parser.add_option(
        "--fsc",
        type="string",
        default=None,
        help=
        "Save overall FSC curve (might be truncated). By default, the program does not save the FSC curve. (default none)"
    )
    parser.add_option(
        "--res_overall",
        type="float",
        default=-1.0,
        help=
        "Overall resolution at the cutoff level estimated by the user [abs units]. (default None)"
    )
    parser.add_option(
        "--out_ang_res",
        action="store_true",
        default=False,
        help=
        "Additionally creates a local resolution file in Angstroms. (default False)"
    )
    parser.add_option(
        "--apix",
        type="float",
        default=1.0,
        help=
        "Pixel size in Angstrom. Effective only with --out_ang_res options. (default 1.0)"
    )
    parser.add_option("--MPI",
                      action="store_true",
                      default=False,
                      help="Use MPI version.")

    (options, args) = parser.parse_args(arglist[1:])

    if len(args) < 3 or len(args) > 4:
        sxprint("Usage: " + usage)
        ERROR(
            "Invalid number of parameters used. Please see usage information above."
        )
        return

    if sp_global_def.CACHE_DISABLE:
        sp_utilities.disable_bdb_cache()

    res_overall = options.res_overall

    if options.MPI:

        number_of_proc = mpi.mpi_comm_size(mpi.MPI_COMM_WORLD)
        myid = mpi.mpi_comm_rank(mpi.MPI_COMM_WORLD)
        main_node = 0
        sp_global_def.MPI = True
        cutoff = options.cutoff

        nk = int(options.wn)

        if (myid == main_node):
            #print sys.argv
            vi = sp_utilities.get_im(sys.argv[1])
            ui = sp_utilities.get_im(sys.argv[2])

            nx = vi.get_xsize()
            ny = vi.get_ysize()
            nz = vi.get_zsize()
            dis = [nx, ny, nz]
        else:
            dis = [0, 0, 0, 0]

        sp_global_def.BATCH = True

        dis = sp_utilities.bcast_list_to_all(dis, myid, source_node=main_node)

        if (myid != main_node):
            nx = int(dis[0])
            ny = int(dis[1])
            nz = int(dis[2])

            vi = sp_utilities.model_blank(nx, ny, nz)
            ui = sp_utilities.model_blank(nx, ny, nz)

        if len(args) == 3:
            m = sp_utilities.model_circle((min(nx, ny, nz) - nk) // 2, nx, ny,
                                          nz)
            outdir = args[2]

        elif len(args) == 4:
            if (myid == main_node):
                m = sp_morphology.binarize(sp_utilities.get_im(args[2]), 0.5)
            else:
                m = sp_utilities.model_blank(nx, ny, nz)
            outdir = args[3]
        if os.path.exists(outdir) and myid == 0:
            sp_global_def.ERROR('Output directory already exists!')
        elif myid == 0:
            os.makedirs(outdir)
        sp_global_def.write_command(outdir)
        sp_utilities.bcast_EMData_to_all(m, myid, main_node)
        """Multiline Comment0"""
        freqvol, resolut = sp_statistics.locres(vi, ui, m, nk, cutoff,
                                                options.step, myid, main_node,
                                                number_of_proc)

        if (myid == 0):
            # Remove outliers based on the Interquartile range
            output_volume(freqvol, resolut, options.apix, outdir,
                          options.prefix, options.fsc, options.out_ang_res, nx,
                          ny, nz, res_overall)

    else:
        cutoff = options.cutoff
        vi = sp_utilities.get_im(args[0])
        ui = sp_utilities.get_im(args[1])

        nn = vi.get_xsize()
        nx = nn
        ny = nn
        nz = nn
        nk = int(options.wn)

        if len(args) == 3:
            m = sp_utilities.model_circle((nn - nk) // 2, nn, nn, nn)
            outdir = args[2]

        elif len(args) == 4:
            m = sp_morphology.binarize(sp_utilities.get_im(args[2]), 0.5)
            outdir = args[3]
        if os.path.exists(outdir):
            sp_global_def.ERROR('Output directory already exists!')
        else:
            os.makedirs(outdir)
        sp_global_def.write_command(outdir)

        mc = sp_utilities.model_blank(nn, nn, nn, 1.0) - m

        vf = sp_fundamentals.fft(vi)
        uf = sp_fundamentals.fft(ui)
        """Multiline Comment1"""
        lp = int(nn / 2 / options.step + 0.5)
        step = 0.5 / lp

        freqvol = sp_utilities.model_blank(nn, nn, nn)
        resolut = []
        for i in range(1, lp):
            fl = step * i
            fh = fl + step
            #print(lp,i,step,fl,fh)
            v = sp_fundamentals.fft(sp_filter.filt_tophatb(vf, fl, fh))
            u = sp_fundamentals.fft(sp_filter.filt_tophatb(uf, fl, fh))
            tmp1 = EMAN2_cppwrap.Util.muln_img(v, v)
            tmp2 = EMAN2_cppwrap.Util.muln_img(u, u)

            do = EMAN2_cppwrap.Util.infomask(
                sp_morphology.square_root(
                    sp_morphology.threshold(
                        EMAN2_cppwrap.Util.muln_img(tmp1, tmp2))), m, True)[0]

            tmp3 = EMAN2_cppwrap.Util.muln_img(u, v)
            dp = EMAN2_cppwrap.Util.infomask(tmp3, m, True)[0]
            resolut.append([i, (fl + fh) / 2.0, dp / do])

            tmp1 = EMAN2_cppwrap.Util.box_convolution(tmp1, nk)
            tmp2 = EMAN2_cppwrap.Util.box_convolution(tmp2, nk)
            tmp3 = EMAN2_cppwrap.Util.box_convolution(tmp3, nk)

            EMAN2_cppwrap.Util.mul_img(tmp1, tmp2)

            tmp1 = sp_morphology.square_root(sp_morphology.threshold(tmp1))

            EMAN2_cppwrap.Util.mul_img(tmp1, m)
            EMAN2_cppwrap.Util.add_img(tmp1, mc)

            EMAN2_cppwrap.Util.mul_img(tmp3, m)
            EMAN2_cppwrap.Util.add_img(tmp3, mc)

            EMAN2_cppwrap.Util.div_img(tmp3, tmp1)

            EMAN2_cppwrap.Util.mul_img(tmp3, m)
            freq = (fl + fh) / 2.0
            bailout = True
            for x in range(nn):
                for y in range(nn):
                    for z in range(nn):
                        if (m.get_value_at(x, y, z) > 0.5):
                            if (freqvol.get_value_at(x, y, z) == 0.0):
                                if (tmp3.get_value_at(x, y, z) < cutoff):
                                    freqvol.set_value_at(x, y, z, freq)
                                    bailout = False
                                else:
                                    bailout = False
            if (bailout): break
        #print(len(resolut))
        # remove outliers
        output_volume(freqvol, resolut, options.apix, outdir, options.prefix,
                      options.fsc, options.out_ang_res, nx, ny, nz,
                      res_overall)