예제 #1
0
def mode_meridien(reconfile,
                  classavgstack,
                  classdocs,
                  partangles,
                  selectdoc,
                  maxshift,
                  outerrad,
                  outanglesdoc,
                  outaligndoc,
                  interpolation_method=1,
                  outliers=None,
                  goodclassparttemplate=None,
                  alignopt='apsh',
                  ringstep=1,
                  log=None,
                  verbose=False):

    # Resample reference
    recondata = EMAN2.EMData(reconfile)
    idim = recondata['nx']
    reconprep = prep_vol(recondata,
                         npad=2,
                         interpolation_method=interpolation_method)

    # Initialize output angles
    outangleslist = []
    outalignlist = []

    # Read class lists
    classdoclist = glob.glob(classdocs)
    partangleslist = read_text_row(partangles)

    # Loop through class lists
    for classdoc in classdoclist:  # [classdoclist[32]]:  #
        # Strip out three-digit filenumber
        classexample = os.path.splitext(classdoc)
        classnum = int(classexample[0][-3:])

        # Initial average
        [avg_phi_init, avg_theta_init] = average_angles(partangleslist,
                                                        classdoc,
                                                        selectdoc=selectdoc)

        # Look for outliers
        if outliers:
            [avg_phi_final, avg_theta_final] = average_angles(
                partangleslist,
                classdoc,
                selectdoc=selectdoc,
                init_angles=[avg_phi_init, avg_theta_init],
                threshold=outliers,
                goodpartdoc=goodclassparttemplate.format(classnum),
                log=log,
                verbose=verbose)
        else:
            [avg_phi_final, avg_theta_final] = [avg_phi_init, avg_theta_init]

        # Compute re-projection
        refprjreal = prgl(reconprep, [avg_phi_final, avg_theta_final, 0, 0, 0],
                          interpolation_method=1,
                          return_real=True)

        # Align to class average
        classavg = get_im(classavgstack, classnum)

        # Alignment using self-correlation function
        if alignopt == 'scf':
            ang_align2d, sxs, sys, mirrorflag, peak = align2d_scf(classavg,
                                                                  refprjreal,
                                                                  maxshift,
                                                                  maxshift,
                                                                  ou=outerrad)

        # Weird results
        elif alignopt == 'align2d':
            # Set search range
            currshift = 0
            txrng = tyrng = search_range(idim, outerrad, currshift, maxshift)

            # Perform alignment
            ang_align2d, sxs, sys, mirrorflag, peak = align2d(
                classavg, refprjreal, txrng, tyrng, last_ring=outerrad)

        # Direct3 (angles seemed to be quantized)
        elif alignopt == 'direct3':
            [[ang_align2d, sxs, sys, mirrorflag,
              peak]] = align2d_direct3([classavg],
                                       refprjreal,
                                       maxshift,
                                       maxshift,
                                       ou=outerrad)

        # APSH-like alignment (default)
        else:
            [[ang_align2d, sxs, sys, mirrorflag,
              scale]] = apsh(refprjreal,
                             classavg,
                             outerradius=outerrad,
                             maxshift=maxshift,
                             ringstep=ringstep)

        outalignlist.append([ang_align2d, sxs, sys, mirrorflag, 1])
        msg = "Particle list %s: ang_align2d=%s sx=%s sy=%s mirror=%s\n" % (
            classdoc, ang_align2d, sxs, sys, mirrorflag)
        print_log_msg(msg, log, verbose)

        # Check for mirroring
        if mirrorflag == 1:
            tempeulers = list(
                compose_transform3(avg_phi_final, avg_theta_final, 0, 0, 0, 0,
                                   1, 0, 180, 0, 0, 0, 0, 1))
            combinedparams = list(
                compose_transform3(tempeulers[0], tempeulers[1], tempeulers[2],
                                   tempeulers[3], tempeulers[4], 0, 1, 0, 0,
                                   -ang_align2d, 0, 0, 0, 1))
        else:
            combinedparams = list(
                compose_transform3(avg_phi_final, avg_theta_final, 0, 0, 0, 0,
                                   1, 0, 0, -ang_align2d, 0, 0, 0, 1))
        # compose_transform3: returns phi,theta,psi, tx,ty,tz, scale

        outangleslist.append(combinedparams)
    # End class-loop

    write_text_row(outangleslist, outanglesdoc)
    write_text_row(outalignlist, outaligndoc)
    print_log_msg(
        'Wrote alignment parameters to %s and %s\n' %
        (outanglesdoc, outaligndoc), log, verbose)

    del recondata  # Clean up
예제 #2
0
def main_proj_compare(classavgstack,
                      reconfile,
                      outdir,
                      options,
                      mode='viper',
                      prjmethod='trilinear',
                      classangles=None,
                      partangles=None,
                      selectdoc=None,
                      verbose=False,
                      displayYN=False):
    """
	Main function overseeing various projection-comparison modes.
	
	Arguments:
		classavgstack : Input image stack
		reconfile : Map of which to generate projections (an optionally perform alignment)
		outdir : Output directory
		mode : Mode, viper (pre-existing angles for each input image), projmatch (angles from internal projection-matching)
		verbose : (boolean) Whether to write additional information to screen
		options : (list) Command-line options, run 'sxproj_compare.py -h' for an exhaustive list
		classangles : Angles and shifts for each input class average
		partangles : Angles and shifts for each particle (mode meridien)
		selectdoc : Selection file for included images
		prjmethod : Interpolation method to use
		displayYN : (boolean) Whether to automatically open montage
	"""

    # Expand path for outputs
    refprojstack = os.path.join(outdir, 'refproj.hdf')
    refanglesdoc = os.path.join(outdir, 'refangles.txt')
    outaligndoc = os.path.join(outdir, 'docalign2d.txt')

    # If not an input, will create an output, in modes projmatch
    if classangles == None:
        classangles = os.path.join(outdir, 'docangles.txt')

        # You need either input angles (mode viper) or to calculate them on the fly (mode projmatch)
        if mode == 'viper':
            sp_global_def.ERROR(
                "\nERROR!! Input alignment parameters not specified.",
                __file__, 1)
            sxprint('Type %s --help to see available options\n' %
                    os.path.basename(__file__))
            exit()

    # Check if inputs exist
    check(classavgstack, verbose=verbose)
    check(reconfile, verbose=verbose)
    if verbose: sxprint('')

    # Check that dimensions of images and volume agree (maybe rescale volume)
    voldim = EMAN2.EMData(reconfile).get_xsize()
    imgdim = EMAN2.EMData(classavgstack, 0).get_xsize()
    if voldim != imgdim:
        sp_global_def.ERROR(
            "\nERROR!! Dimension of input volume doesn't match that of image stack: %s vs. %s"
            % (voldim, imgdim), __file__, 1)

        scale = float(
            imgdim
        ) / voldim  # only approximate, since full-sized particle radius is arbitrary
        msg = 'The command to resize the volume will be of the form:\n'
        msg += 'e2proc3d.py %s resized_vol.hdf --scale=%1.5f --clip=%s,%s,%s\n' % (
            reconfile, scale, imgdim, imgdim, imgdim)
        msg += 'Check the file in the ISAC directory named "README_shrink_ratio.txt" for confirmation.\n'
        sxprint(msg)
        exit()

    #  Here if you want to be fancy, there should be an option to chose the projection method,
    #  the mechanism can be copied from sxproject3d.py  PAP
    if prjmethod == 'trilinear':
        method_num = 1
    elif prjmethod == 'gridding':
        method_num = -1
    elif prjmethod == 'nn':
        method_num = 0
    else:
        sp_global_def.ERROR(
            "\nERROR!! Valid projection methods are: trilinear (default), gridding, and nn (nearest neighbor).",
            __file__, 1)
        sxprint('Usage:\n%s' % USAGE)
        exit()

    # Set output directory and log file name
    log, verbose = prepare_outdir_log(outdir, verbose)

    # In case class averages include discarded images, apply selection file
    if mode == 'viper':
        if selectdoc:
            goodavgs, extension = os.path.splitext(
                os.path.basename(classavgstack))
            newclasses = os.path.join(outdir, goodavgs + "_kept" + extension)

            # e2proc2d appends to existing files, so rename existing output
            if os.path.exists(newclasses):
                renamefile = newclasses + '.bak'
                print_log_msg(
                    "Selected-classes stack %s exists, renaming to %s" %
                    (newclasses, renamefile), log, verbose)
                print_log_msg("mv %s %s\n" % (newclasses, renamefile), log,
                              verbose)
                os.rename(newclasses, renamefile)

            print_log_msg(
                'Creating subset of %s to %s based on selection list %s' %
                (classavgstack, newclasses, selectdoc), log, verbose)
            cmd = "e2proc2d.py %s %s --list=%s" % (classavgstack, newclasses,
                                                   selectdoc)
            print_log_msg(cmd, log, verbose)
            os.system(cmd)
            sxprint('')

            # Update class-averages
            classavgstack = newclasses

    # align de novo to reference map
    if mode == 'projmatch':
        # Generate reference projections
        print_log_msg(
            'Projecting %s to output %s using an increment of %s degrees using %s symmetry'
            % (reconfile, refprojstack, options.delta, options.symmetry), log,
            verbose)
        cmd = 'sxproject3d.py %s %s --delta=%s --method=S --phiEqpsi=Minus --symmetry=%s' % (
            reconfile, refprojstack, options.delta, options.symmetry)
        if options.prjmethod == 'trilinear': cmd += ' --trilinear'
        cmd += '\n'
        print_log_msg(cmd, log, verbose)
        project3d(reconfile,
                  refprojstack,
                  delta=options.delta,
                  symmetry=options.symmetry)

        # Export projection angles
        print_log_msg(
            "Exporting projection angles from %s to %s" %
            (refprojstack, refanglesdoc), log, verbose)
        cmd = "sp_header.py %s --params=xform.projection --import=%s\n" % (
            refprojstack, refanglesdoc)
        print_log_msg(cmd, log, verbose)
        header(refprojstack, 'xform.projection', fexport=refanglesdoc)

        # Perform multi-reference alignment
        if options.align == 'ali2d':
            projdir = os.path.join(
                outdir, 'Projdir')  # used if input angles no provided
            if os.path.isdir(projdir):
                print_log_msg('Removing pre-existing directory %s' % projdir,
                              log, verbose)
                print_log_msg('rm -r %s\n' % projdir, log, verbose)
                shutil.rmtree(
                    projdir)  # os.rmdir only removes empty directories

            # Zero out alignment parameters in header
            print_log_msg(
                'Zeroing out alignment parameters in header of %s' %
                classavgstack, log, verbose)
            cmd = 'sxheader.py %s --params xform.align2d --zero\n' % classavgstack
            print_log_msg(cmd, log, verbose)
            header(classavgstack, 'xform.align2d', zero=True)

            # Perform multi-reference alignment
            msg = 'Aligning images in %s to projections %s with a radius of %s and a maximum allowed shift of %s' % (
                classavgstack, refprojstack, options.matchrad,
                options.matchshift)
            print_log_msg(msg, log, verbose)
            cmd = 'sxmref_ali2d.py %s %s %s --ou=%s --xr=%s --yr=%s\n' % (
                classavgstack, refprojstack, projdir, options.matchrad,
                options.matchshift, options.matchshift)
            print_log_msg(cmd, log, verbose)
            mref_ali2d(classavgstack,
                       refprojstack,
                       projdir,
                       ou=options.matchrad,
                       xrng=options.matchshift,
                       yrng=options.matchshift)

            # Export alignment parameters
            print_log_msg(
                'Exporting angles from %s into %s' %
                (classavgstack, classangles), log, verbose)
            cmd = "sp_header.py %s --params=xform.align2d --export=%s\n" % (
                classavgstack, classangles)
            print_log_msg(cmd, log, verbose)
            header(classavgstack, 'xform.align2d', fexport=classangles)

        # By default, use AP SH
        else:
            apsh(refprojstack,
                 classavgstack,
                 outangles=classangles,
                 refanglesdoc=refanglesdoc,
                 outaligndoc=outaligndoc,
                 outerradius=options.matchrad,
                 maxshift=options.matchshift,
                 ringstep=options.matchstep,
                 log=log,
                 verbose=verbose)

        # Diagnostic
        alignlist = read_text_row(
            classangles)  # contain 2D alignment parameters
        nimg1 = EMAN2.EMUtil.get_image_count(classavgstack)
        assert len(alignlist) == nimg1, "MRK_DEBUG"

    # Get alignment parameters from MERIDIEN
    if mode == 'meridien':
        continueTF = True  # Will proceed unless some information is missing

        if not partangles:
            sp_global_def.ERROR(
                "\nERROR!! Input alignment parameters not provided.", __file__,
                1)
            continueTF = False

        if not continueTF:
            sxprint('Type %s --help to see available options\n' %
                    os.path.basename(__file__))
            exit()

        if not options.classdocs or options.outliers:
            classdir = os.path.join(outdir, 'Byclass')
            if not os.path.isdir(classdir): os.makedirs(classdir)

            if options.outliers:
                goodclassparttemplate = os.path.join(
                    classdir, 'goodpartsclass{0:03d}.txt')
            else:
                goodclassparttemplate = None

            if not options.classdocs:
                classmap = os.path.join(classdir, 'classmap.txt')
                classdoc = os.path.join(classdir, 'docclass{0:03d}.txt')
                options.classdocs = os.path.join(classdir, 'docclass*.txt')

                # Separate particles by class
                vomq(classavgstack,
                     classmap,
                     classdoc,
                     log=log,
                     verbose=verbose)

        mode_meridien(reconfile,
                      classavgstack,
                      options.classdocs,
                      partangles,
                      selectdoc,
                      options.refineshift,
                      options.refinerad,
                      classangles,
                      outaligndoc,
                      interpolation_method=method_num,
                      outliers=options.outliers,
                      goodclassparttemplate=goodclassparttemplate,
                      alignopt=options.align,
                      ringstep=options.refinestep,
                      log=log,
                      verbose=verbose)

    # Import Euler angles
    print_log_msg(
        "Importing parameter information into %s from %s" %
        (classavgstack, classangles), log, verbose)
    cmd = "sp_header.py %s --params=xform.projection --import=%s\n" % (
        classavgstack, classangles)
    print_log_msg(cmd, log, verbose)
    header(classavgstack, 'xform.projection', fimport=classangles)

    # Make comparison stack between class averages (images 0,2,4,...) and re-projections (images 1,3,5,...)
    compstack = compare_projs(reconfile,
                              classavgstack,
                              classangles,
                              outdir,
                              interpolation_method=method_num,
                              log=log,
                              verbose=verbose)

    # Optionally pop up e2display
    if displayYN:
        sxprint('Opening montage')
        cmd = "e2display.py %s\n" % compstack
        sxprint(cmd)
        os.system(cmd)

    sxprint("Done!")
예제 #3
0
def apsh(refimgs,
         imgs2align,
         outangles=None,
         refanglesdoc=None,
         outaligndoc=None,
         outerradius=-1,
         maxshift=0,
         ringstep=1,
         mode="F",
         log=None,
         verbose=False):
    """
	Generates polar representations of a series of images to be used as alignment references.
	
	Arguments:
		refimgs : Input reference image stack (filename or EMData object)
		imgs2align : Image stack to be aligned (filename or EMData object)
		outangles : Output Euler angles doc file
		refanglesdoc : Input Euler angles for reference projections
		outaligndoc : Output 2D alignment doc file
		outerradius : Outer alignment radius
		maxshift : Maximum shift allowed
		ringstep : Alignment radius step size
		mode : Mode, full circle ("F") vs. half circle ("H")
		log : Logger object
		verbose : (boolean) Whether to write additional information to screen
	"""

    # Generate polar representation(s) of reference(s)
    alignrings, polarrefs = mref2polar(refimgs,
                                       outerradius=outerradius,
                                       ringstep=ringstep,
                                       log=log,
                                       verbose=verbose)

    # Read image stack (as a filename or already an EMDataobject)
    if isinstance(imgs2align, str):
        imagestacklist = EMData.read_images(imgs2align)
    else:
        imagestacklist = [imgs2align]

    # Get number of images
    numimg = len(imagestacklist)

    # Get image dimensions (assuming square, and that images and references have the same dimension)
    idim = imagestacklist[0]['nx']

    # Calculate image center
    halfdim = idim / 2 + 1

    # Set constants
    currshift = 0
    scale = 1

    # Initialize output angles
    outangleslist = []
    outalignlist = []

    if outerradius <= 0:
        outerradius = halfdim - 3

    # Set search range
    txrng = tyrng = search_range(idim, outerradius, currshift, maxshift)

    print_log_msg(
        'Running multireference alignment allowing a maximum shift of %s\n' %
        maxshift, log, verbose)

    # Loop through images
    for imgindex in range(numimg):
        currimg = imagestacklist[imgindex]

        # Perform multi-reference alignment (adapted from alignment.mref_ali2d)
        best2dparamslist = [
            angt, sxst, syst, mirrorfloat, bestreffloat, peakt
        ] = Util.multiref_polar_ali_2d(currimg, polarrefs, txrng, tyrng,
                                       ringstep, mode, alignrings, halfdim,
                                       halfdim)
        bestref = int(bestreffloat)
        mirrorflag = int(mirrorfloat)

        # Store parameters
        params2dlist = [angt, sxst, syst, mirrorflag, scale]
        outalignlist.append(params2dlist)

        if refanglesdoc:
            refangleslist = read_text_row(refanglesdoc)
            besteulers = refangleslist[bestref]
        else:
            besteulers = [0] * 5

        # Check for mirroring
        if mirrorflag == 1:
            tempeulers = list(
                compose_transform3(besteulers[0], besteulers[1], besteulers[2],
                                   besteulers[3], besteulers[4], 0, 1, 0, 180,
                                   0, 0, 0, 0, 1))
            combinedparams = list(
                compose_transform3(tempeulers[0], tempeulers[1], tempeulers[2],
                                   tempeulers[3], tempeulers[4], 0, 1, 0, 0,
                                   -angt, 0, 0, 0, 1))
        else:
            combinedparams = list(
                compose_transform3(besteulers[0], besteulers[1], besteulers[2],
                                   besteulers[3], besteulers[4], 0, 1, 0, 0,
                                   -angt, 0, 0, 0, 1))
        # compose_transform3: returns phi,theta,psi, tx,ty,tz, scale

        outangleslist.append(combinedparams)

        # Set transformations as image attribute
        set_params2D(currimg, params2dlist, xform="xform.align2d"
                     )  # sometimes I get a vector error with sxheader
        set_params_proj(currimg, besteulers,
                        xform="xform.projection")  # use shifts

    if outangles or outaligndoc:
        msg = ''
        if outangles:
            write_text_row(outangleslist, outangles)
            msg += 'Wrote alignment angles to %s\n' % outangles
            print_log_msg(msg, log, verbose)
        if outaligndoc:
            write_text_row(outalignlist, outaligndoc)
            msg += 'Wrote 2D alignment parameters to %s\n' % outaligndoc
            print_log_msg(msg, log, verbose)

    return outalignlist
예제 #4
0
def compare_projs(reconfile,
                  classavgstack,
                  inputanglesdoc,
                  outdir,
                  interpolation_method=1,
                  log=None,
                  verbose=False):
    """
	Make comparison stack between class averages (even-numbered (starts from 0)) and re-projections (odd-numbered).
	
	Arguments:
		reconfile : Input volume from which to generate re-projections
		classavgstack ; Input image stack
		inputanglesdoc : Input Euler angles doc
		outdir ; Output directory
		interpolation_method : Interpolation method: nearest neighbor (nn, 0), trilinear (1, default), gridding (-1)
		log : Logger object
		verbose : (boolean) Whether to write additional information to screen
	Returns:
		compstack : Stack of comparisons between input image stack (even-numbered (starts from 0)) and input volume (odd-numbered)
	"""

    recondata = EMAN2.EMData(reconfile)
    nx = recondata.get_xsize()

    # Resample reference
    reconprep = prep_vol(recondata,
                         npad=2,
                         interpolation_method=interpolation_method)

    ccclist = []

    #  Here you need actual radius to compute proper ccc's, but if you do, you have to deal with translations, PAP
    mask = model_circle(nx // 2 - 2, nx, nx)
    mask.write_image(os.path.join(outdir, 'maskalign.hdf'))
    compstack = os.path.join(outdir, 'comp-proj-reproj.hdf')

    # Number of images may have changed
    nimg1 = EMAN2.EMUtil.get_image_count(classavgstack)
    angleslist = read_text_row(inputanglesdoc)

    for imgnum in range(nimg1):
        # Get class average
        classimg = get_im(classavgstack, imgnum)

        # Compute re-projection
        prjimg = prgl(reconprep,
                      angleslist[imgnum],
                      interpolation_method=1,
                      return_real=False)

        # Calculate 1D power spectra
        rops_dst = rops_table(classimg * mask)
        rops_src = rops_table(prjimg)

        #  Set power spectrum of reprojection to the data.
        #  Since data has an envelope, it would make more sense to set data to reconstruction,
        #  but to do it one would have to know the actual resolution of the data.
        #  you can check sxprocess.py --adjpw to see how this is done properly  PAP
        table = [0.0] * len(rops_dst)  # initialize table
        for j in range(len(rops_dst)):
            table[j] = sqrt(rops_dst[j] / rops_src[j])
        prjimg = fft(filt_table(
            prjimg,
            table))  # match FFT amplitudes of re-projection and class average

        cccoeff = ccc(prjimg, classimg, mask)
        #print imgnum, cccoeff
        classimg.set_attr_dict({'cross-corr': cccoeff})
        prjimg.set_attr_dict({'cross-corr': cccoeff})

        montagestack = []
        montagestack.append(prjimg)
        montagestack.append(classimg)
        comparison_pair = montage2(montagestack, ncol=2, marginwidth=1)
        comparison_pair.write_image(compstack, imgnum)

        ccclist.append(cccoeff)
    del angleslist
    meanccc = sum(ccclist) / nimg1
    print_log_msg("Average CCC is %s\n" % meanccc, log, verbose)

    nimg2 = EMAN2.EMUtil.get_image_count(compstack)

    for imgnum in range(nimg2):  # xrange will be deprecated in Python3
        prjimg = get_im(compstack, imgnum)
        meanccc1 = prjimg.get_attr_default('mean-cross-corr', -1.0)
        prjimg.set_attr_dict({'mean-cross-corr': meanccc})
        write_header(compstack, prjimg, imgnum)

    return compstack
예제 #5
0
def average_angles(alignlist,
                   partdoc,
                   selectdoc=None,
                   init_angles=None,
                   threshold=None,
                   goodpartdoc=None,
                   log=None,
                   verbose=False):
    """
	Computes a vector average of a set of particles' Euler angles phi and theta.
	
	Arguments:
		alignlist : Alignment parameter doc file, i.e., from MERIDIEN refinment
		partdoc : List of particle indices whose angles should be averaged
		selectdoc : Input substack selection file if particles removed before refinement (e.g., Substack/isac_substack_particle_id_list.txt)
		init_angles : List (2 elements) with initial phi and theta angles, for excluding outliers
		threshold : Angular threshold (degrees) beyond which particles exceeding this angular difference from init_angles will be excluded
		goodpartdoc : Output list of retained particles if a threshold was specified
		log : Logger object
		verbose : (boolean) Whether to write additional information to screen
	Returns:
		list of 2 elements:
			avg_phi
			avg_theta
	"""

    # Read alignment parameters
    if isinstance(alignlist, str): alignlist = read_text_row(outaligndoc)
    # (If loading the same parameters repeatedly, better to read the file once externally and pass only the list.)

    # Read class list
    partlist = read_text_row(partdoc)

    if selectdoc:
        selectlist = read_text_row(selectdoc)
    else:
        selectlist = None

    sum_phi = np.array([0.0, 0.0])
    sum_theta = np.array([0.0, 0.0])
    totparts = 0
    num_outliers = 0
    goodpartlist = []
    goodpartcounter = 0

    # Loop through particles
    for totpartnum in partlist:
        if selectlist:
            goodpartnum = selectlist.index(totpartnum)
        else:
            goodpartnum = totpartnum[0]

        try:
            phi_deg = alignlist[goodpartnum][0]
            theta_deg = alignlist[goodpartnum][1]
            phi_rad = np.deg2rad(phi_deg)
            theta_rad = np.deg2rad(theta_deg)
        except IndexError:
            msg = "\nERROR!! %s tries to access particle #%s" % (partdoc,
                                                                 goodpartnum)
            numalignparts = len(alignlist)
            msg += "\nAlignment doc file has only %s entries" % (numalignparts)
            msg += "\nMaybe try substack selection file with flag '--select <substack_select>'?"
            sp_global_def.ERROR(msg, __file__, 1)
            exit()

        if init_angles:
            angdif = angle_diff(init_angles, [phi_deg, theta_deg])
            if angdif > 180: angdif = 360.0 - angdif

        totparts += 1

        # Exclude particles exceeding optional threshold
        if threshold == None or angdif < threshold:
            sum_phi += (np.cos(phi_rad), np.sin(phi_rad))
            sum_theta += (np.cos(theta_rad), np.sin(theta_rad))
            goodpartlist.append(goodpartnum)
            goodpartcounter += 1
        else:
            num_outliers += 1

    # Compute final average
    avg_phi = degrees(atan2(sum_phi[1], sum_phi[0]))
    avg_theta = degrees(atan2(sum_theta[1], sum_theta[0]))

    # Clean up, might reuse
    del alignlist
    del partlist
    del selectlist

    msg = "Particle list %s: average angles (%s, %s)" % (partdoc, avg_phi,
                                                         avg_theta)
    print_log_msg(msg, log, verbose)

    if threshold:
        msg = "Found %s out of %s outliers exceeding an angle difference of %s degrees from initial estimate" % (
            num_outliers, totparts, threshold)
        print_log_msg(msg, log, verbose)

        if goodpartdoc:
            if goodpartcounter > 0:
                write_text_row(goodpartlist, goodpartdoc)
                msg = "Wrote %s particles to %s" % (goodpartcounter,
                                                    goodpartdoc)
                print_log_msg(msg, log, verbose)
            else:
                msg = "WARNING!! Kept 0 particles from class %s" % partdoc
                print_log_msg(msg, log, verbose)
                [avg_phi, avg_theta] = init_angles

    return [avg_phi, avg_theta]
예제 #6
0
def main():
    global Tracker, Blockdata
    progname = os.path.basename(sys.argv[0])
    usage = progname + " --output_dir=output_dir  --isac_dir=output_dir_of_isac "
    parser = optparse.OptionParser(usage, version=sp_global_def.SPARXVERSION)
    parser.add_option(
        "--pw_adjustment",
        type="string",
        default="analytical_model",
        help=
        "adjust power spectrum of 2-D averages to an analytic model. Other opions: no_adjustment; bfactor; a text file of 1D rotationally averaged PW",
    )
    #### Four options for --pw_adjustment:
    # 1> analytical_model(default);
    # 2> no_adjustment;
    # 3> bfactor;
    # 4> adjust_to_given_pw2(user has to provide a text file that contains 1D rotationally averaged PW)

    # options in common
    parser.add_option(
        "--isac_dir",
        type="string",
        default="",
        help="ISAC run output directory, input directory for this command",
    )
    parser.add_option(
        "--output_dir",
        type="string",
        default="",
        help="output directory where computed averages are saved",
    )
    parser.add_option(
        "--pixel_size",
        type="float",
        default=-1.0,
        help=
        "pixel_size of raw images. one can put 1.0 in case of negative stain data",
    )
    parser.add_option(
        "--fl",
        type="float",
        default=-1.0,
        help=
        "low pass filter, = -1.0, not applied; =0.0, using FH1 (initial resolution), = 1.0 using FH2 (resolution after local alignment), or user provided value in absolute freqency [0.0:0.5]",
    )
    parser.add_option("--stack",
                      type="string",
                      default="",
                      help="data stack used in ISAC")
    parser.add_option("--radius", type="int", default=-1, help="radius")
    parser.add_option("--xr",
                      type="float",
                      default=-1.0,
                      help="local alignment search range")
    # parser.add_option("--ts",                    type   ="float",          default =1.0,    help= "local alignment search step")
    parser.add_option(
        "--fh",
        type="float",
        default=-1.0,
        help="local alignment high frequencies limit",
    )
    # parser.add_option("--maxit",                 type   ="int",            default =5,      help= "local alignment iterations")
    parser.add_option("--navg",
                      type="int",
                      default=1000000,
                      help="number of aveages")
    parser.add_option(
        "--local_alignment",
        action="store_true",
        default=False,
        help="do local alignment",
    )
    parser.add_option(
        "--noctf",
        action="store_true",
        default=False,
        help=
        "no ctf correction, useful for negative stained data. always ctf for cryo data",
    )
    parser.add_option(
        "--B_start",
        type="float",
        default=45.0,
        help=
        "start frequency (Angstrom) of power spectrum for B_factor estimation",
    )
    parser.add_option(
        "--Bfactor",
        type="float",
        default=-1.0,
        help=
        "User defined bactors (e.g. 25.0[A^2]). By default, the program automatically estimates B-factor. ",
    )

    (options, args) = parser.parse_args(sys.argv[1:])

    adjust_to_analytic_model = (True if options.pw_adjustment
                                == "analytical_model" else False)
    no_adjustment = True if options.pw_adjustment == "no_adjustment" else False
    B_enhance = True if options.pw_adjustment == "bfactor" else False
    adjust_to_given_pw2 = (
        True if not (adjust_to_analytic_model or no_adjustment or B_enhance)
        else False)

    # mpi
    nproc = mpi.mpi_comm_size(mpi.MPI_COMM_WORLD)
    myid = mpi.mpi_comm_rank(mpi.MPI_COMM_WORLD)

    Blockdata = {}
    Blockdata["nproc"] = nproc
    Blockdata["myid"] = myid
    Blockdata["main_node"] = 0
    Blockdata["shared_comm"] = mpi.mpi_comm_split_type(
        mpi.MPI_COMM_WORLD, mpi.MPI_COMM_TYPE_SHARED, 0, mpi.MPI_INFO_NULL)
    Blockdata["myid_on_node"] = mpi.mpi_comm_rank(Blockdata["shared_comm"])
    Blockdata["no_of_processes_per_group"] = mpi.mpi_comm_size(
        Blockdata["shared_comm"])
    masters_from_groups_vs_everything_else_comm = mpi.mpi_comm_split(
        mpi.MPI_COMM_WORLD,
        Blockdata["main_node"] == Blockdata["myid_on_node"],
        Blockdata["myid_on_node"],
    )
    Blockdata["color"], Blockdata[
        "no_of_groups"], balanced_processor_load_on_nodes = sp_utilities.get_colors_and_subsets(
            Blockdata["main_node"],
            mpi.MPI_COMM_WORLD,
            Blockdata["myid"],
            Blockdata["shared_comm"],
            Blockdata["myid_on_node"],
            masters_from_groups_vs_everything_else_comm,
        )
    #  We need two nodes for processing of volumes
    Blockdata["node_volume"] = [
        Blockdata["no_of_groups"] - 3,
        Blockdata["no_of_groups"] - 2,
        Blockdata["no_of_groups"] - 1,
    ]  # For 3D stuff take three last nodes
    #  We need two CPUs for processing of volumes, they are taken to be main CPUs on each volume
    #  We have to send the two myids to all nodes so we can identify main nodes on two selected groups.
    Blockdata["nodes"] = [
        Blockdata["node_volume"][0] * Blockdata["no_of_processes_per_group"],
        Blockdata["node_volume"][1] * Blockdata["no_of_processes_per_group"],
        Blockdata["node_volume"][2] * Blockdata["no_of_processes_per_group"],
    ]
    # End of Blockdata: sorting requires at least three nodes, and the used number of nodes be integer times of three
    sp_global_def.BATCH = True
    sp_global_def.MPI = True

    if adjust_to_given_pw2:
        checking_flag = 0
        if Blockdata["myid"] == Blockdata["main_node"]:
            if not os.path.exists(options.pw_adjustment):
                checking_flag = 1
        checking_flag = sp_utilities.bcast_number_to_all(
            checking_flag, Blockdata["main_node"], mpi.MPI_COMM_WORLD)

        if checking_flag == 1:
            sp_global_def.ERROR("User provided power spectrum does not exist",
                                myid=Blockdata["myid"])

    Tracker = {}
    Constants = {}
    Constants["isac_dir"] = options.isac_dir
    Constants["masterdir"] = options.output_dir
    Constants["pixel_size"] = options.pixel_size
    Constants["orgstack"] = options.stack
    Constants["radius"] = options.radius
    Constants["xrange"] = options.xr
    Constants["FH"] = options.fh
    Constants["low_pass_filter"] = options.fl
    # Constants["maxit"]                        = options.maxit
    Constants["navg"] = options.navg
    Constants["B_start"] = options.B_start
    Constants["Bfactor"] = options.Bfactor

    if adjust_to_given_pw2:
        Constants["modelpw"] = options.pw_adjustment
    Tracker["constants"] = Constants
    # -------------------------------------------------------------
    #
    # Create and initialize Tracker dictionary with input options  # State Variables

    # <<<---------------------->>>imported functions<<<---------------------------------------------

    # x_range = max(Tracker["constants"]["xrange"], int(1./Tracker["ini_shrink"])+1)
    # y_range =  x_range

    ####-----------------------------------------------------------
    # Create Master directory and associated subdirectories
    line = time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()) + " =>"
    if Tracker["constants"]["masterdir"] == Tracker["constants"]["isac_dir"]:
        masterdir = os.path.join(Tracker["constants"]["isac_dir"], "sharpen")
    else:
        masterdir = Tracker["constants"]["masterdir"]

    if Blockdata["myid"] == Blockdata["main_node"]:
        msg = "Postprocessing ISAC 2D averages starts"
        sp_global_def.sxprint(line, "Postprocessing ISAC 2D averages starts")
        if not masterdir:
            timestring = time.strftime("_%d_%b_%Y_%H_%M_%S", time.localtime())
            masterdir = "sharpen_" + Tracker["constants"]["isac_dir"]
            os.makedirs(masterdir)
        else:
            if os.path.exists(masterdir):
                sp_global_def.sxprint("%s already exists" % masterdir)
            else:
                os.makedirs(masterdir)
        sp_global_def.write_command(masterdir)
        subdir_path = os.path.join(masterdir, "ali2d_local_params_avg")
        if not os.path.exists(subdir_path):
            os.mkdir(subdir_path)
        subdir_path = os.path.join(masterdir, "params_avg")
        if not os.path.exists(subdir_path):
            os.mkdir(subdir_path)
        li = len(masterdir)
    else:
        li = 0
    li = mpi.mpi_bcast(li, 1, mpi.MPI_INT, Blockdata["main_node"],
                       mpi.MPI_COMM_WORLD)[0]
    masterdir = mpi.mpi_bcast(masterdir, li, mpi.MPI_CHAR,
                              Blockdata["main_node"], mpi.MPI_COMM_WORLD)
    masterdir = b"".join(masterdir).decode('latin1')
    Tracker["constants"]["masterdir"] = masterdir
    log_main = sp_logger.Logger(sp_logger.BaseLogger_Files())
    log_main.prefix = Tracker["constants"]["masterdir"] + "/"

    while not os.path.exists(Tracker["constants"]["masterdir"]):
        sp_global_def.sxprint(
            "Node ",
            Blockdata["myid"],
            "  waiting...",
            Tracker["constants"]["masterdir"],
        )
        time.sleep(1)
    mpi.mpi_barrier(mpi.MPI_COMM_WORLD)

    if Blockdata["myid"] == Blockdata["main_node"]:
        init_dict = {}
        sp_global_def.sxprint(Tracker["constants"]["isac_dir"])
        Tracker["directory"] = os.path.join(Tracker["constants"]["isac_dir"],
                                            "2dalignment")
        core = sp_utilities.read_text_row(
            os.path.join(Tracker["directory"], "initial2Dparams.txt"))
        for im in range(len(core)):
            init_dict[im] = core[im]
        del core
    else:
        init_dict = 0
    init_dict = sp_utilities.wrap_mpi_bcast(init_dict,
                                            Blockdata["main_node"],
                                            communicator=mpi.MPI_COMM_WORLD)
    ###
    do_ctf = True
    if options.noctf:
        do_ctf = False
    if Blockdata["myid"] == Blockdata["main_node"]:
        if do_ctf:
            sp_global_def.sxprint("CTF correction is on")
        else:
            sp_global_def.sxprint("CTF correction is off")
        if options.local_alignment:
            sp_global_def.sxprint("local refinement is on")
        else:
            sp_global_def.sxprint("local refinement is off")
        if B_enhance:
            sp_global_def.sxprint("Bfactor is to be applied on averages")
        elif adjust_to_given_pw2:
            sp_global_def.sxprint(
                "PW of averages is adjusted to a given 1D PW curve")
        elif adjust_to_analytic_model:
            sp_global_def.sxprint(
                "PW of averages is adjusted to analytical model")
        else:
            sp_global_def.sxprint("PW of averages is not adjusted")
        # Tracker["constants"]["orgstack"] = "bdb:"+ os.path.join(Tracker["constants"]["isac_dir"],"../","sparx_stack")
        image = sp_utilities.get_im(Tracker["constants"]["orgstack"], 0)
        Tracker["constants"]["nnxo"] = image.get_xsize()
        if Tracker["constants"]["pixel_size"] == -1.0:
            sp_global_def.sxprint(
                "Pixel size value is not provided by user. extracting it from ctf header entry of the original stack."
            )
            try:
                ctf_params = image.get_attr("ctf")
                Tracker["constants"]["pixel_size"] = ctf_params.apix
            except:
                sp_global_def.ERROR(
                    "Pixel size could not be extracted from the original stack.",
                    myid=Blockdata["myid"],
                )
        ## Now fill in low-pass filter

        isac_shrink_path = os.path.join(Tracker["constants"]["isac_dir"],
                                        "README_shrink_ratio.txt")

        if not os.path.exists(isac_shrink_path):
            sp_global_def.ERROR(
                "%s does not exist in the specified ISAC run output directory"
                % (isac_shrink_path),
                myid=Blockdata["myid"],
            )

        isac_shrink_file = open(isac_shrink_path, "r")
        isac_shrink_lines = isac_shrink_file.readlines()
        isac_shrink_ratio = float(
            isac_shrink_lines[5]
        )  # 6th line: shrink ratio (= [target particle radius]/[particle radius]) used in the ISAC run
        isac_radius = float(
            isac_shrink_lines[6]
        )  # 7th line: particle radius at original pixel size used in the ISAC run
        isac_shrink_file.close()
        print("Extracted parameter values")
        print("ISAC shrink ratio    : {0}".format(isac_shrink_ratio))
        print("ISAC particle radius : {0}".format(isac_radius))
        Tracker["ini_shrink"] = isac_shrink_ratio
    else:
        Tracker["ini_shrink"] = 0.0
    Tracker = sp_utilities.wrap_mpi_bcast(Tracker,
                                          Blockdata["main_node"],
                                          communicator=mpi.MPI_COMM_WORLD)

    # print(Tracker["constants"]["pixel_size"], "pixel_size")
    x_range = max(
        Tracker["constants"]["xrange"],
        int(old_div(1.0, Tracker["ini_shrink"]) + 0.99999),
    )
    a_range = y_range = x_range

    if Blockdata["myid"] == Blockdata["main_node"]:
        parameters = sp_utilities.read_text_row(
            os.path.join(Tracker["constants"]["isac_dir"],
                         "all_parameters.txt"))
    else:
        parameters = 0
    parameters = sp_utilities.wrap_mpi_bcast(parameters,
                                             Blockdata["main_node"],
                                             communicator=mpi.MPI_COMM_WORLD)
    params_dict = {}
    list_dict = {}
    # parepare params_dict

    # navg = min(Tracker["constants"]["navg"]*Blockdata["nproc"], EMUtil.get_image_count(os.path.join(Tracker["constants"]["isac_dir"], "class_averages.hdf")))
    navg = min(
        Tracker["constants"]["navg"],
        EMAN2_cppwrap.EMUtil.get_image_count(
            os.path.join(Tracker["constants"]["isac_dir"],
                         "class_averages.hdf")),
    )
    global_dict = {}
    ptl_list = []
    memlist = []
    if Blockdata["myid"] == Blockdata["main_node"]:
        sp_global_def.sxprint("Number of averages computed in this run is %d" %
                              navg)
        for iavg in range(navg):
            params_of_this_average = []
            image = sp_utilities.get_im(
                os.path.join(Tracker["constants"]["isac_dir"],
                             "class_averages.hdf"),
                iavg,
            )
            members = sorted(image.get_attr("members"))
            memlist.append(members)
            for im in range(len(members)):
                abs_id = members[im]
                global_dict[abs_id] = [iavg, im]
                P = sp_utilities.combine_params2(
                    init_dict[abs_id][0],
                    init_dict[abs_id][1],
                    init_dict[abs_id][2],
                    init_dict[abs_id][3],
                    parameters[abs_id][0],
                    old_div(parameters[abs_id][1], Tracker["ini_shrink"]),
                    old_div(parameters[abs_id][2], Tracker["ini_shrink"]),
                    parameters[abs_id][3],
                )
                if parameters[abs_id][3] == -1:
                    sp_global_def.sxprint(
                        "WARNING: Image #{0} is an unaccounted particle with invalid 2D alignment parameters and should not be the member of any classes. Please check the consitency of input dataset."
                        .format(abs_id)
                    )  # How to check what is wrong about mirror = -1 (Toshio 2018/01/11)
                params_of_this_average.append([P[0], P[1], P[2], P[3], 1.0])
                ptl_list.append(abs_id)
            params_dict[iavg] = params_of_this_average
            list_dict[iavg] = members
            sp_utilities.write_text_row(
                params_of_this_average,
                os.path.join(
                    Tracker["constants"]["masterdir"],
                    "params_avg",
                    "params_avg_%03d.txt" % iavg,
                ),
            )
        ptl_list.sort()
        init_params = [None for im in range(len(ptl_list))]
        for im in range(len(ptl_list)):
            init_params[im] = [ptl_list[im]] + params_dict[global_dict[
                ptl_list[im]][0]][global_dict[ptl_list[im]][1]]
        sp_utilities.write_text_row(
            init_params,
            os.path.join(Tracker["constants"]["masterdir"],
                         "init_isac_params.txt"),
        )
    else:
        params_dict = 0
        list_dict = 0
        memlist = 0
    params_dict = sp_utilities.wrap_mpi_bcast(params_dict,
                                              Blockdata["main_node"],
                                              communicator=mpi.MPI_COMM_WORLD)
    list_dict = sp_utilities.wrap_mpi_bcast(list_dict,
                                            Blockdata["main_node"],
                                            communicator=mpi.MPI_COMM_WORLD)
    memlist = sp_utilities.wrap_mpi_bcast(memlist,
                                          Blockdata["main_node"],
                                          communicator=mpi.MPI_COMM_WORLD)
    # Now computing!
    del init_dict
    tag_sharpen_avg = 1000
    ## always apply low pass filter to B_enhanced images to suppress noise in high frequencies
    enforced_to_H1 = False
    if B_enhance:
        if Tracker["constants"]["low_pass_filter"] == -1.0:
            enforced_to_H1 = True

    # distribute workload among mpi processes
    image_start, image_end = sp_applications.MPI_start_end(
        navg, Blockdata["nproc"], Blockdata["myid"])

    if Blockdata["myid"] == Blockdata["main_node"]:
        cpu_dict = {}
        for iproc in range(Blockdata["nproc"]):
            local_image_start, local_image_end = sp_applications.MPI_start_end(
                navg, Blockdata["nproc"], iproc)
            for im in range(local_image_start, local_image_end):
                cpu_dict[im] = iproc
    else:
        cpu_dict = 0

    cpu_dict = sp_utilities.wrap_mpi_bcast(cpu_dict,
                                           Blockdata["main_node"],
                                           communicator=mpi.MPI_COMM_WORLD)

    slist = [None for im in range(navg)]
    ini_list = [None for im in range(navg)]
    avg1_list = [None for im in range(navg)]
    avg2_list = [None for im in range(navg)]
    data_list = [None for im in range(navg)]
    plist_dict = {}

    if Blockdata["myid"] == Blockdata["main_node"]:
        if B_enhance:
            sp_global_def.sxprint(
                "Avg ID   B-factor  FH1(Res before ali) FH2(Res after ali)")
        else:
            sp_global_def.sxprint(
                "Avg ID   FH1(Res before ali)  FH2(Res after ali)")

    FH_list = [[0, 0.0, 0.0] for im in range(navg)]
    for iavg in range(image_start, image_end):

        mlist = EMAN2_cppwrap.EMData.read_images(
            Tracker["constants"]["orgstack"], list_dict[iavg])

        for im in range(len(mlist)):
            sp_utilities.set_params2D(mlist[im],
                                      params_dict[iavg][im],
                                      xform="xform.align2d")

        if options.local_alignment:
            new_avg, plist, FH2 = sp_applications.refinement_2d_local(
                mlist,
                Tracker["constants"]["radius"],
                a_range,
                x_range,
                y_range,
                CTF=do_ctf,
                SNR=1.0e10,
            )
            plist_dict[iavg] = plist
            FH1 = -1.0

        else:
            new_avg, frc, plist = compute_average(
                mlist, Tracker["constants"]["radius"], do_ctf)
            FH1 = get_optimistic_res(frc)
            FH2 = -1.0

        FH_list[iavg] = [iavg, FH1, FH2]

        if B_enhance:
            new_avg, gb = apply_enhancement(
                new_avg,
                Tracker["constants"]["B_start"],
                Tracker["constants"]["pixel_size"],
                Tracker["constants"]["Bfactor"],
            )
            sp_global_def.sxprint("  %6d      %6.3f  %4.3f  %4.3f" %
                                  (iavg, gb, FH1, FH2))

        elif adjust_to_given_pw2:
            roo = sp_utilities.read_text_file(Tracker["constants"]["modelpw"],
                                              -1)
            roo = roo[0]  # always on the first column
            new_avg = adjust_pw_to_model(new_avg,
                                         Tracker["constants"]["pixel_size"],
                                         roo)
            sp_global_def.sxprint("  %6d      %4.3f  %4.3f  " %
                                  (iavg, FH1, FH2))

        elif adjust_to_analytic_model:
            new_avg = adjust_pw_to_model(new_avg,
                                         Tracker["constants"]["pixel_size"],
                                         None)
            sp_global_def.sxprint("  %6d      %4.3f  %4.3f   " %
                                  (iavg, FH1, FH2))

        elif no_adjustment:
            pass

        if Tracker["constants"]["low_pass_filter"] != -1.0:
            if Tracker["constants"]["low_pass_filter"] == 0.0:
                low_pass_filter = FH1
            elif Tracker["constants"]["low_pass_filter"] == 1.0:
                low_pass_filter = FH2
                if not options.local_alignment:
                    low_pass_filter = FH1
            else:
                low_pass_filter = Tracker["constants"]["low_pass_filter"]
                if low_pass_filter >= 0.45:
                    low_pass_filter = 0.45
            new_avg = sp_filter.filt_tanl(new_avg, low_pass_filter, 0.02)
        else:  # No low pass filter but if enforced
            if enforced_to_H1:
                new_avg = sp_filter.filt_tanl(new_avg, FH1, 0.02)
        if B_enhance:
            new_avg = sp_fundamentals.fft(new_avg)

        new_avg.set_attr("members", list_dict[iavg])
        new_avg.set_attr("n_objects", len(list_dict[iavg]))
        slist[iavg] = new_avg
        sp_global_def.sxprint(
            time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()) + " =>",
            "Refined average %7d" % iavg,
        )

    ## send to main node to write
    mpi.mpi_barrier(mpi.MPI_COMM_WORLD)

    for im in range(navg):
        # avg
        if (cpu_dict[im] == Blockdata["myid"]
                and Blockdata["myid"] != Blockdata["main_node"]):
            sp_utilities.send_EMData(slist[im], Blockdata["main_node"],
                                     tag_sharpen_avg)

        elif (cpu_dict[im] == Blockdata["myid"]
              and Blockdata["myid"] == Blockdata["main_node"]):
            slist[im].set_attr("members", memlist[im])
            slist[im].set_attr("n_objects", len(memlist[im]))
            slist[im].write_image(
                os.path.join(Tracker["constants"]["masterdir"],
                             "class_averages.hdf"),
                im,
            )

        elif (cpu_dict[im] != Blockdata["myid"]
              and Blockdata["myid"] == Blockdata["main_node"]):
            new_avg_other_cpu = sp_utilities.recv_EMData(
                cpu_dict[im], tag_sharpen_avg)
            new_avg_other_cpu.set_attr("members", memlist[im])
            new_avg_other_cpu.set_attr("n_objects", len(memlist[im]))
            new_avg_other_cpu.write_image(
                os.path.join(Tracker["constants"]["masterdir"],
                             "class_averages.hdf"),
                im,
            )

        if options.local_alignment:
            if cpu_dict[im] == Blockdata["myid"]:
                sp_utilities.write_text_row(
                    plist_dict[im],
                    os.path.join(
                        Tracker["constants"]["masterdir"],
                        "ali2d_local_params_avg",
                        "ali2d_local_params_avg_%03d.txt" % im,
                    ),
                )

            if (cpu_dict[im] == Blockdata["myid"]
                    and cpu_dict[im] != Blockdata["main_node"]):
                sp_utilities.wrap_mpi_send(plist_dict[im],
                                           Blockdata["main_node"],
                                           mpi.MPI_COMM_WORLD)
                sp_utilities.wrap_mpi_send(FH_list, Blockdata["main_node"],
                                           mpi.MPI_COMM_WORLD)

            elif (cpu_dict[im] != Blockdata["main_node"]
                  and Blockdata["myid"] == Blockdata["main_node"]):
                dummy = sp_utilities.wrap_mpi_recv(cpu_dict[im],
                                                   mpi.MPI_COMM_WORLD)
                plist_dict[im] = dummy
                dummy = sp_utilities.wrap_mpi_recv(cpu_dict[im],
                                                   mpi.MPI_COMM_WORLD)
                FH_list[im] = dummy[im]
        else:
            if (cpu_dict[im] == Blockdata["myid"]
                    and cpu_dict[im] != Blockdata["main_node"]):
                sp_utilities.wrap_mpi_send(FH_list, Blockdata["main_node"],
                                           mpi.MPI_COMM_WORLD)

            elif (cpu_dict[im] != Blockdata["main_node"]
                  and Blockdata["myid"] == Blockdata["main_node"]):
                dummy = sp_utilities.wrap_mpi_recv(cpu_dict[im],
                                                   mpi.MPI_COMM_WORLD)
                FH_list[im] = dummy[im]

        mpi.mpi_barrier(mpi.MPI_COMM_WORLD)
    mpi.mpi_barrier(mpi.MPI_COMM_WORLD)

    if options.local_alignment:
        if Blockdata["myid"] == Blockdata["main_node"]:
            ali3d_local_params = [None for im in range(len(ptl_list))]
            for im in range(len(ptl_list)):
                ali3d_local_params[im] = [ptl_list[im]] + plist_dict[
                    global_dict[ptl_list[im]][0]][global_dict[ptl_list[im]][1]]
            sp_utilities.write_text_row(
                ali3d_local_params,
                os.path.join(Tracker["constants"]["masterdir"],
                             "ali2d_local_params.txt"),
            )
            sp_utilities.write_text_row(
                FH_list,
                os.path.join(Tracker["constants"]["masterdir"], "FH_list.txt"))
    else:
        if Blockdata["myid"] == Blockdata["main_node"]:
            sp_utilities.write_text_row(
                FH_list,
                os.path.join(Tracker["constants"]["masterdir"], "FH_list.txt"))

    mpi.mpi_barrier(mpi.MPI_COMM_WORLD)
    target_xr = 3
    target_yr = 3
    if Blockdata["myid"] == 0:
        cmd = "{} {} {} {} {} {} {} {} {} {}".format(
            "sp_chains.py",
            os.path.join(Tracker["constants"]["masterdir"],
                         "class_averages.hdf"),
            os.path.join(Tracker["constants"]["masterdir"], "junk.hdf"),
            os.path.join(Tracker["constants"]["masterdir"],
                         "ordered_class_averages.hdf"),
            "--circular",
            "--radius=%d" % Tracker["constants"]["radius"],
            "--xr=%d" % (target_xr + 1),
            "--yr=%d" % (target_yr + 1),
            "--align",
            ">/dev/null",
        )
        junk = sp_utilities.cmdexecute(cmd)
        cmd = "{} {}".format(
            "rm -rf",
            os.path.join(Tracker["constants"]["masterdir"], "junk.hdf"))
        junk = sp_utilities.cmdexecute(cmd)

    return
예제 #7
0
def combine_isac_params(isac_dir,
                        classavgstack,
                        chains_params_file,
                        old_combined_parts,
                        classdoc,
                        combined_params_file,
                        log=False,
                        verbose=False):
    """
	Combines initial and all_params from ISAC.
	
	Arguments:
		isac_dir : ISAC directory
		classavgstack : Input image stack
		chains_params_file : Input alignment parameters applied to averages in sp_chains
		old_combined_parts
		classdoc
		combined_params_file : Output combined alignment parameters
		log : instance of Logger class
		verbose : (boolean) Whether to write to screen
	"""

    from sp_utilities import combine_params2, read_text_row

    # File-handling
    init_params_file = os.path.join(isac_dir, "2dalignment",
                                    "initial2Dparams.txt")
    all_params_file = os.path.join(isac_dir, "all_parameters.txt")
    init_params_list = read_text_row(init_params_file)
    all_params = read_text_row(all_params_file)
    isac_shrink_path = os.path.join(isac_dir, "README_shrink_ratio.txt")
    isac_shrink_file = open(isac_shrink_path, "r")
    isac_shrink_lines = isac_shrink_file.readlines()
    isac_shrink_ratio = float(isac_shrink_lines[5])
    """
	Three cases:
		1) Using class_averages.hdf
		2) Using ordered_class_averages.hdf, but chains_params.txt doesn't exist
		3) Using ordered_class_averages.hdf and chains_params.txt 
	"""

    msg = "Combining alignment parameters from %s and %s, dividing by %s, and writing to %s" % \
     (init_params_file, all_params_file, isac_shrink_ratio, combined_params_file)

    # Check if using ordered class averages and whether chains_params exists
    if os.path.basename(classavgstack) == 'ordered_class_averages.hdf':
        if not os.path.exists(chains_params_file):
            msg += "WARNING: '%s' does not exist. " % chains_params_file
            msg += "         Using '%s' but alignment parameters correspond to 'class_averages.hdf'.\n" % classavgstack
        else:
            msg = "Combining alignment parameters from %s, %s, and %s, dividing by %s, and writing to %s" % \
             (init_params_file, all_params_file, chains_params_file, isac_shrink_ratio, combined_params_file)

    print_log_msg(msg, log, verbose)

    if os.path.basename(
            classavgstack) == 'ordered_class_averages.hdf' and os.path.exists(
                chains_params_file):
        chains_params_list = read_text_row(chains_params_file)
        old_combined_list = read_text_row(old_combined_parts)
        num_classes = EMUtil.get_image_count(classavgstack)
        tmp_combined = []

        # Loop through classes
        for class_num in range(num_classes):
            # Extract members
            image = get_im(classavgstack, class_num)
            members = sorted(image.get_attr("members"))
            old_class_list = read_text_row(classdoc.format(class_num))
            new_class_list = []

            # Loop through particles
            for idx, im in enumerate(members):
                tmp_par = combine_params2(
                    init_params_list[im][0], init_params_list[im][1],
                    init_params_list[im][2], init_params_list[im][3],
                    all_params[im][0], all_params[im][1] / isac_shrink_ratio,
                    all_params[im][2] / isac_shrink_ratio, all_params[im][3])

                # Combine with class-average parameters
                P = combine_params2(tmp_par[0], tmp_par[1], tmp_par[2],
                                    tmp_par[3],
                                    chains_params_list[class_num][2],
                                    chains_params_list[class_num][3],
                                    chains_params_list[class_num][4],
                                    chains_params_list[class_num][5])

                tmp_combined.append([im, P[0], P[1], P[2], P[3]])

                # Need to update class number in class docs
                old_part_num = old_class_list[idx]

                try:
                    new_part_num = old_combined_list.index(old_part_num)
                except ValueError:
                    print(
                        "Couldn't find particle: class_num %s, old_part_num %s, new_part_num %s"
                        % (class_num, old_part_num[0], new_part_num))

                new_class_list.append(new_part_num)
            # End particle-loop

            # Overwrite pre-existing class doc
            write_text_row(new_class_list, classdoc.format(class_num))
        # End class-loop

        # Sort by particle number
        combined_params = sorted(tmp_combined, key=itemgetter(0))

        # Remove first column
        for row in combined_params:
            del row[0]

    # Not applying alignments of ordered_class_averages
    else:
        combined_params = []

        # Loop through images
        for im in range(len(all_params)):
            P = combine_params2(
                init_params_list[im][0], init_params_list[im][1],
                init_params_list[im][2], init_params_list[im][3],
                all_params[im][0], all_params[im][1] / isac_shrink_ratio,
                all_params[im][2] / isac_shrink_ratio, all_params[im][3])
            combined_params.append([P[0], P[1], P[2], P[3], 1.0])

    write_text_row(combined_params, combined_params_file)
    print_log_msg(
        'Wrote %s entries to %s\n' %
        (len(combined_params), combined_params_file), log, verbose)

    return combined_params
예제 #8
0
def main():
    progname = os.path.basename(sys.argv[0])
    usage = progname + """ [options] <inputfile> <outputfile>

	Forms chains of 2D images based on their similarities.

	Functionality:


	Order a 2-D stack of image based on pair-wise similarity (computed as a cross-correlation coefficent).
		Options 1-3 require image stack to be aligned.  The program will apply orientation parameters if present in headers.
	    The ways to use the program:
	   1.  Use option initial to specify which image will be used as an initial seed to form the chain.
	        sp_chains.py input_stack.hdf output_stack.hdf --initial=23 --radius=25

	   2.  If options initial is omitted, the program will determine which image best serves as initial seed to form the chain
	        sp_chains.py input_stack.hdf output_stack.hdf --radius=25

	   3.  Use option circular to form a circular chain.
	        sp_chains.py input_stack.hdf output_stack.hdf --circular--radius=25

	   4.  New circular code based on pairwise alignments
			sp_chains.py aclf.hdf chain.hdf circle.hdf --align  --radius=25 --xr=2 --pairwiseccc=lcc.txt

	   5.  Circular ordering based on pairwise alignments
			sp_chains.py vols.hdf chain.hdf mask.hdf --dd  --radius=25


"""

    parser = OptionParser(usage, version=SPARXVERSION)
    parser.add_option(
        "--dd",
        action="store_true",
        help="Circular ordering without adjustment of orientations",
        default=False)
    parser.add_option(
        "--circular",
        action="store_true",
        help=
        "Select circular ordering (first image has to be similar to the last)",
        default=False)
    parser.add_option(
        "--align",
        action="store_true",
        help=
        "Compute all pairwise alignments and from the table of image similarities find the best chain",
        default=False)
    parser.add_option(
        "--initial",
        type="int",
        default=-1,
        help=
        "Specifies which image will be used as an initial seed to form the chain. (default = 0, means the first image)"
    )
    parser.add_option(
        "--radius",
        type="int",
        default=-1,
        help="Radius of a circular mask for similarity based ordering")
    #  import params for 2D alignment
    parser.add_option(
        "--ou",
        type="int",
        default=-1,
        help=
        "outer radius for 2D alignment < nx/2-1 (set to the radius of the particle)"
    )
    parser.add_option(
        "--xr",
        type="int",
        default=0,
        help="range for translation search in x direction, search is +/xr (0)")
    parser.add_option(
        "--yr",
        type="int",
        default=0,
        help="range for translation search in y direction, search is +/yr (0)")
    # parser.add_option("--nomirror",     action="store_true", default=False,   help="Disable checking mirror orientations of images (default False)")
    parser.add_option("--pairwiseccc",
                      type="string",
                      default=" ",
                      help="Input/output pairwise ccc file")

    (options, args) = parser.parse_args()

    sp_global_def.BATCH = True

    if options.dd:
        nargs = len(args)
        if nargs != 3:
            ERROR("Must provide name of input and two output files!")
            return

        stack = args[0]
        new_stack = args[1]

        from sp_utilities import model_circle
        from sp_statistics import ccc
        from sp_statistics import mono
        lend = EMUtil.get_image_count(stack)
        lccc = [None] * (old_div(lend * (lend - 1), 2))

        for i in range(lend - 1):
            v1 = get_im(stack, i)
            if (i == 0 and nargs == 2):
                nx = v1.get_xsize()
                ny = v1.get_ysize()
                nz = v1.get_ysize()
                if options.ou < 1:
                    radius = old_div(nx, 2) - 2
                else:
                    radius = options.ou
                mask = model_circle(radius, nx, ny, nz)
            else:
                mask = get_im(args[2])

            for j in range(i + 1, lend):
                lccc[mono(i, j)] = [ccc(v1, get_im(stack, j), mask), 0]

        order = tsp(lccc)
        if (len(order) != lend):
            ERROR("Problem with data length")
            return

        sxprint("Total sum of cccs :", TotalDistance(order, lccc))
        sxprint("ordering :", order)
        for i in range(lend):
            get_im(stack, order[i]).write_image(new_stack, i)

    elif options.align:
        nargs = len(args)
        if nargs != 3:
            ERROR("Must provide name of input and two output files!")
            return

        from sp_utilities import get_params2D, model_circle
        from sp_fundamentals import rot_shift2D
        from sp_statistics import ccc
        from time import time
        from sp_alignment import align2d, align2d_scf

        stack = args[0]
        new_stack = args[1]

        d = EMData.read_images(stack)
        if (len(d) < 6):
            ERROR(
                "Chains requires at least six images in the input stack to be executed"
            )
            return
        """
		# will align anyway
		try:
			ttt = d[0].get_attr('xform.params2d')
			for i in xrange(len(d)):
				alpha, sx, sy, mirror, scale = get_params2D(d[i])
				d[i] = rot_shift2D(d[i], alpha, sx, sy, mirror)
		except:
			pass
		"""

        nx = d[0].get_xsize()
        ny = d[0].get_ysize()
        if options.ou < 1:
            radius = old_div(nx, 2) - 2
        else:
            radius = options.ou
        mask = model_circle(radius, nx, ny)

        if (options.xr < 0):
            xrng = 0
        else:
            xrng = options.xr
        if (options.yr < 0):
            yrng = xrng
        else:
            yrng = options.yr

        initial = max(options.initial, 0)

        from sp_statistics import mono
        lend = len(d)
        lccc = [None] * (old_div(lend * (lend - 1), 2))
        from sp_utilities import read_text_row

        if options.pairwiseccc == " " or not os.path.exists(
                options.pairwiseccc):
            st = time()
            for i in range(lend - 1):
                for j in range(i + 1, lend):
                    #  j>i meaning mono entry (i,j) or (j,i) indicates T i->j (from smaller index to larger)
                    # alpha, sx, sy, mir, peak = align2d(d[i],d[j], xrng, yrng, step=options.ts, first_ring=options.ir, last_ring=radius, mode = "F")
                    alpha, sx, sy, mir, peak = align2d_scf(d[i],
                                                           d[j],
                                                           xrng,
                                                           yrng,
                                                           ou=radius)
                    lccc[mono(i, j)] = [
                        ccc(d[j], rot_shift2D(d[i], alpha, sx, sy, mir, 1.0),
                            mask), alpha, sx, sy, mir
                    ]
            # print "  %4d   %10.1f"%(i,time()-st)

            if ((not os.path.exists(options.pairwiseccc))
                    and (options.pairwiseccc != " ")):
                write_text_row([[initial, 0, 0, 0, 0]] + lccc,
                               options.pairwiseccc)
        elif (os.path.exists(options.pairwiseccc)):
            lccc = read_text_row(options.pairwiseccc)
            initial = int(lccc[0][0] + 0.1)
            del lccc[0]

        for i in range(len(lccc)):
            T = Transform({
                "type": "2D",
                "alpha": lccc[i][1],
                "tx": lccc[i][2],
                "ty": lccc[i][3],
                "mirror": int(lccc[i][4] + 0.1)
            })
            lccc[i] = [lccc[i][0], T]

        tdummy = Transform({"type": "2D"})
        maxsum = -1.023
        for m in range(0, lend):  # initial, initial+1):
            indc = list(range(lend))
            lsnake = [[m, tdummy, 0.0]]
            del indc[m]

            lsum = 0.0
            while len(indc) > 1:
                maxcit = -111.
                for i in range(len(indc)):
                    cuc = lccc[mono(indc[i], lsnake[-1][0])][0]
                    if cuc > maxcit:
                        maxcit = cuc
                        qi = indc[i]
                        #  Here we need transformation from the current to the previous,
                        #     meaning indc[i] -> lsnake[-1][0]
                        T = lccc[mono(indc[i], lsnake[-1][0])][1]
                        #  If direction is from larger to smaller index, the transformation has to be inverted
                        if (indc[i] > lsnake[-1][0]): T = T.inverse()

                lsnake.append([qi, T, maxcit])
                lsum += maxcit

                del indc[indc.index(qi)]

            T = lccc[mono(indc[-1], lsnake[-1][0])][1]
            if (indc[-1] > lsnake[-1][0]): T = T.inverse()
            lsnake.append(
                [indc[-1], T, lccc[mono(indc[-1], lsnake[-1][0])][0]])
            sxprint(" initial image and lsum  ", m, lsum)
            # print lsnake
            if (lsum > maxsum):
                maxsum = lsum
                init = m
                snake = [lsnake[i] for i in range(lend)]
        sxprint("  Initial image selected : ", init, maxsum, "    ",
                TotalDistance([snake[m][0] for m in range(lend)], lccc))
        # for q in snake: print q

        from copy import deepcopy
        trans = deepcopy([snake[i][1] for i in range(len(snake))])
        sxprint([snake[i][0] for i in range(len(snake))])
        """
		for m in xrange(lend):
			prms = trans[m].get_params("2D")
			print " %3d   %7.1f   %7.1f   %7.1f   %2d  %6.2f"%(snake[m][0], prms["alpha"], prms["tx"], prms["ty"], prms["mirror"], snake[m][2])
		"""
        for k in range(lend - 2, 0, -1):
            T = snake[k][1]
            for i in range(k + 1, lend):
                trans[i] = T * trans[i]
        #  To add - apply all transformations and do the overall centering.
        for m in range(lend):
            prms = trans[m].get_params("2D")
            # print(" %3d   %7.1f   %7.1f   %7.1f   %2d  %6.2f"%(snake[m][0], prms["alpha"], prms["tx"], prms["ty"], prms["mirror"], snake[m][2]) )
            # rot_shift2D(d[snake[m][0]], prms["alpha"], prms["tx"], prms["ty"], prms["mirror"]).write_image(new_stack, m)
            rot_shift2D(d[snake[m][0]], prms["alpha"], 0.0, 0.0,
                        prms["mirror"]).write_image(new_stack, m)

        order = tsp(lccc)
        if (len(order) != lend):
            ERROR("Problem with data length")
            return

        sxprint(TotalDistance(order, lccc))
        sxprint(order)
        ibeg = order.index(init)
        order = [order[(i + ibeg) % lend] for i in range(lend)]
        sxprint(TotalDistance(order, lccc))
        sxprint(order)

        snake = [tdummy]
        for i in range(1, lend):
            #  Here we need transformation from the current to the previous,
            #     meaning order[i] -> order[i-1]]
            T = lccc[mono(order[i], order[i - 1])][1]
            #  If direction is from larger to smaller index, the transformation has to be inverted
            if (order[i] > order[i - 1]): T = T.inverse()
            snake.append(T)
        assert (len(snake) == lend)
        from copy import deepcopy
        trans = deepcopy(snake)
        for k in range(lend - 2, 0, -1):
            T = snake[k]
            for i in range(k + 1, lend):
                trans[i] = T * trans[i]

        #  Try to smooth the angles - complicated, I am afraid one would have to use angles forward and backwards
        #     and find their average??
        #  In addition, one would have to recenter them
        """
		trms = []
		for m in xrange(lend):
			prms = trans[m].get_params("2D")
			trms.append([prms["alpha"], prms["mirror"]])
		for i in xrange(3):
			for m in xrange(lend):
				mb = (m-1)%lend
				me = (m+1)%lend
				#  angles order mb,m,me
				# calculate predicted angles mb->m 
		"""

        best_params = []
        for m in range(lend):
            prms = trans[m].get_params("2D")
            # rot_shift2D(d[order[m]], prms["alpha"], prms["tx"], prms["ty"], prms["mirror"]).write_image("metro.hdf", m)
            rot_shift2D(d[order[m]], prms["alpha"], 0.0, 0.0,
                        prms["mirror"]).write_image(args[2], m)
            best_params.append(
                [m, order[m], prms["alpha"], 0.0, 0.0, prms["mirror"]])

        # Write alignment parameters
        outdir = os.path.dirname(args[2])
        aligndoc = os.path.join(outdir, "chains_params.txt")
        write_text_row(best_params, aligndoc)
        """
		#  This was an effort to get number of loops, inconclusive, to say the least
		from numpy import outer, zeros, float32, sqrt
		lend = len(d)
 		cor = zeros(lend,float32)
 		cor = outer(cor, cor)
		for i in xrange(lend):  cor[i][i] = 1.0
		for i in xrange(lend-1):
			for j in xrange(i+1, lend):
				cor[i,j] = lccc[mono(i,j)][0]
				cor[j,i] = cor[i,j]

		lmbd, eigvec = pca(cor)

		from sp_utilities import write_text_file

		nvec=20
		print  [lmbd[j] for j in xrange(nvec)]
		print  " G"
		mm = [-1]*lend
		for i in xrange(lend):  # row
			mi = -1.0e23
			for j in xrange(nvec):
				qt = eigvec[j][i]
				if(abs(qt)>mi):
					mi = abs(qt)
					mm[i] = j
			for j in xrange(nvec):
				qt = eigvec[j][i]
				print  round(qt,3),   #  eigenvector
			print  mm[i]
		print
		for j in xrange(nvec):
			qt = []
			for i in xrange(lend):
				if(mm[i] == j):  qt.append(i)
			if(len(qt)>0):  write_text_file(qt,"loop%02d.txt"%j)
		"""
        """
		print  [lmbd[j] for j in xrange(nvec)]
		print  " B"
		mm = [-1]*lend
		for i in xrange(lend):  # row
			mi = -1.0e23
			for j in xrange(nvec):
				qt = eigvec[j][i]/sqrt(lmbd[j])
				if(abs(qt)>mi):
					mi = abs(qt)
					mm[i] = j
			for j in xrange(nvec):
				qt = eigvec[j][i]/sqrt(lmbd[j])
				print  round(qt,3),   #  eigenvector
			print  mm[i]
		print
		"""
        """
		lend=3
 		cor = zeros(lend,float32)

 		cor = outer(cor, cor)


 		cor[0][0] =136.77
 		cor[0][1] = 79.15
 		cor[0][2] = 37.13

 		cor[1][0] = 79.15
 		cor[2][0] = 37.13


 		cor[1][1] = 50.04
 		cor[1][2] = 21.65

 		cor[2][1] = 21.65


 		cor[2][2] = 13.26

		lmbd, eigvec = pca(cor)
		print  lmbd
		print  eigvec
		for i in xrange(lend):  # row
			for j in xrange(lend):  print  eigvec[j][i],   #  eigenvector
			print
		print  " B"
		for i in xrange(lend):  # row
			for j in xrange(lend):  print  eigvec[j][i]/sqrt(lmbd[j]),   #  eigenvector
			print
		print  " G"
		for i in xrange(lend):  # row
			for j in xrange(lend):  print  eigvec[j][i]*sqrt(lmbd[j]),   #  eigenvector
			print
		"""
    else:
        nargs = len(args)
        if nargs != 2:
            ERROR("Must provide name of input and output file!")
            return

        from sp_utilities import get_params2D, model_circle
        from sp_fundamentals import rot_shift2D
        from sp_statistics import ccc
        from time import time
        from sp_alignment import align2d

        stack = args[0]
        new_stack = args[1]

        d = EMData.read_images(stack)
        try:
            sxprint("Using 2D alignment parameters from header.")
            ttt = d[0].get_attr('xform.params2d')
            for i in range(len(d)):
                alpha, sx, sy, mirror, scale = get_params2D(d[i])
                d[i] = rot_shift2D(d[i], alpha, sx, sy, mirror)
        except:
            pass

        nx = d[0].get_xsize()
        ny = d[0].get_ysize()
        if options.radius < 1:
            radius = old_div(nx, 2) - 2
        else:
            radius = options.radius
        mask = model_circle(radius, nx, ny)

        init = options.initial

        if init > -1:
            sxprint("      initial image: %d" % init)
            temp = d[init].copy()
            temp.write_image(new_stack, 0)
            del d[init]
            k = 1
            lsum = 0.0
            while len(d) > 1:
                maxcit = -111.
                for i in range(len(d)):
                    cuc = ccc(d[i], temp, mask)
                    if cuc > maxcit:
                        maxcit = cuc
                        qi = i
                # 	sxprint k, maxcit
                lsum += maxcit
                temp = d[qi].copy()
                del d[qi]
                temp.write_image(new_stack, k)
                k += 1
            sxprint(lsum)
            d[0].write_image(new_stack, k)
        else:
            if options.circular:
                sxprint("Using options.circular, no alignment")
                #  figure the "best circular" starting image
                maxsum = -1.023
                for m in range(len(d)):
                    indc = list(range(len(d)))
                    lsnake = [-1] * (len(d) + 1)
                    lsnake[0] = m
                    lsnake[-1] = m
                    del indc[m]
                    temp = d[m].copy()
                    lsum = 0.0
                    direction = +1
                    k = 1
                    while len(indc) > 1:
                        maxcit = -111.
                        for i in range(len(indc)):
                            cuc = ccc(d[indc[i]], temp, mask)
                            if cuc > maxcit:
                                maxcit = cuc
                                qi = indc[i]
                        lsnake[k] = qi
                        lsum += maxcit
                        del indc[indc.index(qi)]
                        direction = -direction
                        for i in range(1, len(d)):
                            if (direction > 0):
                                if (lsnake[i] == -1):
                                    temp = d[lsnake[i - 1]].copy()
                                    # print  "  forw  ",lsnake[i-1]
                                    k = i
                                    break
                            else:
                                if (lsnake[len(d) - i] == -1):
                                    temp = d[lsnake[len(d) - i + 1]].copy()
                                    # print  "  back  ",lsnake[len(d) - i +1]
                                    k = len(d) - i
                                    break

                    lsnake[lsnake.index(-1)] = indc[-1]
                    # print  " initial image and lsum  ",m,lsum
                    # print lsnake
                    if (lsum > maxsum):
                        maxsum = lsum
                        init = m
                        snake = [lsnake[i] for i in range(len(d))]
                sxprint("  Initial image selected : ", init, maxsum)
                sxprint(lsnake)
                for m in range(len(d)):
                    d[snake[m]].write_image(new_stack, m)
            else:
                #  figure the "best" starting image
                sxprint("Straight chain, no alignment")
                maxsum = -1.023
                for m in range(len(d)):
                    indc = list(range(len(d)))
                    lsnake = [m]
                    del indc[m]
                    temp = d[m].copy()
                    lsum = 0.0
                    while len(indc) > 1:
                        maxcit = -111.
                        for i in range(len(indc)):
                            cuc = ccc(d[indc[i]], temp, mask)
                            if cuc > maxcit:
                                maxcit = cuc
                                qi = indc[i]
                        lsnake.append(qi)
                        lsum += maxcit
                        temp = d[qi].copy()
                        del indc[indc.index(qi)]

                    lsnake.append(indc[-1])
                    # sxprint  " initial image and lsum  ",m,lsum
                    # sxprint lsnake
                    if (lsum > maxsum):
                        maxsum = lsum
                        init = m
                        snake = [lsnake[i] for i in range(len(d))]
                sxprint("  Initial image selected : ", init, maxsum)
                sxprint(lsnake)
                for m in range(len(d)):
                    d[snake[m]].write_image(new_stack, m)