Example #1
0
def filterlocal(ui, vi, m, falloff, myid, main_node, number_of_proc):

    if myid == main_node:

        nx = vi.get_xsize()
        ny = vi.get_ysize()
        nz = vi.get_zsize()
        #  Round all resolution numbers to two digits
        for x in range(nx):
            for y in range(ny):
                for z in range(nz):
                    ui.set_value_at_fast(x, y, z, round(ui.get_value_at(x, y, z), 2))
        dis = [nx, ny, nz]
    else:
        falloff = 0.0
        radius = 0
        dis = [0, 0, 0]
    falloff = sp_utilities.bcast_number_to_all(falloff, main_node)
    dis = sp_utilities.bcast_list_to_all(dis, myid, source_node=main_node)

    if myid != main_node:
        nx = int(dis[0])
        ny = int(dis[1])
        nz = int(dis[2])

        vi = sp_utilities.model_blank(nx, ny, nz)
        ui = sp_utilities.model_blank(nx, ny, nz)

    sp_utilities.bcast_EMData_to_all(vi, myid, main_node)
    sp_utilities.bcast_EMData_to_all(ui, myid, main_node)

    sp_fundamentals.fftip(vi)  #  volume to be filtered

    st = EMAN2_cppwrap.Util.infomask(ui, m, True)

    filteredvol = sp_utilities.model_blank(nx, ny, nz)
    cutoff = max(st[2] - 0.01, 0.0)
    while cutoff < st[3]:
        cutoff = round(cutoff + 0.01, 2)
        # if(myid == main_node):  print  cutoff,st
        pt = EMAN2_cppwrap.Util.infomask(
            sp_morphology.threshold_outside(ui, cutoff - 0.00501, cutoff + 0.005),
            m,
            True,
        )  # Ideally, one would want to check only slices in question...
        if pt[0] != 0.0:
            # print cutoff,pt[0]
            vovo = sp_fundamentals.fft(filt_tanl(vi, cutoff, falloff))
            for z in range(myid, nz, number_of_proc):
                for x in range(nx):
                    for y in range(ny):
                        if m.get_value_at(x, y, z) > 0.5:
                            if round(ui.get_value_at(x, y, z), 2) == cutoff:
                                filteredvol.set_value_at_fast(
                                    x, y, z, vovo.get_value_at(x, y, z)
                                )

    mpi.mpi_barrier(mpi.MPI_COMM_WORLD)
    sp_utilities.reduce_EMData_to_root(filteredvol, myid, main_node, mpi.MPI_COMM_WORLD)
    return filteredvol
Example #2
0
def helicalshiftali_MPI(stack,
                        maskfile=None,
                        maxit=100,
                        CTF=False,
                        snr=1.0,
                        Fourvar=False,
                        search_rng=-1):

    nproc = mpi.mpi_comm_size(mpi.MPI_COMM_WORLD)
    myid = mpi.mpi_comm_rank(mpi.MPI_COMM_WORLD)
    main_node = 0

    ftp = file_type(stack)

    if myid == main_node:
        print_begin_msg("helical-shiftali_MPI")

    max_iter = int(maxit)
    if (myid == main_node):
        infils = EMUtil.get_all_attributes(stack, "filament")
        ptlcoords = EMUtil.get_all_attributes(stack, 'ptcl_source_coord')
        filaments = ordersegments(infils, ptlcoords)
        total_nfils = len(filaments)
        inidl = [0] * total_nfils
        for i in range(total_nfils):
            inidl[i] = len(filaments[i])
        linidl = sum(inidl)
        nima = linidl
        tfilaments = []
        for i in range(total_nfils):
            tfilaments += filaments[i]
        del filaments
    else:
        total_nfils = 0
        linidl = 0
    total_nfils = bcast_number_to_all(total_nfils, source_node=main_node)
    if myid != main_node:
        inidl = [-1] * total_nfils
    inidl = bcast_list_to_all(inidl, myid, source_node=main_node)
    linidl = bcast_number_to_all(linidl, source_node=main_node)
    if myid != main_node:
        tfilaments = [-1] * linidl
    tfilaments = bcast_list_to_all(tfilaments, myid, source_node=main_node)
    filaments = []
    iendi = 0
    for i in range(total_nfils):
        isti = iendi
        iendi = isti + inidl[i]
        filaments.append(tfilaments[isti:iendi])
    del tfilaments, inidl

    if myid == main_node:
        print_msg("total number of filaments: %d" % total_nfils)
    if total_nfils < nproc:
        ERROR(
            'number of CPUs (%i) is larger than the number of filaments (%i), please reduce the number of CPUs used'
            % (nproc, total_nfils),
            myid=myid)

    #  balanced load
    temp = chunks_distribution([[len(filaments[i]), i]
                                for i in range(len(filaments))],
                               nproc)[myid:myid + 1][0]
    filaments = [filaments[temp[i][1]] for i in range(len(temp))]
    nfils = len(filaments)

    #filaments = [[0,1]]
    #print "filaments",filaments
    list_of_particles = []
    indcs = []
    k = 0
    for i in range(nfils):
        list_of_particles += filaments[i]
        k1 = k + len(filaments[i])
        indcs.append([k, k1])
        k = k1
    data = EMData.read_images(stack, list_of_particles)
    ldata = len(data)
    sxprint("ldata=", ldata)
    nx = data[0].get_xsize()
    ny = data[0].get_ysize()
    if maskfile == None:
        mrad = min(nx, ny) // 2 - 2
        mask = pad(model_blank(2 * mrad + 1, ny, 1, 1.0), nx, ny, 1, 0.0)
    else:
        mask = get_im(maskfile)

    # apply initial xform.align2d parameters stored in header
    init_params = []
    for im in range(ldata):
        t = data[im].get_attr('xform.align2d')
        init_params.append(t)
        p = t.get_params("2d")
        data[im] = rot_shift2D(data[im], p['alpha'], p['tx'], p['ty'],
                               p['mirror'], p['scale'])

    if CTF:
        from sp_filter import filt_ctf
        from sp_morphology import ctf_img
        ctf_abs_sum = EMData(nx, ny, 1, False)
        ctf_2_sum = EMData(nx, ny, 1, False)
    else:
        ctf_2_sum = None
        ctf_abs_sum = None

    from sp_utilities import info

    for im in range(ldata):
        data[im].set_attr('ID', list_of_particles[im])
        st = Util.infomask(data[im], mask, False)
        data[im] -= st[0]
        if CTF:
            ctf_params = data[im].get_attr("ctf")
            qctf = data[im].get_attr("ctf_applied")
            if qctf == 0:
                data[im] = filt_ctf(fft(data[im]), ctf_params)
                data[im].set_attr('ctf_applied', 1)
            elif qctf != 1:
                ERROR('Incorrectly set qctf flag', myid=myid)
            ctfimg = ctf_img(nx, ctf_params, ny=ny)
            Util.add_img2(ctf_2_sum, ctfimg)
            Util.add_img_abs(ctf_abs_sum, ctfimg)
        else:
            data[im] = fft(data[im])

    del list_of_particles

    if CTF:
        reduce_EMData_to_root(ctf_2_sum, myid, main_node)
        reduce_EMData_to_root(ctf_abs_sum, myid, main_node)
    if CTF:
        if myid != main_node:
            del ctf_2_sum
            del ctf_abs_sum
        else:
            temp = EMData(nx, ny, 1, False)
            tsnr = 1. / snr
            for i in range(0, nx + 2, 2):
                for j in range(ny):
                    temp.set_value_at(i, j, tsnr)
                    temp.set_value_at(i + 1, j, 0.0)
            #info(ctf_2_sum)
            Util.add_img(ctf_2_sum, temp)
            #info(ctf_2_sum)
            del temp

    total_iter = 0
    shift_x = [0.0] * ldata

    for Iter in range(max_iter):
        if myid == main_node:
            start_time = time()
            print_msg("Iteration #%4d\n" % (total_iter))
        total_iter += 1
        avg = EMData(nx, ny, 1, False)
        for im in range(ldata):
            Util.add_img(avg, fshift(data[im], shift_x[im]))

        reduce_EMData_to_root(avg, myid, main_node)

        if myid == main_node:
            if CTF: tavg = Util.divn_filter(avg, ctf_2_sum)
            else: tavg = Util.mult_scalar(avg, 1.0 / float(nima))
        else:
            tavg = model_blank(nx, ny)

        if Fourvar:
            bcast_EMData_to_all(tavg, myid, main_node)
            vav, rvar = varf2d_MPI(myid, data, tavg, mask, "a", CTF)

        if myid == main_node:
            if Fourvar:
                tavg = fft(Util.divn_img(fft(tavg), vav))
                vav_r = Util.pack_complex_to_real(vav)
            # normalize and mask tavg in real space
            tavg = fft(tavg)
            stat = Util.infomask(tavg, mask, False)
            tavg -= stat[0]
            Util.mul_img(tavg, mask)
            tavg.write_image("tavg.hdf", Iter)
            # For testing purposes: shift tavg to some random place and see if the centering is still correct
            #tavg = rot_shift3D(tavg,sx=3,sy=-4)

        if Fourvar: del vav
        bcast_EMData_to_all(tavg, myid, main_node)
        tavg = fft(tavg)

        sx_sum = 0.0
        nxc = nx // 2

        for ifil in range(nfils):
            """
			# Calculate filament average
			avg = EMData(nx, ny, 1, False)
			filnima = 0
			for im in xrange(indcs[ifil][0], indcs[ifil][1]):
				Util.add_img(avg, data[im])
				filnima += 1
			tavg = Util.mult_scalar(avg, 1.0/float(filnima))
			"""
            # Calculate 1D ccf between each segment and filament average
            nsegms = indcs[ifil][1] - indcs[ifil][0]
            ctx = [None] * nsegms
            pcoords = [None] * nsegms
            for im in range(indcs[ifil][0], indcs[ifil][1]):
                ctx[im - indcs[ifil][0]] = Util.window(ccf(tavg, data[im]), nx,
                                                       1)
                pcoords[im - indcs[ifil][0]] = data[im].get_attr(
                    'ptcl_source_coord')
                #ctx[im-indcs[ifil][0]].write_image("ctx.hdf",im-indcs[ifil][0])
                #print "  CTX  ",myid,im,Util.infomask(ctx[im-indcs[ifil][0]], None, True)
            # search for best x-shift
            cents = nsegms // 2

            dst = sqrt(
                max((pcoords[cents][0] - pcoords[0][0])**2 +
                    (pcoords[cents][1] - pcoords[0][1])**2,
                    (pcoords[cents][0] - pcoords[-1][0])**2 +
                    (pcoords[cents][1] - pcoords[-1][1])**2))
            maxincline = atan2(ny // 2 - 2 - float(search_rng), dst)
            kang = int(dst * tan(maxincline) + 0.5)
            #print  "  settings ",nsegms,cents,dst,search_rng,maxincline,kang

            # ## C code for alignment. @ming
            results = [0.0] * 3
            results = Util.helixshiftali(ctx, pcoords, nsegms, maxincline,
                                         kang, search_rng, nxc)
            sib = int(results[0])
            bang = results[1]
            qm = results[2]
            #print qm, sib, bang

            # qm = -1.e23
            #
            # 			for six in xrange(-search_rng, search_rng+1,1):
            # 				q0 = ctx[cents].get_value_at(six+nxc)
            # 				for incline in xrange(kang+1):
            # 					qt = q0
            # 					qu = q0
            # 					if(kang>0):  tang = tan(maxincline/kang*incline)
            # 					else:        tang = 0.0
            # 					for kim in xrange(cents+1,nsegms):
            # 						dst = sqrt((pcoords[cents][0] - pcoords[kim][0])**2 + (pcoords[cents][1] - pcoords[kim][1])**2)
            # 						xl = dst*tang+six+nxc
            # 						ixl = int(xl)
            # 						dxl = xl - ixl
            # 						#print "  A  ", ifil,six,incline,kim,xl,ixl,dxl
            # 						qt += (1.0-dxl)*ctx[kim].get_value_at(ixl) + dxl*ctx[kim].get_value_at(ixl+1)
            # 						xl = -dst*tang+six+nxc
            # 						ixl = int(xl)
            # 						dxl = xl - ixl
            # 						qu += (1.0-dxl)*ctx[kim].get_value_at(ixl) + dxl*ctx[kim].get_value_at(ixl+1)
            # 					for kim in xrange(cents):
            # 						dst = sqrt((pcoords[cents][0] - pcoords[kim][0])**2 + (pcoords[cents][1] - pcoords[kim][1])**2)
            # 						xl = -dst*tang+six+nxc
            # 						ixl = int(xl)
            # 						dxl = xl - ixl
            # 						qt += (1.0-dxl)*ctx[kim].get_value_at(ixl) + dxl*ctx[kim].get_value_at(ixl+1)
            # 						xl =  dst*tang+six+nxc
            # 						ixl = int(xl)
            # 						dxl = xl - ixl
            # 						qu += (1.0-dxl)*ctx[kim].get_value_at(ixl) + dxl*ctx[kim].get_value_at(ixl+1)
            # 					if( qt > qm ):
            # 						qm = qt
            # 						sib = six
            # 						bang = tang
            # 					if( qu > qm ):
            # 						qm = qu
            # 						sib = six
            # 						bang = -tang
            #if incline == 0:  print  "incline = 0  ",six,tang,qt,qu
            #print qm,six,sib,bang
            #print " got results   ",indcs[ifil][0], indcs[ifil][1], ifil,myid,qm,sib,tang,bang,len(ctx),Util.infomask(ctx[0], None, True)
            for im in range(indcs[ifil][0], indcs[ifil][1]):
                kim = im - indcs[ifil][0]
                dst = sqrt((pcoords[cents][0] - pcoords[kim][0])**2 +
                           (pcoords[cents][1] - pcoords[kim][1])**2)
                if (kim < cents): xl = -dst * bang + sib
                else: xl = dst * bang + sib
                shift_x[im] = xl

            # Average shift
            sx_sum += shift_x[indcs[ifil][0] + cents]

        # #print myid,sx_sum,total_nfils
        sx_sum = mpi.mpi_reduce(sx_sum, 1, mpi.MPI_FLOAT, mpi.MPI_SUM,
                                main_node, mpi.MPI_COMM_WORLD)
        if myid == main_node:
            sx_sum = float(sx_sum[0]) / total_nfils
            print_msg("Average shift  %6.2f\n" % (sx_sum))
        else:
            sx_sum = 0.0
        sx_sum = 0.0
        sx_sum = bcast_number_to_all(sx_sum, source_node=main_node)
        for im in range(ldata):
            shift_x[im] -= sx_sum
            #print  "   %3d  %6.3f"%(im,shift_x[im])
        #exit()

    # combine shifts found with the original parameters
    for im in range(ldata):
        t1 = Transform()
        ##import random
        ##shix=random.randint(-10, 10)
        ##t1.set_params({"type":"2D","tx":shix})
        t1.set_params({"type": "2D", "tx": shift_x[im]})
        # combine t0 and t1
        tt = t1 * init_params[im]
        data[im].set_attr("xform.align2d", tt)
    # write out headers and STOP, under MPI writing has to be done sequentially
    mpi.mpi_barrier(mpi.MPI_COMM_WORLD)
    par_str = ["xform.align2d", "ID"]
    if myid == main_node:
        from sp_utilities import file_type
        if (file_type(stack) == "bdb"):
            from sp_utilities import recv_attr_dict_bdb
            recv_attr_dict_bdb(main_node, stack, data, par_str, 0, ldata,
                               nproc)
        else:
            from sp_utilities import recv_attr_dict
            recv_attr_dict(main_node, stack, data, par_str, 0, ldata, nproc)
    else:
        send_attr_dict(main_node, data, par_str, 0, ldata)
    if myid == main_node: print_end_msg("helical-shiftali_MPI")
Example #3
0
def shiftali_MPI(stack,
                 maskfile=None,
                 maxit=100,
                 CTF=False,
                 snr=1.0,
                 Fourvar=False,
                 search_rng=-1,
                 oneDx=False,
                 search_rng_y=-1):

    number_of_proc = mpi.mpi_comm_size(mpi.MPI_COMM_WORLD)
    myid = mpi.mpi_comm_rank(mpi.MPI_COMM_WORLD)
    main_node = 0

    ftp = file_type(stack)

    if myid == main_node:
        print_begin_msg("shiftali_MPI")

    max_iter = int(maxit)

    if myid == main_node:
        if ftp == "bdb":
            from EMAN2db import db_open_dict
            dummy = db_open_dict(stack, True)
        nima = EMUtil.get_image_count(stack)
    else:
        nima = 0
    nima = bcast_number_to_all(nima, source_node=main_node)
    list_of_particles = list(range(nima))

    image_start, image_end = MPI_start_end(nima, number_of_proc, myid)
    list_of_particles = list_of_particles[image_start:image_end]

    # read nx and ctf_app (if CTF) and broadcast to all nodes
    if myid == main_node:
        ima = EMData()
        ima.read_image(stack, list_of_particles[0], True)
        nx = ima.get_xsize()
        ny = ima.get_ysize()
        if CTF: ctf_app = ima.get_attr_default('ctf_applied', 2)
        del ima
    else:
        nx = 0
        ny = 0
        if CTF: ctf_app = 0
    nx = bcast_number_to_all(nx, source_node=main_node)
    ny = bcast_number_to_all(ny, source_node=main_node)
    if CTF:
        ctf_app = bcast_number_to_all(ctf_app, source_node=main_node)
        if ctf_app > 0:
            ERROR("data cannot be ctf-applied", myid=myid)

    if maskfile == None:
        mrad = min(nx, ny)
        mask = model_circle(mrad // 2 - 2, nx, ny)
    else:
        mask = get_im(maskfile)

    if CTF:
        from sp_filter import filt_ctf
        from sp_morphology import ctf_img
        ctf_abs_sum = EMData(nx, ny, 1, False)
        ctf_2_sum = EMData(nx, ny, 1, False)
    else:
        ctf_2_sum = None

    from sp_global_def import CACHE_DISABLE
    if CACHE_DISABLE:
        data = EMData.read_images(stack, list_of_particles)
    else:
        for i in range(number_of_proc):
            if myid == i:
                data = EMData.read_images(stack, list_of_particles)
            if ftp == "bdb": mpi.mpi_barrier(mpi.MPI_COMM_WORLD)

    for im in range(len(data)):
        data[im].set_attr('ID', list_of_particles[im])
        st = Util.infomask(data[im], mask, False)
        data[im] -= st[0]
        if CTF:
            ctf_params = data[im].get_attr("ctf")
            ctfimg = ctf_img(nx, ctf_params, ny=ny)
            Util.add_img2(ctf_2_sum, ctfimg)
            Util.add_img_abs(ctf_abs_sum, ctfimg)

    if CTF:
        reduce_EMData_to_root(ctf_2_sum, myid, main_node)
        reduce_EMData_to_root(ctf_abs_sum, myid, main_node)
    else:
        ctf_2_sum = None
    if CTF:
        if myid != main_node:
            del ctf_2_sum
            del ctf_abs_sum
        else:
            temp = EMData(nx, ny, 1, False)
            for i in range(0, nx, 2):
                for j in range(ny):
                    temp.set_value_at(i, j, snr)
            Util.add_img(ctf_2_sum, temp)
            del temp

    total_iter = 0

    # apply initial xform.align2d parameters stored in header
    init_params = []
    for im in range(len(data)):
        t = data[im].get_attr('xform.align2d')
        init_params.append(t)
        p = t.get_params("2d")
        data[im] = rot_shift2D(data[im],
                               p['alpha'],
                               sx=p['tx'],
                               sy=p['ty'],
                               mirror=p['mirror'],
                               scale=p['scale'])

    # fourier transform all images, and apply ctf if CTF
    for im in range(len(data)):
        if CTF:
            ctf_params = data[im].get_attr("ctf")
            data[im] = filt_ctf(fft(data[im]), ctf_params)
        else:
            data[im] = fft(data[im])

    sx_sum = 0
    sy_sum = 0
    sx_sum_total = 0
    sy_sum_total = 0
    shift_x = [0.0] * len(data)
    shift_y = [0.0] * len(data)
    ishift_x = [0.0] * len(data)
    ishift_y = [0.0] * len(data)

    for Iter in range(max_iter):
        if myid == main_node:
            start_time = time()
            print_msg("Iteration #%4d\n" % (total_iter))
        total_iter += 1
        avg = EMData(nx, ny, 1, False)
        for im in data:
            Util.add_img(avg, im)

        reduce_EMData_to_root(avg, myid, main_node)

        if myid == main_node:
            if CTF:
                tavg = Util.divn_filter(avg, ctf_2_sum)
            else:
                tavg = Util.mult_scalar(avg, 1.0 / float(nima))
        else:
            tavg = EMData(nx, ny, 1, False)

        if Fourvar:
            bcast_EMData_to_all(tavg, myid, main_node)
            vav, rvar = varf2d_MPI(myid, data, tavg, mask, "a", CTF)

        if myid == main_node:
            if Fourvar:
                tavg = fft(Util.divn_img(fft(tavg), vav))
                vav_r = Util.pack_complex_to_real(vav)

            # normalize and mask tavg in real space
            tavg = fft(tavg)
            stat = Util.infomask(tavg, mask, False)
            tavg -= stat[0]
            Util.mul_img(tavg, mask)
            # For testing purposes: shift tavg to some random place and see if the centering is still correct
            #tavg = rot_shift3D(tavg,sx=3,sy=-4)
            tavg = fft(tavg)

        if Fourvar: del vav
        bcast_EMData_to_all(tavg, myid, main_node)

        sx_sum = 0
        sy_sum = 0
        if search_rng > 0: nwx = 2 * search_rng + 1
        else: nwx = nx

        if search_rng_y > 0: nwy = 2 * search_rng_y + 1
        else: nwy = ny

        not_zero = 0
        for im in range(len(data)):
            if oneDx:
                ctx = Util.window(ccf(data[im], tavg), nwx, 1)
                p1 = peak_search(ctx)
                p1_x = -int(p1[0][3])
                ishift_x[im] = p1_x
                sx_sum += p1_x
            else:
                p1 = peak_search(Util.window(ccf(data[im], tavg), nwx, nwy))
                p1_x = -int(p1[0][4])
                p1_y = -int(p1[0][5])
                ishift_x[im] = p1_x
                ishift_y[im] = p1_y
                sx_sum += p1_x
                sy_sum += p1_y

            if not_zero == 0:
                if (not (ishift_x[im] == 0.0)) or (not (ishift_y[im] == 0.0)):
                    not_zero = 1

        sx_sum = mpi.mpi_reduce(sx_sum, 1, mpi.MPI_INT, mpi.MPI_SUM, main_node,
                                mpi.MPI_COMM_WORLD)

        if not oneDx:
            sy_sum = mpi.mpi_reduce(sy_sum, 1, mpi.MPI_INT, mpi.MPI_SUM,
                                    main_node, mpi.MPI_COMM_WORLD)

        if myid == main_node:
            sx_sum_total = int(sx_sum[0])
            if not oneDx:
                sy_sum_total = int(sy_sum[0])
        else:
            sx_sum_total = 0
            sy_sum_total = 0

        sx_sum_total = bcast_number_to_all(sx_sum_total, source_node=main_node)

        if not oneDx:
            sy_sum_total = bcast_number_to_all(sy_sum_total,
                                               source_node=main_node)

        sx_ave = round(float(sx_sum_total) / nima)
        sy_ave = round(float(sy_sum_total) / nima)
        for im in range(len(data)):
            p1_x = ishift_x[im] - sx_ave
            p1_y = ishift_y[im] - sy_ave
            params2 = {
                "filter_type": Processor.fourier_filter_types.SHIFT,
                "x_shift": p1_x,
                "y_shift": p1_y,
                "z_shift": 0.0
            }
            data[im] = Processor.EMFourierFilter(data[im], params2)
            shift_x[im] += p1_x
            shift_y[im] += p1_y
        # stop if all shifts are zero
        not_zero = mpi.mpi_reduce(not_zero, 1, mpi.MPI_INT, mpi.MPI_SUM,
                                  main_node, mpi.MPI_COMM_WORLD)
        if myid == main_node:
            not_zero_all = int(not_zero[0])
        else:
            not_zero_all = 0
        not_zero_all = bcast_number_to_all(not_zero_all, source_node=main_node)

        if myid == main_node:
            print_msg("Time of iteration = %12.2f\n" % (time() - start_time))
            start_time = time()

        if not_zero_all == 0: break

    #for im in xrange(len(data)): data[im] = fft(data[im])  This should not be required as only header information is used
    # combine shifts found with the original parameters
    for im in range(len(data)):
        t0 = init_params[im]
        t1 = Transform()
        t1.set_params({
            "type": "2D",
            "alpha": 0,
            "scale": t0.get_scale(),
            "mirror": 0,
            "tx": shift_x[im],
            "ty": shift_y[im]
        })
        # combine t0 and t1
        tt = t1 * t0
        data[im].set_attr("xform.align2d", tt)

    # write out headers and STOP, under MPI writing has to be done sequentially
    mpi.mpi_barrier(mpi.MPI_COMM_WORLD)
    par_str = ["xform.align2d", "ID"]
    if myid == main_node:
        from sp_utilities import file_type
        if (file_type(stack) == "bdb"):
            from sp_utilities import recv_attr_dict_bdb
            recv_attr_dict_bdb(main_node, stack, data, par_str, image_start,
                               image_end, number_of_proc)
        else:
            from sp_utilities import recv_attr_dict
            recv_attr_dict(main_node, stack, data, par_str, image_start,
                           image_end, number_of_proc)

    else:
        send_attr_dict(main_node, data, par_str, image_start, image_end)
    if myid == main_node: print_end_msg("shiftali_MPI")
Example #4
0
def filterlocal(ui, vi, m, falloff, myid, main_node, number_of_proc):
	from mpi 	  	  import mpi_init, mpi_comm_size, mpi_comm_rank, MPI_COMM_WORLD
	from mpi 	  	  import mpi_reduce, mpi_bcast, mpi_barrier, mpi_gatherv, mpi_send, mpi_recv
	from mpi 	  	  import MPI_SUM, MPI_FLOAT, MPI_INT
	from sp_utilities import bcast_number_to_all, bcast_list_to_all, model_blank, bcast_EMData_to_all, reduce_EMData_to_root
	from sp_morphology import threshold_outside
	from sp_filter import filt_tanl
	from sp_fundamentals import fft, fftip

	if(myid == main_node):

		nx = vi.get_xsize()
		ny = vi.get_ysize()
		nz = vi.get_zsize()
		#  Round all resolution numbers to two digits
		for x in range(nx):
			for y in range(ny):
				for z in range(nz):
					ui.set_value_at_fast( x,y,z, round(ui.get_value_at(x,y,z), 2) )
		dis = [nx,ny,nz]
	else:
		falloff = 0.0
		radius  = 0
		dis = [0,0,0]
	falloff = bcast_number_to_all(falloff, main_node)
	dis = bcast_list_to_all(dis, myid, source_node = main_node)

	if(myid != main_node):
		nx = int(dis[0])
		ny = int(dis[1])
		nz = int(dis[2])

		vi = model_blank(nx,ny,nz)
		ui = model_blank(nx,ny,nz)

	bcast_EMData_to_all(vi, myid, main_node)
	bcast_EMData_to_all(ui, myid, main_node)

	fftip(vi)  #  volume to be filtered

	st = Util.infomask(ui, m, True)


	filteredvol = model_blank(nx,ny,nz)
	cutoff = max(st[2] - 0.01,0.0)
	while(cutoff < st[3] ):
		cutoff = round(cutoff + 0.01, 2)
		#if(myid == main_node):  print  cutoff,st
		pt = Util.infomask( threshold_outside(ui, cutoff - 0.00501, cutoff + 0.005), m, True)  # Ideally, one would want to check only slices in question...
		if(pt[0] != 0.0):
			#print cutoff,pt[0]
			vovo = fft( filt_tanl(vi, cutoff, falloff) )
			for z in range(myid, nz, number_of_proc):
				for x in range(nx):
					for y in range(ny):
						if(m.get_value_at(x,y,z) > 0.5):
							if(round(ui.get_value_at(x,y,z),2) == cutoff):
								filteredvol.set_value_at_fast(x,y,z,vovo.get_value_at(x,y,z))

	mpi_barrier(MPI_COMM_WORLD)
	reduce_EMData_to_root(filteredvol, myid, main_node, MPI_COMM_WORLD)
	return filteredvol
Example #5
0
def recons3d_4nn_ctf_MPI(
    myid,
    prjlist,
    snr=1.0,
    sign=1,
    symmetry="c1",
    finfo=None,
    npad=2,
    xysize=-1,
    zsize=-1,
    mpi_comm=None,
    smearstep=0.0,
):
    """
		recons3d_4nn_ctf - calculate CTF-corrected 3-D reconstruction from a set of projections using three Eulerian angles, two shifts, and CTF settings for each projeciton image
		Input
			stack: name of the stack file containing projection data, projections have to be squares
			list_proj: list of projections to be included in the reconstruction or image iterator
			snr: Signal-to-Noise Ratio of the data
			sign: sign of the CTF
			symmetry: point-group symmetry to be enforced, each projection will enter the reconstruction in all symmetry-related directions.
	"""

    if mpi_comm == None:
        mpi_comm = mpi.MPI_COMM_WORLD

    if type(prjlist) == list:
        prjlist = sp_utilities.iterImagesList(prjlist)
    if not prjlist.goToNext():
        sp_global_def.ERROR("empty input list", "recons3d_4nn_ctf_MPI", 1)
    imgsize = prjlist.image().get_xsize()
    if prjlist.image().get_ysize() != imgsize:
        imgsize = max(imgsize, prjlist.image().get_ysize())
        dopad = True
    else:
        dopad = False
    prjlist.goToPrev()

    fftvol = EMAN2_cppwrap.EMData()

    if smearstep > 0.0:
        # if myid == 0:  print "  Setting smear in prepare_recons_ctf"
        ns = 1
        smear = []
        for j in range(-ns, ns + 1):
            if j != 0:
                for i in range(-ns, ns + 1):
                    for k in range(-ns, ns + 1):
                        smear += [
                            i * smearstep, j * smearstep, k * smearstep, 1.0
                        ]
        # Deal with theta = 0.0 cases
        prj = []
        for i in range(-ns, ns + 1):
            for k in range(-ns, ns + 1):
                prj.append(i + k)
        for i in range(-2 * ns, 2 * ns + 1, 1):
            smear += [i * smearstep, 0.0, 0.0, float(prj.count(i))]
        # if myid == 0:  print "  Smear  ",smear
        fftvol.set_attr("smear", smear)

    weight = EMAN2_cppwrap.EMData()
    if xysize == -1 and zsize == -1:
        params = {
            "size": imgsize,
            "npad": npad,
            "snr": snr,
            "sign": sign,
            "symmetry": symmetry,
            "fftvol": fftvol,
            "weight": weight,
        }
        r = EMAN2_cppwrap.Reconstructors.get("nn4_ctf", params)
    else:
        if xysize != -1 and zsize != -1:
            rx = old_div(float(xysize), imgsize)
            ry = old_div(float(xysize), imgsize)
            rz = old_div(float(zsize), imgsize)
        elif xysize != -1:
            rx = old_div(float(xysize), imgsize)
            ry = old_div(float(xysize), imgsize)
            rz = 1.0
        else:
            rx = 1.0
            ry = 1.0
            rz = old_div(float(zsize), imgsize)
        #  There is an error here with sizeprojection  PAP 10/22/2014
        params = {
            "size": sizeprojection,
            "npad": npad,
            "snr": snr,
            "sign": sign,
            "symmetry": symmetry,
            "fftvol": fftvol,
            "weight": weight,
            "xratio": rx,
            "yratio": ry,
            "zratio": rz,
        }
        r = EMAN2_cppwrap.Reconstructors.get("nn4_ctf_rect", params)
    r.setup()

    # if not (finfo is None):
    nimg = 0
    while prjlist.goToNext():
        prj = prjlist.image()
        if dopad:
            prj = sp_utilities.pad(prj, imgsize, imgsize, 1, "circumference")
        # if params:
        insert_slices(r, prj)
        if not (finfo is None):
            nimg += 1
            finfo.write(" %4d inserted\n" % (nimg))
            finfo.flush()
    del sp_utilities.pad
    if not (finfo is None):
        finfo.write("begin reduce\n")
        finfo.flush()

    sp_utilities.reduce_EMData_to_root(fftvol, myid, comm=mpi_comm)
    sp_utilities.reduce_EMData_to_root(weight, myid, comm=mpi_comm)

    if not (finfo is None):
        finfo.write("after reduce\n")
        finfo.flush()

    if myid == 0:
        dummy = r.finish(True)
    else:
        if xysize == -1 and zsize == -1:
            fftvol = sp_utilities.model_blank(imgsize, imgsize, imgsize)
        else:
            if zsize == -1:
                fftvol = sp_utilities.model_blank(xysize, xysize, imgsize)
            elif xysize == -1:
                fftvol = sp_utilities.model_blank(imgsize, imgsize, zsize)
            else:
                fftvol = sp_utilities.model_blank(xysize, xysize, zsize)
    return fftvol
Example #6
0
def recons3d_trl_struct_MPI(
    myid,
    main_node,
    prjlist,
    paramstructure,
    refang,
    rshifts_shrank,
    delta,
    upweighted=True,
    mpi_comm=None,
    CTF=True,
    target_size=-1,
    avgnorm=1.0,
    norm_per_particle=None,
):
    """
		recons3d_4nn_ctf - calculate CTF-corrected 3-D reconstruction from a set of projections using three Eulerian angles, two shifts, and CTF settings for each projeciton image
		Input
			list_of_prjlist: list of lists of projections to be included in the reconstruction
	"""

    if mpi_comm == None:
        mpi_comm = mpi.MPI_COMM_WORLD

    refvol = sp_utilities.model_blank(target_size)
    refvol.set_attr("fudge", 1.0)

    if CTF:
        do_ctf = 1
    else:
        do_ctf = 0

    fftvol = EMAN2_cppwrap.EMData()
    weight = EMAN2_cppwrap.EMData()

    params = {
        "size": target_size,
        "npad": 2,
        "snr": 1.0,
        "sign": 1,
        "symmetry": "c1",
        "refvol": refvol,
        "fftvol": fftvol,
        "weight": weight,
        "do_ctf": do_ctf,
    }
    r = EMAN2_cppwrap.Reconstructors.get("nn4_ctfw", params)
    r.setup()

    if prjlist:
        if norm_per_particle == None:
            norm_per_particle = len(prjlist) * [1.0]

        nnx = prjlist[0].get_xsize()
        nny = prjlist[0].get_ysize()
        nshifts = len(rshifts_shrank)
        for im in range(len(prjlist)):
            #  parse projection structure, generate three lists:
            #  [ipsi+iang], [ishift], [probability]
            #  Number of orientations for a given image
            numbor = len(paramstructure[im][2])
            ipsiandiang = [
                old_div(paramstructure[im][2][i][0], 1000)
                for i in range(numbor)
            ]
            allshifts = [
                paramstructure[im][2][i][0] % 1000 for i in range(numbor)
            ]
            probs = [paramstructure[im][2][i][1] for i in range(numbor)]
            #  Find unique projection directions
            tdir = list(set(ipsiandiang))
            bckgn = prjlist[im].get_attr("bckgnoise")
            ct = prjlist[im].get_attr("ctf")
            #  For each unique projection direction:
            data = [None] * nshifts
            for ii in range(len(tdir)):
                #  Find the number of times given projection direction appears on the list, it is the number of different shifts associated with it.
                lshifts = sp_utilities.findall(tdir[ii], ipsiandiang)
                toprab = 0.0
                for ki in range(len(lshifts)):
                    toprab += probs[lshifts[ki]]
                recdata = EMAN2_cppwrap.EMData(nny, nny, 1, False)
                recdata.set_attr("is_complex", 0)
                for ki in range(len(lshifts)):
                    lpt = allshifts[lshifts[ki]]
                    if data[lpt] == None:
                        data[lpt] = sp_fundamentals.fshift(
                            prjlist[im], rshifts_shrank[lpt][0],
                            rshifts_shrank[lpt][1])
                        data[lpt].set_attr("is_complex", 0)
                    EMAN2_cppwrap.Util.add_img(
                        recdata,
                        EMAN2_cppwrap.Util.mult_scalar(
                            data[lpt], old_div(probs[lshifts[ki]], toprab)),
                    )
                recdata.set_attr_dict({
                    "padffted": 1,
                    "is_fftpad": 1,
                    "is_fftodd": 0,
                    "is_complex_ri": 1,
                    "is_complex": 1,
                })
                if not upweighted:
                    recdata = sp_filter.filt_table(recdata, bckgn)
                recdata.set_attr_dict({"bckgnoise": bckgn, "ctf": ct})
                ipsi = tdir[ii] % 100000
                iang = old_div(tdir[ii], 100000)
                r.insert_slice(
                    recdata,
                    EMAN2_cppwrap.Transform({
                        "type":
                        "spider",
                        "phi":
                        refang[iang][0],
                        "theta":
                        refang[iang][1],
                        "psi":
                        refang[iang][2] + ipsi * delta,
                    }),
                    old_div(toprab * avgnorm, norm_per_particle[im]),
                )
        #  clean stuff
        del bckgn, recdata, tdir, ipsiandiang, allshifts, probs

    sp_utilities.reduce_EMData_to_root(fftvol, myid, main_node, comm=mpi_comm)
    sp_utilities.reduce_EMData_to_root(weight, myid, main_node, comm=mpi_comm)

    if myid == main_node:
        dummy = r.finish(True)
    mpi.mpi_barrier(mpi_comm)

    if myid == main_node:
        return fftvol, weight, refvol
    else:
        return None, None, None
Example #7
0
def recons3d_4nnw_MPI(
    myid,
    prjlist,
    bckgdata,
    snr=1.0,
    sign=1,
    symmetry="c1",
    finfo=None,
    npad=2,
    xysize=-1,
    zsize=-1,
    mpi_comm=None,
    smearstep=0.0,
    fsc=None,
):
    """
		recons3d_4nn_ctf - calculate CTF-corrected 3-D reconstruction from a set of projections using three Eulerian angles, two shifts, and CTF settings for each projeciton image
		Input
			stack: name of the stack file containing projection data, projections have to be squares
			prjlist: list of projections to be included in the reconstruction or image iterator
			bckgdata = [get_im("tsd.hdf"),read_text_file("data_stamp.txt")]
			snr: Signal-to-Noise Ratio of the data
			sign: sign of the CTF
			symmetry: point-group symmetry to be enforced, each projection will enter the reconstruction in all symmetry-related directions.
	"""
    pass  # IMPORTIMPORTIMPORT from sp_utilities  import reduce_EMData_to_root, pad
    pass  # IMPORTIMPORTIMPORT from EMAN2      import Reconstructors
    pass  # IMPORTIMPORTIMPORT from sp_utilities  import iterImagesList, set_params_proj, model_blank
    pass  # IMPORTIMPORTIMPORT from mpi        import MPI_COMM_WORLD
    pass  # IMPORTIMPORTIMPORT import types

    if mpi_comm == None:
        mpi_comm = mpi.MPI_COMM_WORLD

    if type(prjlist) == list:
        prjlist = sp_utilities.iterImagesList(prjlist)
    if not prjlist.goToNext():
        sp_global_def.ERROR("empty input list", "recons3d_4nnw_MPI", 1)
    imgsize = prjlist.image().get_xsize()
    if prjlist.image().get_ysize() != imgsize:
        imgsize = max(imgsize, prjlist.image().get_ysize())
        dopad = True
    else:
        dopad = False
    prjlist.goToPrev()

    #  Do the FSC shtick.
    bnx = old_div(imgsize * npad, 2) + 1
    if fsc:
        pass  # IMPORTIMPORTIMPORT from math import sqrt
        pass  # IMPORTIMPORTIMPORT from sp_utilities import reshape_1d
        t = [0.0] * len(fsc)
        for i in range(len(fsc)):
            t[i] = min(max(fsc[i], 0.0), 0.999)
        t = sp_utilities.reshape_1d(t, len(t), npad * len(t))
        refvol = sp_utilities.model_blank(bnx, 1, 1, 0.0)
        for i in range(len(fsc)):
            refvol.set_value_at(i, t[i])
    else:
        refvol = sp_utilities.model_blank(bnx, 1, 1, 1.0)
    refvol.set_attr("fudge", 1.0)

    fftvol = EMAN2_cppwrap.EMData()
    weight = EMAN2_cppwrap.EMData()

    if smearstep > 0.0:
        # if myid == 0:  print "  Setting smear in prepare_recons_ctf"
        ns = 1
        smear = []
        for j in range(-ns, ns + 1):
            if j != 0:
                for i in range(-ns, ns + 1):
                    for k in range(-ns, ns + 1):
                        smear += [
                            i * smearstep, j * smearstep, k * smearstep, 1.0
                        ]
        # Deal with theta = 0.0 cases
        prj = []
        for i in range(-ns, ns + 1):
            for k in range(-ns, ns + 1):
                prj.append(i + k)
        for i in range(-2 * ns, 2 * ns + 1, 1):
            smear += [i * smearstep, 0.0, 0.0, float(prj.count(i))]
        # if myid == 0:  print "  Smear  ",smear
        fftvol.set_attr("smear", smear)

    if xysize == -1 and zsize == -1:
        params = {
            "size": imgsize,
            "npad": npad,
            "snr": snr,
            "sign": sign,
            "symmetry": symmetry,
            "refvol": refvol,
            "fftvol": fftvol,
            "weight": weight,
        }
        r = EMAN2_cppwrap.Reconstructors.get("nn4_ctfw", params)
    else:
        if xysize != -1 and zsize != -1:
            rx = old_div(float(xysize), imgsize)
            ry = old_div(float(xysize), imgsize)
            rz = old_div(float(zsize), imgsize)
        elif xysize != -1:
            rx = old_div(float(xysize), imgsize)
            ry = old_div(float(xysize), imgsize)
            rz = 1.0
        else:
            rx = 1.0
            ry = 1.0
            rz = old_div(float(zsize), imgsize)
        #  There is an error here with sizeprojection  PAP 10/22/2014
        params = {
            "size": sizeprojection,
            "npad": npad,
            "snr": snr,
            "sign": sign,
            "symmetry": symmetry,
            "fftvol": fftvol,
            "weight": weight,
            "xratio": rx,
            "yratio": ry,
            "zratio": rz,
        }
        r = EMAN2_cppwrap.Reconstructors.get("nn4_ctf_rect", params)
    r.setup()

    # from utilities import model_blank, get_im, read_text_file
    # bckgdata = [get_im("tsd.hdf"),read_text_file("data_stamp.txt")]

    nnx = bckgdata[0].get_xsize()
    nny = bckgdata[0].get_ysize()
    bckgnoise = []
    for i in range(nny):
        prj = sp_utilities.model_blank(nnx)
        for k in range(nnx):
            prj[k] = bckgdata[0].get_value_at(k, i)
        bckgnoise.append(prj)

    datastamp = bckgdata[1]
    if not (finfo is None):
        nimg = 0
    while prjlist.goToNext():
        prj = prjlist.image()
        try:
            stmp = old_div(nnx, 0)
            stmp = prj.get_attr("ptcl_source_image")
        except:
            try:
                stmp = prj.get_attr("ctf")
                stmp = round(stmp.defocus, 4)
            except:
                sp_global_def.ERROR(
                    "Either ptcl_source_image or ctf has to be present in the header.",
                    "recons3d_4nnw_MPI",
                    1,
                    myid,
                )
        try:
            indx = datastamp.index(stmp)
        except:
            sp_global_def.ERROR("Problem with indexing ptcl_source_image.",
                                "recons3d_4nnw_MPI", 1, myid)

        if dopad:
            prj = sp_utilities.pad(prj, imgsize, imgsize, 1, "circumference")

        prj.set_attr("bckgnoise", bckgnoise[indx])
        insert_slices(r, prj)
        if not (finfo is None):
            nimg += 1
            finfo.write(" %4d inserted\n" % (nimg))
            finfo.flush()
    del sp_utilities.pad
    if not (finfo is None):
        finfo.write("begin reduce\n")
        finfo.flush()

    sp_utilities.reduce_EMData_to_root(fftvol, myid, comm=mpi_comm)
    sp_utilities.reduce_EMData_to_root(weight, myid, comm=mpi_comm)

    if not (finfo is None):
        finfo.write("after reduce\n")
        finfo.flush()

    if myid == 0:
        dummy = r.finish(True)
    else:
        pass  # IMPORTIMPORTIMPORT from sp_utilities import model_blank
        if xysize == -1 and zsize == -1:
            fftvol = sp_utilities.model_blank(imgsize, imgsize, imgsize)
        else:
            if zsize == -1:
                fftvol = sp_utilities.model_blank(xysize, xysize, imgsize)
            elif xysize == -1:
                fftvol = sp_utilities.model_blank(imgsize, imgsize, zsize)
            else:
                fftvol = sp_utilities.model_blank(xysize, xysize, zsize)
    return fftvol
Example #8
0
def recons3d_4nn_MPI(
    myid,
    prjlist,
    symmetry="c1",
    finfo=None,
    snr=1.0,
    npad=2,
    xysize=-1,
    zsize=-1,
    mpi_comm=None,
):
    if mpi_comm == None:
        mpi_comm = mpi.MPI_COMM_WORLD

    if type(prjlist) == list:
        prjlist = sp_utilities.iterImagesList(prjlist)

    if not prjlist.goToNext():
        sp_global_def.ERROR("empty input list", "recons3d_4nn_MPI", 1)

    imgsize = prjlist.image().get_xsize()
    if prjlist.image().get_ysize() != imgsize:
        imgsize = max(imgsize, prjlist.image().get_ysize())
        dopad = True
    else:
        dopad = False
    prjlist.goToPrev()

    fftvol = EMAN2_cppwrap.EMData()
    weight = EMAN2_cppwrap.EMData()
    if xysize == -1 and zsize == -1:
        params = {
            "size": imgsize,
            "npad": npad,
            "symmetry": symmetry,
            "fftvol": fftvol,
            "weight": weight,
            "snr": snr,
        }
        r = EMAN2_cppwrap.Reconstructors.get("nn4", params)
    else:
        if xysize != -1 and zsize != -1:
            rx = old_div(float(xysize), imgsize)
            ry = old_div(float(xysize), imgsize)
            rz = old_div(float(zsize), imgsize)
        elif xysize != -1:
            rx = old_div(float(xysize), imgsize)
            ry = old_div(float(xysize), imgsize)
            rz = 1.0
        else:
            rx = 1.0
            ry = 1.0
            rz = old_div(float(zsize), imgsize)
        params = {
            "sizeprojection": imgsize,
            "npad": npad,
            "symmetry": symmetry,
            "fftvol": fftvol,
            "weight": weight,
            "xratio": rx,
            "yratio": ry,
            "zratio": rz,
        }
        r = EMAN2_cppwrap.Reconstructors.get("nn4_rect", params)
    r.setup()

    if not (finfo is None):
        nimg = 0
    while prjlist.goToNext():
        prj = prjlist.image()
        if dopad:
            prj = sp_utilities.pad(prj, imgsize, imgsize, 1, "circumference")
        insert_slices(r, prj)
        if not (finfo is None):
            nimg += 1
            finfo.write("Image %4d inserted.\n" % (nimg))
            finfo.flush()

    if not (finfo is None):
        finfo.write("Begin reducing ...\n")
        finfo.flush()

    sp_utilities.reduce_EMData_to_root(fftvol, myid, comm=mpi_comm)
    sp_utilities.reduce_EMData_to_root(weight, myid, comm=mpi_comm)

    if myid == 0:
        dummy = r.finish(True)
    else:
        if xysize == -1 and zsize == -1:
            fftvol = sp_utilities.model_blank(imgsize, imgsize, imgsize)
        else:
            if zsize == -1:
                fftvol = sp_utilities.model_blank(xysize, xysize, imgsize)
            elif xysize == -1:
                fftvol = sp_utilities.model_blank(imgsize, imgsize, zsize)
            else:
                fftvol = sp_utilities.model_blank(xysize, xysize, zsize)
    return fftvol
Example #9
0
def mref_ali2d_MPI(stack,
                   refim,
                   outdir,
                   maskfile=None,
                   ir=1,
                   ou=-1,
                   rs=1,
                   xrng=0,
                   yrng=0,
                   step=1,
                   center=1,
                   maxit=10,
                   CTF=False,
                   snr=1.0,
                   user_func_name="ref_ali2d",
                   rand_seed=1000):
    # 2D multi-reference alignment using rotational ccf in polar coordinates and quadratic interpolation

    from sp_utilities import model_circle, combine_params2, inverse_transform2, drop_image, get_image, get_im
    from sp_utilities import reduce_EMData_to_root, bcast_EMData_to_all, bcast_number_to_all
    from sp_utilities import send_attr_dict
    from sp_utilities import center_2D
    from sp_statistics import fsc_mask
    from sp_alignment import Numrinit, ringwe, search_range
    from sp_fundamentals import rot_shift2D, fshift
    from sp_utilities import get_params2D, set_params2D
    from random import seed, randint
    from sp_morphology import ctf_2
    from sp_filter import filt_btwl, filt_params
    from numpy import reshape, shape
    from sp_utilities import print_msg, print_begin_msg, print_end_msg
    import os
    import sys
    import shutil
    from sp_applications import MPI_start_end
    from mpi import mpi_comm_size, mpi_comm_rank, MPI_COMM_WORLD
    from mpi import mpi_reduce, mpi_bcast, mpi_barrier, mpi_recv, mpi_send
    from mpi import MPI_SUM, MPI_FLOAT, MPI_INT

    number_of_proc = mpi_comm_size(MPI_COMM_WORLD)
    myid = mpi_comm_rank(MPI_COMM_WORLD)
    main_node = 0

    # create the output directory, if it does not exist

    if os.path.exists(outdir):
        ERROR(
            'Output directory exists, please change the name and restart the program',
            "mref_ali2d_MPI ", 1, myid)
    mpi_barrier(MPI_COMM_WORLD)

    import sp_global_def
    if myid == main_node:
        os.mkdir(outdir)
        sp_global_def.LOGFILE = os.path.join(outdir, sp_global_def.LOGFILE)
        print_begin_msg("mref_ali2d_MPI")

    nima = EMUtil.get_image_count(stack)

    image_start, image_end = MPI_start_end(nima, number_of_proc, myid)

    nima = EMUtil.get_image_count(stack)
    ima = EMData()
    ima.read_image(stack, image_start)

    first_ring = int(ir)
    last_ring = int(ou)
    rstep = int(rs)
    max_iter = int(maxit)

    if max_iter == 0:
        max_iter = 10
        auto_stop = True
    else:
        auto_stop = False

    if myid == main_node:
        print_msg("Input stack                 : %s\n" % (stack))
        print_msg("Reference stack             : %s\n" % (refim))
        print_msg("Output directory            : %s\n" % (outdir))
        print_msg("Maskfile                    : %s\n" % (maskfile))
        print_msg("Inner radius                : %i\n" % (first_ring))

    nx = ima.get_xsize()
    # default value for the last ring
    if last_ring == -1: last_ring = nx / 2 - 2

    if myid == main_node:
        print_msg("Outer radius                : %i\n" % (last_ring))
        print_msg("Ring step                   : %i\n" % (rstep))
        print_msg("X search range              : %f\n" % (xrng))
        print_msg("Y search range              : %f\n" % (yrng))
        print_msg("Translational step          : %f\n" % (step))
        print_msg("Center type                 : %i\n" % (center))
        print_msg("Maximum iteration           : %i\n" % (max_iter))
        print_msg("CTF correction              : %s\n" % (CTF))
        print_msg("Signal-to-Noise Ratio       : %f\n" % (snr))
        print_msg("Random seed                 : %i\n\n" % (rand_seed))
        print_msg("User function               : %s\n" % (user_func_name))
    import sp_user_functions
    user_func = sp_user_functions.factory[user_func_name]

    if maskfile:
        import types
        if type(maskfile) is bytes: mask = get_image(maskfile)
        else: mask = maskfile
    else: mask = model_circle(last_ring, nx, nx)
    #  references, do them on all processors...
    refi = []
    numref = EMUtil.get_image_count(refim)

    # IMAGES ARE SQUARES! center is in SPIDER convention
    cnx = nx / 2 + 1
    cny = cnx

    mode = "F"
    #precalculate rings
    numr = Numrinit(first_ring, last_ring, rstep, mode)
    wr = ringwe(numr, mode)

    # prepare reference images on all nodes
    ima.to_zero()
    for j in range(numref):
        #  even, odd, numer of even, number of images.  After frc, totav
        refi.append([get_im(refim, j), ima.copy(), 0])
    #  for each node read its share of data
    data = EMData.read_images(stack, list(range(image_start, image_end)))
    for im in range(image_start, image_end):
        data[im - image_start].set_attr('ID', im)

    if myid == main_node: seed(rand_seed)

    a0 = -1.0
    again = True
    Iter = 0

    ref_data = [mask, center, None, None]

    while Iter < max_iter and again:
        ringref = []
        mashi = cnx - last_ring - 2
        for j in range(numref):
            refi[j][0].process_inplace("normalize.mask", {
                "mask": mask,
                "no_sigma": 1
            })  # normalize reference images to N(0,1)
            cimage = Util.Polar2Dm(refi[j][0], cnx, cny, numr, mode)
            Util.Frngs(cimage, numr)
            Util.Applyws(cimage, numr, wr)
            ringref.append(cimage)
            # zero refi
            refi[j][0].to_zero()
            refi[j][1].to_zero()
            refi[j][2] = 0

        assign = [[] for i in range(numref)]
        # begin MPI section
        for im in range(image_start, image_end):
            alpha, sx, sy, mirror, scale = get_params2D(data[im - image_start])
            #  Why inverse?  07/11/2015 PAP
            alphai, sxi, syi, scalei = inverse_transform2(alpha, sx, sy)
            # normalize
            data[im - image_start].process_inplace("normalize.mask", {
                "mask": mask,
                "no_sigma": 0
            })  # subtract average under the mask
            # If shifts are outside of the permissible range, reset them
            if (abs(sxi) > mashi or abs(syi) > mashi):
                sxi = 0.0
                syi = 0.0
                set_params2D(data[im - image_start], [0.0, 0.0, 0.0, 0, 1.0])
            ny = nx
            txrng = search_range(nx, last_ring, sxi, xrng, "mref_ali2d_MPI")
            txrng = [txrng[1], txrng[0]]
            tyrng = search_range(ny, last_ring, syi, yrng, "mref_ali2d_MPI")
            tyrng = [tyrng[1], tyrng[0]]
            # align current image to the reference
            [angt, sxst, syst, mirrort, xiref,
             peakt] = Util.multiref_polar_ali_2d(data[im - image_start],
                                                 ringref, txrng, tyrng, step,
                                                 mode, numr, cnx + sxi,
                                                 cny + syi)

            iref = int(xiref)
            # combine parameters and set them to the header, ignore previous angle and mirror
            [alphan, sxn, syn,
             mn] = combine_params2(0.0, -sxi, -syi, 0, angt, sxst, syst,
                                   (int)(mirrort))
            set_params2D(data[im - image_start],
                         [alphan, sxn, syn, int(mn), scale])
            data[im - image_start].set_attr('assign', iref)
            # apply current parameters and add to the average
            temp = rot_shift2D(data[im - image_start], alphan, sxn, syn, mn)
            it = im % 2
            Util.add_img(refi[iref][it], temp)
            assign[iref].append(im)
            #assign[im] = iref
            refi[iref][2] += 1.0
        del ringref
        # end MPI section, bring partial things together, calculate new reference images, broadcast them back

        for j in range(numref):
            reduce_EMData_to_root(refi[j][0], myid, main_node)
            reduce_EMData_to_root(refi[j][1], myid, main_node)
            refi[j][2] = mpi_reduce(refi[j][2], 1, MPI_FLOAT, MPI_SUM,
                                    main_node, MPI_COMM_WORLD)
            if (myid == main_node): refi[j][2] = int(refi[j][2][0])
        # gather assignements
        for j in range(numref):
            if myid == main_node:
                for n in range(number_of_proc):
                    if n != main_node:
                        import sp_global_def
                        ln = mpi_recv(1, MPI_INT, n,
                                      sp_global_def.SPARX_MPI_TAG_UNIVERSAL,
                                      MPI_COMM_WORLD)
                        lis = mpi_recv(ln[0], MPI_INT, n,
                                       sp_global_def.SPARX_MPI_TAG_UNIVERSAL,
                                       MPI_COMM_WORLD)
                        for l in range(ln[0]):
                            assign[j].append(int(lis[l]))
            else:
                import sp_global_def
                mpi_send(len(assign[j]), 1, MPI_INT, main_node,
                         sp_global_def.SPARX_MPI_TAG_UNIVERSAL, MPI_COMM_WORLD)
                mpi_send(assign[j], len(assign[j]), MPI_INT, main_node,
                         sp_global_def.SPARX_MPI_TAG_UNIVERSAL, MPI_COMM_WORLD)

        if myid == main_node:
            # replace the name of the stack with reference with the current one
            refim = os.path.join(outdir, "aqm%03d.hdf" % Iter)
            a1 = 0.0
            ave_fsc = []
            for j in range(numref):
                if refi[j][2] < 4:
                    #ERROR("One of the references vanished","mref_ali2d_MPI",1)
                    #  if vanished, put a random image (only from main node!) there
                    assign[j] = []
                    assign[j].append(
                        randint(image_start, image_end - 1) - image_start)
                    refi[j][0] = data[assign[j][0]].copy()
                    #print 'ERROR', j
                else:
                    #frsc = fsc_mask(refi[j][0], refi[j][1], mask, 1.0, os.path.join(outdir,"drm%03d%04d"%(Iter, j)))
                    from sp_statistics import fsc
                    frsc = fsc(
                        refi[j][0], refi[j][1], 1.0,
                        os.path.join(outdir, "drm%03d%04d.txt" % (Iter, j)))
                    Util.add_img(refi[j][0], refi[j][1])
                    Util.mul_scalar(refi[j][0], 1.0 / float(refi[j][2]))

                    if ave_fsc == []:
                        for i in range(len(frsc[1])):
                            ave_fsc.append(frsc[1][i])
                        c_fsc = 1
                    else:
                        for i in range(len(frsc[1])):
                            ave_fsc[i] += frsc[1][i]
                        c_fsc += 1
                    #print 'OK', j, len(frsc[1]), frsc[1][0:5], ave_fsc[0:5]

            #print 'sum', sum(ave_fsc)
            if sum(ave_fsc) != 0:
                for i in range(len(ave_fsc)):
                    ave_fsc[i] /= float(c_fsc)
                    frsc[1][i] = ave_fsc[i]

            for j in range(numref):
                ref_data[2] = refi[j][0]
                ref_data[3] = frsc
                refi[j][0], cs = user_func(ref_data)

                # write the current average
                TMP = []
                for i_tmp in range(len(assign[j])):
                    TMP.append(float(assign[j][i_tmp]))
                TMP.sort()
                refi[j][0].set_attr_dict({'ave_n': refi[j][2], 'members': TMP})
                del TMP
                refi[j][0].process_inplace("normalize.mask", {
                    "mask": mask,
                    "no_sigma": 1
                })
                refi[j][0].write_image(refim, j)

            Iter += 1
            msg = "ITERATION #%3d        %d\n\n" % (Iter, again)
            print_msg(msg)
            for j in range(numref):
                msg = "   group #%3d   number of particles = %7d\n" % (
                    j, refi[j][2])
                print_msg(msg)
        Iter = bcast_number_to_all(Iter, main_node)  # need to tell all
        if again:
            for j in range(numref):
                bcast_EMData_to_all(refi[j][0], myid, main_node)

    #  clean up
    del assign
    # write out headers  and STOP, under MPI writing has to be done sequentially (time-consumming)
    mpi_barrier(MPI_COMM_WORLD)
    if CTF and data_had_ctf == 0:
        for im in range(len(data)):
            data[im].set_attr('ctf_applied', 0)
    par_str = ['xform.align2d', 'assign', 'ID']
    if myid == main_node:
        from sp_utilities import file_type
        if (file_type(stack) == "bdb"):
            from sp_utilities import recv_attr_dict_bdb
            recv_attr_dict_bdb(main_node, stack, data, par_str, image_start,
                               image_end, number_of_proc)
        else:
            from sp_utilities import recv_attr_dict
            recv_attr_dict(main_node, stack, data, par_str, image_start,
                           image_end, number_of_proc)
    else:
        send_attr_dict(main_node, data, par_str, image_start, image_end)
    if myid == main_node:
        print_end_msg("mref_ali2d_MPI")