Exemple #1
0
def multi_align_stability(ali_params,
                          mir_stab_thld=0.0,
                          grp_err_thld=10000.0,
                          err_thld=1.732,
                          print_individual=False,
                          d=64):
    def sqerr(a):
        n = len(a)
        avg = sum(a)
        sq = 0.0
        for i in range(n):
            sq += a[i]**2
        return (sq - avg * avg / n) / n

    # args - G, data - [T, d]
    def func(args, data, return_avg_pixel_error=True):

        pass  #IMPORTIMPORTIMPORT from math import pi, sin, cos, radians, degrees

        ali_params = data[0]
        d = data[1]
        L = len(ali_params)
        N = len(ali_params[0]) / 4

        args_list = [0.0] * (L * 3)
        for i in range(L * 3 - 3):
            args_list[i] = args[i]

        cosa = [0.0] * L
        sina = [0.0] * L
        for i in range(L):
            cosa[i] = numpy.cos(numpy.radians(args_list[i * 3]))
            sina[i] = numpy.sin(numpy.radians(args_list[i * 3]))
        sqr_pixel_error = [0.0] * N
        ave_params = []
        for i in range(N):
            sum_cosa = 0.0
            sum_sina = 0.0
            sx = [0.0] * L
            sy = [0.0] * L
            for j in range(L):
                if int(ali_params[j][i * 4 + 3]) == 0:
                    sum_cosa += numpy.cos(
                        numpy.radians(args_list[j * 3] + ali_params[j][i * 4]))
                    sum_sina += numpy.sin(
                        numpy.radians(args_list[j * 3] + ali_params[j][i * 4]))
                    sx[j] = args_list[j * 3 + 1] + ali_params[j][
                        i * 4 + 1] * cosa[j] + ali_params[j][i * 4 +
                                                             2] * sina[j]
                    sy[j] = args_list[j * 3 + 2] - ali_params[j][
                        i * 4 + 1] * sina[j] + ali_params[j][i * 4 +
                                                             2] * cosa[j]
                else:
                    sum_cosa += numpy.cos(
                        numpy.radians(-args_list[j * 3] +
                                      ali_params[j][i * 4]))
                    sum_sina += numpy.sin(
                        numpy.radians(-args_list[j * 3] +
                                      ali_params[j][i * 4]))
                    sx[j] = -args_list[j * 3 + 1] + ali_params[j][
                        i * 4 + 1] * cosa[j] - ali_params[j][i * 4 +
                                                             2] * sina[j]
                    sy[j] = args_list[j * 3 + 2] + ali_params[j][
                        i * 4 + 1] * sina[j] + ali_params[j][i * 4 +
                                                             2] * cosa[j]
            sqrtP = numpy.sqrt(sum_cosa**2 + sum_sina**2)
            sqr_pixel_error[i] = max(
                0.0, d * d / 4. * (1 - sqrtP / L) + sqerr(sx) + sqerr(sy))
            # Get ave transform params
            H = EMAN2_cppwrap.Transform({"type": "2D"})
            H.set_matrix([
                sum_cosa / sqrtP, sum_sina / sqrtP, 0.0,
                sum(sx) / L, -sum_sina / sqrtP, sum_cosa / sqrtP, 0.0,
                sum(sy) / L, 0.0, 0.0, 1.0, 0.0
            ])
            dd = H.get_params("2D")
            #  We are using here mirror of the LAST SET.
            H = EMAN2_cppwrap.Transform({
                "type":
                "2D",
                "alpha":
                dd["alpha"],
                "tx":
                dd["tx"],
                "ty":
                dd["ty"],
                "mirror":
                int(ali_params[L - 1][i * 4 + 3]),
                "scale":
                1.0
            })
            dd = H.get_params("2D")
            ave_params.append([dd["alpha"], dd["tx"], dd["ty"], dd["mirror"]])
        # Warning: Whatever I return here is squared pixel error, this is for the easy expression of derivative
        # Don't forget to square root it after getting the value
        if return_avg_pixel_error:
            return sum(sqr_pixel_error) / N
        else:
            return sqr_pixel_error, ave_params

    """Multiline Comment0"""
    #MULTILINEMULTILINEMULTILINE 0

    #MULTILINEMULTILINEMULTILINE 0
    #MULTILINEMULTILINEMULTILINE 0

    #MULTILINEMULTILINEMULTILINE 0

    #MULTILINEMULTILINEMULTILINE 0
    #MULTILINEMULTILINEMULTILINE 0

    #MULTILINEMULTILINEMULTILINE 0
    #MULTILINEMULTILINEMULTILINE 0

    #MULTILINEMULTILINEMULTILINE 0
    #MULTILINEMULTILINEMULTILINE 0
    #MULTILINEMULTILINEMULTILINE 0
    #MULTILINEMULTILINEMULTILINE 0
    #MULTILINEMULTILINEMULTILINE 0
    #MULTILINEMULTILINEMULTILINE 0
    #MULTILINEMULTILINEMULTILINE 0
    #MULTILINEMULTILINEMULTILINE 0
    #MULTILINEMULTILINEMULTILINE 0
    #MULTILINEMULTILINEMULTILINE 0
    #MULTILINEMULTILINEMULTILINE 0
    #MULTILINEMULTILINEMULTILINE 0
    #MULTILINEMULTILINEMULTILINE 0
    #MULTILINEMULTILINEMULTILINE 0
    #MULTILINEMULTILINEMULTILINE 0
    #MULTILINEMULTILINEMULTILINE 0
    #MULTILINEMULTILINEMULTILINE 0
    #MULTILINEMULTILINEMULTILINE 0
    #MULTILINEMULTILINEMULTILINE 0
    #MULTILINEMULTILINEMULTILINE 0
    #MULTILINEMULTILINEMULTILINE 0
    #MULTILINEMULTILINEMULTILINE 0
    #MULTILINEMULTILINEMULTILINE 0
    #MULTILINEMULTILINEMULTILINE 0
    #MULTILINEMULTILINEMULTILINE 0
    #MULTILINEMULTILINEMULTILINE 0

    #MULTILINEMULTILINEMULTILINE 0
    #MULTILINEMULTILINEMULTILINE 0
    #MULTILINEMULTILINEMULTILINE 0
    #MULTILINEMULTILINEMULTILINE 0
    #MULTILINEMULTILINEMULTILINE 0
    #MULTILINEMULTILINEMULTILINE 0
    #MULTILINEMULTILINEMULTILINE 0
    #MULTILINEMULTILINEMULTILINE 0
    #MULTILINEMULTILINEMULTILINE 0
    #MULTILINEMULTILINEMULTILINE 0
    #MULTILINEMULTILINEMULTILINE 0
    #MULTILINEMULTILINEMULTILINE 0
    #MULTILINEMULTILINEMULTILINE 0
    #MULTILINEMULTILINEMULTILINE 0
    #MULTILINEMULTILINEMULTILINE 0
    #MULTILINEMULTILINEMULTILINE 0
    #MULTILINEMULTILINEMULTILINE 0
    #MULTILINEMULTILINEMULTILINE 0
    #MULTILINEMULTILINEMULTILINE 0
    #MULTILINEMULTILINEMULTILINE 0
    #MULTILINEMULTILINEMULTILINE 0
    #MULTILINEMULTILINEMULTILINE 0
    #MULTILINEMULTILINEMULTILINE 0

    #MULTILINEMULTILINEMULTILINE 0
    #MULTILINEMULTILINEMULTILINE 0

    pass  #IMPORTIMPORTIMPORT from sp_statistics import k_means_stab_bbenum
    pass  #IMPORTIMPORTIMPORT from sp_utilities import combine_params2
    pass  #IMPORTIMPORTIMPORT from numpy import array
    pass  #IMPORTIMPORTIMPORT from math import sqrt

    # I decided not to use scipy in order to reduce the dependency, I wrote the C++ code instead
    # from scipy import array, int32
    # from scipy.optimize.lbfgsb import fmin_l_bfgs_b

    # Find out the subset which is mirror stable over all runs
    all_part = []
    num_ali = len(ali_params)
    nima = len(ali_params[0]) / 4
    for i in range(num_ali):
        mirror0 = []
        mirror1 = []
        for j in range(nima):
            if ali_params[i][j * 4 + 3] == 0: mirror0.append(j)
            else: mirror1.append(j)
        mirror0 = numpy.array(mirror0, 'int32')
        mirror1 = numpy.array(mirror1, 'int32')
        all_part.append([mirror0, mirror1])
    match, stab_part, CT_s, CT_t, ST, st = sp_statistics.k_means_stab_bbenum(
        all_part, T=0, nguesses=1)
    mir_stab_part = stab_part[0] + stab_part[1]

    mir_stab_rate = len(mir_stab_part) / float(nima)
    if mir_stab_rate <= mir_stab_thld: return [], mir_stab_rate, -1.0
    mir_stab_part.sort()

    del all_part, match, stab_part, CT_s, CT_t, ST, st

    #for i in xrange(len(mir_stab_part)):  print i, mir_stab_part[i]

    nima2 = len(mir_stab_part)

    #print "  mirror stable  ",nima2

    # Keep the alignment parameters of mirror stable particles
    ali_params_mir_stab = [[] for i in range(num_ali)]
    for i in range(num_ali):
        for j in mir_stab_part:
            ali_params_mir_stab[i].extend(ali_params[i][j * 4:j * 4 + 4])
    # Find out the alignment parameters for each iteration against the last one
    args = []
    for i in range(num_ali - 1):
        alpha, sx, sy, mirror = align_diff_params(
            ali_params_mir_stab[i], ali_params_mir_stab[num_ali - 1])
        args.extend([alpha, sx, sy])

    # Do an initial analysis, purge all outlier particles, whose pixel error are larger than three times the threshold
    data = [ali_params_mir_stab, d]
    pixel_error_before, ave_params = func(numpy.array(args),
                                          data,
                                          return_avg_pixel_error=False)
    #  We have to return mir_stab_rate (see above), even if the group does not survive it and the pixel error before cleaning,
    #   see below, and in ISAC print the user the overall statistics (histograms?)
    #   so one can get on overall idea how good/bad data is.  PAP  01/25/2015
    #print  " >>> ",sqrt(sum(pixel_error_before)/nima2)
    ali_params_cleaned = [[] for i in range(num_ali)]
    cleaned_part = []
    for j in range(nima2):
        pixel_error_before[j] = max(0.0,
                                    pixel_error_before[j])  # prevent sqrt of 0
        if numpy.sqrt(pixel_error_before[j]) > 3 * err_thld:
            pass  #print "  removed ",3*err_thld,j,sqrt(pixel_error_before[j])
        else:
            cleaned_part.append(mir_stab_part[j])
            for i in range(num_ali):
                ali_params_cleaned[i].extend(
                    ali_params_mir_stab[i][j * 4:j * 4 + 4])
    nima3 = len(cleaned_part)
    prever = numpy.sqrt(sum(pixel_error_before) / nima2)
    if nima3 <= 1: return [], mir_stab_rate, prever

    #print "  cleaned part  ",nima3

    # Use LBFGSB to minimize the sum of pixel errors
    data = [ali_params_cleaned, d]
    # Use Python code
    #ps_lp, val, d = fmin_l_bfgs_b(func, array(args), args=[data], fprime=dfunc, bounds=None, m=10, factr=1e3, pgtol=1e-4, iprint=-1, maxfun=100)
    # Use C++ code
    ali_params_cleaned_list = []
    for params in ali_params_cleaned:
        ali_params_cleaned_list.extend(params)
    results = EMAN2_cppwrap.Util.multi_align_error(args,
                                                   ali_params_cleaned_list, d)
    ps_lp = results[:-1]

    # Negative val can happen in some rare cases, it should be due to rounding errors,
    # because all results show the val is about 1e-13.
    #print "Strange results"
    #print "args =", args
    #print "ali_params_cleaned_list =", ali_params_cleaned_list
    #print "results = ", results
    val = max(0.0, results[-1])

    del ali_params_cleaned_list

    if numpy.sqrt(val) > grp_err_thld:
        return [], mir_stab_rate, numpy.sqrt(val)

    pixel_error_after, ave_params = func(ps_lp,
                                         data,
                                         return_avg_pixel_error=False)

    stable_set = []
    val = 0.0
    for i in range(nima):
        if i in cleaned_part:
            j = cleaned_part.index(i)
            err = numpy.sqrt(pixel_error_after[j])
            if err < err_thld:
                stable_set.append([err, i, ave_params[j]])
                val += err
                if print_individual:
                    sp_global_def.sxprint(
                        "Particle %4d :  pixel error = %18.4f" % (i, err))
            else:
                if print_individual:
                    sp_global_def.sxprint(
                        "Particle %4d :  pixel error = %18.4f  unstable" %
                        (i, err))
        else:
            if print_individual:
                sp_global_def.sxprint("Particle %4d :  Mirror unstable" % i)
    #  return average pixel error before pruning as it is more informative
    return stable_set, mir_stab_rate, prever  # sqrt(val/len(stable_set))
Exemple #2
0
def main():
    progname = os.path.basename(sys.argv[0])
    usage = progname + " averages1 averages2 --th_grp"
    parser = OptionParser(usage, version=SPARXVERSION)
    parser.add_option("--T",
                      type="int",
                      default=0,
                      help=" Threshold for matching")
    parser.add_option("--J", type="int", default=50, help=" J")
    parser.add_option("--max_branching",
                      type="int",
                      default=40,
                      help=" maximum branching")
    parser.add_option("--verbose",
                      action="store_true",
                      default=False,
                      help=" Threshold for matching")
    parser.add_option("--timing",
                      action="store_true",
                      default=False,
                      help=" Get the timing")

    (options, args) = parser.parse_args()

    if sp_global_def.CACHE_DISABLE:
        from sp_utilities import disable_bdb_cache
        disable_bdb_cache()

    sp_global_def.BATCH = True

    from numpy import array
    from sp_statistics import k_means_stab_bbenum

    R = len(args)
    Parts = []
    mem = [0] * R
    avg = [0] * R
    for r in range(R):
        data = EMData.read_images(args[r])
        avg[r] = len(data)

        part = []
        for k in range(len(data)):
            lid = data[k].get_attr('members')
            mem[r] += len(lid)
            lid = array(lid, 'int32')
            lid.sort()
            part.append(lid.copy())
        Parts.append(part)

    if options.timing:
        from time import time
        time1 = time()

    MATCH, STB_PART, CT_s, CT_t, ST, st = k_means_stab_bbenum(
        Parts,
        T=options.T,
        J=options.J,
        max_branching=options.max_branching,
        stmult=0.1,
        branchfunc=2)

    if options.verbose:
        sxprint(MATCH)
        sxprint(STB_PART)
        sxprint(CT_s)
        sxprint(CT_t)
        sxprint(ST)
        sxprint(st)
        sxprint(" ")

    for i in range(len(MATCH)):
        u = MATCH[i][0]  # u is the group in question in partition 1
        assert len(STB_PART[u]) == CT_s[u]
        sxprint("Group ", end=' ')
        for r in range(R):
            sxprint("%3d " % (MATCH[i][r]), end=' ')
        sxprint(" matches:   group size = ", end=' ')
        for r in range(R):
            sxprint(" %3d" % len(Parts[r][MATCH[i][r]]), end=' ')
        sxprint("     matched size = %3d" % (CT_s[u]), end=' ')
        if options.verbose:
            sxprint("   matched group = %s" % (STB_PART[u]))
        else:
            sxprint("")

    sxprint("\nNumber of averages = ", end=' ')
    for r in range(R):
        sxprint("%3d" % (avg[r]), end=' ')
    sxprint("\nTotal number of particles = ", end=' ')
    for r in range(R):
        sxprint("%3d" % (mem[r]), end=' ')
    sxprint("     number of matched particles = %5d" % (sum(CT_s)))

    if options.timing:
        sxprint("Elapsed time = ", time() - time1)

    sp_global_def.BATCH = False