예제 #1
0
    def evaluateError(cls, w, coeff_field, pde, f, zeta, gamma, ceta, cQ, newmi_add_maxm, maxh=0.1, quadrature_degree= -1, projection_degree_increase=1, refine_projection_mesh=1):
        """Evaluate EGSZ Error (7.5)."""
        logger.debug("starting evaluateError")

        # define store function for timings
        from functools import partial
        def _store_stats(val, key, stats):
            stats[key] = val

        timing_stats = {}
        with timing(msg="ResidualEstimator.evaluateResidualEstimator", logfunc=logger.info, store_func=partial(_store_stats, key="TIME-RESIDUAL", stats=timing_stats)):
            resind, reserror = ResidualEstimator.evaluateResidualEstimator(w, coeff_field, pde, f, quadrature_degree)

        logger.debug("starting evaluateProjectionEstimator")
        with timing(msg="ResidualEstimator.evaluateProjectionError", logfunc=logger.info, store_func=partial(_store_stats, key="TIME-PROJECTION", stats=timing_stats)):
            projind, projerror = ResidualEstimator.evaluateProjectionError(w, coeff_field, pde, maxh, True, projection_degree_increase, refine_projection_mesh)

        logger.debug("starting evaluateInactiveProjectionError")
        with timing(msg="ResidualEstimator.evaluateInactiveMIProjectionError", logfunc=logger.info, store_func=partial(_store_stats, key="TIME-INACTIVE-MI", stats=timing_stats)):
            mierror = ResidualEstimator.evaluateInactiveMIProjectionError(w, coeff_field, pde, maxh, newmi_add_maxm) 

        eta = sum(reserror[mu] ** 2 for mu in reserror)
        delta = sum(projerror[mu] ** 2 for mu in projerror)
        delta_inactive_mi = sum(v[1] ** 2 for v in mierror)
        est1 = ceta / sqrt(1 - gamma) * sqrt(eta)
        est2 = cQ / sqrt(1 - gamma) * sqrt(delta + delta_inactive_mi)
        est3 = cQ * sqrt(zeta / (1 - gamma))
        est4 = zeta / (1 - gamma)
#        xi = (ceta / sqrt(1 - gamma) * sqrt(eta) + cQ / sqrt(1 - gamma) * sqrt(delta)
#              + cQ * sqrt(zeta / (1 - gamma))) ** 2 + zeta / (1 - gamma)
        xi = (est1 + est2 + est3) ** 2 + est4
        logger.info("Total Residual ERROR Factors: A1=%s  A2=%s  A3=%s  A4=%s", ceta / sqrt(1 - gamma), cQ / sqrt(1 - gamma), cQ * sqrt(zeta / (1 - gamma)), zeta / (1 - gamma))
        return (xi, resind, projind, mierror, (est1, est2, est3, est4), (eta, delta, zeta), timing_stats)
예제 #2
0
파일: sampling.py 프로젝트: SpuqTeam/spuq
def get_projected_solution(w, mu, proj_basis):
    # TODO: obfuscated method call since project is not obvious in the interface of MultiVector! This should be separated more clearly!
#    print "sampling.get_projected_solution"
#    print w[mu].num_sub_spaces
#    print proj_basis.num_sub_spaces
    with timing(msg="get_projected_solution (%s --- %i)" % (str(mu), w[mu].dim), logfunc=logger.debug):
        w_proj = w.project(w[mu], proj_basis)
    return w_proj
예제 #3
0
파일: sampling.py 프로젝트: SpuqTeam/spuq
def compute_direct_sample_solution(pde, RV_samples, coeff_field, A, maxm, proj_basis, cache=None):
    try:
        A0 = cache.A
        A_m = cache.A_m
        b = cache.b
        logger.debug("compute_direct_sample_solution: CACHE USED")
        print "CACHE USED"
    except AttributeError:
        with timing(msg="direct_sample_sol: compute A_0, b", logfunc=logger.info):
            a = coeff_field.mean_func
            A0 = pde.assemble_lhs(basis=proj_basis, coeff=a, withDirichletBC=False)
            b = pde.assemble_rhs(basis=proj_basis, coeff=a, withDirichletBC=False)
            A_m = [None] * maxm
            logger.debug("compute_direct_sample_solution: CACHE NOT USED")
            print "CACHE NOT USED"
        if cache is not None:
            cache.A = A0
            cache.A_m = A_m
            cache.b = b

    with timing(msg="direct_sample_sol: compute A_m", logfunc=logger.info):
        A = A0.copy()
        for m in range(maxm):
            if A_m[m] is None:
                a_m = coeff_field[m][0]
                A_m[m] = pde.assemble_lhs(basis=proj_basis, coeff=a_m, withDirichletBC=False)
            A += RV_samples[m] * A_m[m]

    with timing(msg="direct_sample_sol: apply BCs", logfunc=logger.info):
        A, b = pde.apply_dirichlet_bc(proj_basis._fefs, A, b)

    with timing(msg="direct_sample_sol: solve linear system", logfunc=logger.info):
        X = 0 * b
        logger.info("compute_direct_sample_solution with %i dofs" % b.size())
        solve(A, X, b)
    return FEniCSVector(Function(proj_basis._fefs, X))
예제 #4
0
파일: sampling.py 프로젝트: SpuqTeam/spuq
def compute_parametric_sample_solution(RV_samples, coeff_field, w, proj_basis, cache=None):
    with timing(msg="parametric_sample_sol", logfunc=logger.info):
        Lambda = w.active_indices()
        sample_map, _ = coeff_field.sample_realization(Lambda, RV_samples)
        # sum up (stochastic) solution vector on reference function space wrt samples

        if cache is None:
            sample_sol = sum(get_projected_solution(w, mu, proj_basis) * sample_map[mu] for mu in Lambda)
        else:
            try:
                projected_sol = cache.projected_sol
            except AttributeError:
                projected_sol = {mu: get_projected_solution(w, mu, proj_basis) for mu in Lambda}
                cache.projected_sol = projected_sol
            sample_sol = sum(projected_sol[mu] * sample_map[mu] for mu in Lambda)
    return sample_sol
예제 #5
0
def AdaptiveSolver(A, coeff_field, pde,
                    mis, w0, mesh0, degree,
                    gamma=0.9,
                    cQ=1.0,
                    ceta=6.0,
                    # marking parameters
                    theta_eta=0.4, # residual marking bulk parameter
                    theta_zeta=0.1, # projection marking threshold factor
                    min_zeta=1e-8, # minimal projection error to be considered 
                    maxh=0.1, # maximal mesh width for projection maximum norm evaluation
                    newmi_add_maxm=20, # maximal search length for new new multiindices (to be added to max order of solution w)
                    theta_delta=10.0, # number new multiindex activation bound
                    max_Lambda_frac=1 / 10, # max fraction of |Lambda| for new multiindices
                    marking_strategy="SEPARATE with CELLPROJECTION", # separate (as initially in EGSZ) or relative marking wrt overall error, projection refinement based on cell or mesh errors
                    # residual error
                    quadrature_degree= -1,
                    # projection error
                    projection_degree_increase=1,
                    refine_projection_mesh=1,
                    # pcg solver
                    pcg_eps=1e-6,
                    pcg_maxiter=100,
                    # adaptive algorithm threshold
                    error_eps=1e-2,
                    # refinements
                    max_refinements=5,
                    max_dof=1e10,
                    do_refinement={"RES":True, "PROJ":True, "MI":False},
                    do_uniform_refinement=False,
                    w_history=None,
                    sim_stats=None):
    
    # define store function for timings
    from functools import partial
    def _store_stats(val, key, stats):
        stats[key] = val

    # define tuple type        
    EstimatorData = namedtuple('EstimatorData', ['xi', 'gamma', 'cQ', 'ceta'])
    
    # get rhs
    f = pde.f

    # setup w and statistics
    w = w0
    if sim_stats is None:
        assert w_history is None or len(w_history) == 0
        sim_stats = []

    try:
        start_iteration = max(len(sim_stats) - 1, 0)
    except:
        start_iteration = 0
    logger.info("START/CONTINUE EXPERIMENT at iteration %i", start_iteration)

    # data collection
    import resource
    refinement = None
    for refinement in range(start_iteration, max_refinements + 1):
        logger.info("************* REFINEMENT LOOP iteration %i (of %i or max_dof %i) *************", refinement, max_refinements, max_dof)
        # memory usage info
        logger.info("\n======================================\nMEMORY USED: " + str(resource.getrusage(resource.RUSAGE_SELF).ru_maxrss) + "\n======================================\n")

        # pcg solve
        # ---------
        stats = {}
        with timing(msg="pcg_solve", logfunc=logger.info, store_func=partial(_store_stats, key="TIME-PCG", stats=stats)):
            w, zeta = pcg_solve(A, w, coeff_field, pde, stats, pcg_eps, pcg_maxiter)

        logger.info("DIM of w = %s", w.dim)
        if w_history is not None and (refinement == 0 or start_iteration < refinement):
            w_history.append(w)

        # error evaluation
        # ----------------
        # residual and projection errors
        logger.debug("evaluating ResidualEstimator.evaluateError")
        with timing(msg="ResidualEstimator.evaluateError", logfunc=logger.info, store_func=partial(_store_stats, key="TIME-ESTIMATOR", stats=stats)):
            xi, resind, projind, mierror, estparts, errors, timing_stats = ResidualEstimator.evaluateError(w, coeff_field, pde, f, zeta, gamma, ceta, cQ,
                                                                                    newmi_add_maxm, maxh, quadrature_degree, projection_degree_increase,
                                                                                    refine_projection_mesh)
        reserrmu = [(mu, sqrt(sum(resind[mu].coeffs ** 2))) for mu in resind.keys()]
        projerrmu = [(mu, sqrt(sum(projind[mu].coeffs ** 2))) for mu in projind.keys()]
        res_part, proj_part, pcg_part = estparts[0], estparts[1], estparts[2]
        err_res, err_proj, err_pcg = errors[0], errors[1], errors[2]
        logger.info("Overall Estimator Error xi = %s while residual error is %s, projection error is %s, pcg error is %s", xi, res_part, proj_part, pcg_part)
        
        stats.update(timing_stats)
        stats["EST"] = xi
        stats["RES-PART"] = res_part
        stats["PROJ-PART"] = proj_part
        stats["PCG-PART"] = pcg_part
        stats["ERR-RES"] = err_res
        stats["ERR-PROJ"] = err_proj
        stats["ERR-PCG"] = err_pcg
        stats["ETA-ERR"] = errors[0]
        stats["DELTA-ERR"] = errors[1]
        stats["ZETA-ERR"] = errors[2]
        stats["RES-mu"] = reserrmu
        stats["PROJ-mu"] = projerrmu
        stats["PROJ-MAX-ZETA"] = 0
        stats["PROJ-MAX-INACTIVE-ZETA"] = 0
        stats["MARKING-RES"] = 0
        stats["MARKING-PROJ"] = 0
        stats["MARKING-MI"] = 0
        stats["TIME-MARKING"] = 0
        stats["MI"] = [(mu, vec.basis.dim) for mu, vec in w.iteritems()]
        if refinement == 0 or start_iteration < refinement:
            sim_stats.append(stats)            
#            print "SIM_STATS:", sim_stats[refinement]
        
        logger.debug("squared error components: eta=%s  delta=%s  zeta=%", errors[0], errors[1], errors[2])

        # exit when either error threshold or max_refinements or max_dof is reached
        if refinement > max_refinements:
            logger.info("SKIPPING REFINEMENT after FINAL SOLUTION in ITERATION %i", refinement)
            break
        if sim_stats[refinement]["DOFS"] >= max_dof:
            logger.info("REACHED %i DOFS, EXITING refinement loop", sim_stats[refinement]["DOFS"])
            break
        if xi <= error_eps:
            logger.info("error reached requested accuracy, xi=%f", xi)
            break

        # marking
        # -------
        if refinement < max_refinements:
            if not do_uniform_refinement:        
                logger.debug("starting Marking.mark")
                estimator_data = EstimatorData(xi=xi, gamma=gamma, cQ=cQ, ceta=ceta) 
                mesh_markers_R, mesh_markers_P, new_multiindices, proj_zeta, new_multiindices_all = Marking.mark(resind, projind, mierror, w.max_order,
                                                                                theta_eta, theta_zeta, theta_delta,
                                                                                min_zeta, maxh, max_Lambda_frac,
                                                                                estimator_data, marking_strategy)
                sim_stats[-1]["PROJ-MAX-ZETA"] = proj_zeta[0]
                sim_stats[-1]["PROJ-MAX-INACTIVE-ZETA"] = proj_zeta[1]
                sim_stats[-1]["PROJ-INACTIVE-ZETA"] = new_multiindices_all
#                assert len(new_multiindices_all) == 0 or proj_zeta[1] == max([v for v in new_multiindices_all.values()])
                logger.info("PROJECTION error values: max_zeta = %s  and  max_inactive_zeta = %s  with threshold factor theta_zeta = %s  (=%s)",
                            proj_zeta[0], proj_zeta[1], theta_zeta, theta_zeta * proj_zeta[0])
                logger.info("MARKING will be carried out with %s (res) + %s (proj) cells and %s new multiindices",
                            sum([len(cell_ids) for cell_ids in mesh_markers_R.itervalues()]),
                            sum([len(cell_ids) for cell_ids in mesh_markers_P.itervalues()]), len(new_multiindices))
                stats["MARKING-RES"] = sum([len(cell_ids) for cell_ids in mesh_markers_R.itervalues()])
                stats["MARKING-PROJ"] = sum([len(cell_ids) for cell_ids in mesh_markers_P.itervalues()])
                stats["MARKING-MI"] = len(new_multiindices)
                if do_refinement["RES"]:
                    mesh_markers = mesh_markers_R.copy()
                else:
                    mesh_markers = defaultdict(set)
                    logger.info("SKIP residual refinement")
    
                if do_refinement["PROJ"]:
                    for mu, cells in mesh_markers_P.iteritems():
                        if len(cells) > 0:
                            mesh_markers[mu] = mesh_markers[mu].union(cells)
                else:
                    logger.info("SKIP projection refinement")
    
                if not do_refinement["MI"] or refinement == max_refinements:
                    new_multiindices = {}
                    logger.info("SKIP new multiindex refinement")
            else:
                logger.info("UNIFORM REFINEMENT active")
                mesh_markers = {}
                for mu, vec in w.iteritems():
                    from dolfin import cells
                    mesh_markers[mu] = list([c.index() for c in cells(vec._fefunc.function_space().mesh())])
                new_multiindices = {}
            
            # carry out refinement of meshes
            with timing(msg="Marking.refine", logfunc=logger.info, store_func=partial(_store_stats, key="TIME-MARKING", stats=stats)):
                Marking.refine(w, mesh_markers, new_multiindices.keys(), partial(setup_vector, pde=pde, mesh=mesh0, degree=degree))
    
    if refinement:
        logger.info("ENDED refinement loop after %i of %i refinements with %i dofs and %i active multiindices",
                    refinement, max_refinements, sim_stats[refinement]["DOFS"], len(sim_stats[refinement]["MI"]))

#    except Exception as ex:
#        import pickle
#        logger.error("EXCEPTION during AdaptiveSolver: %s", str(ex))
#        print "DIM of w:", w.dim
#        if not w_history is None:
#            w_history.append(w)
#        wname = "W-PCG-FAILED.pkl"
#        try:
#            with open(wname, 'wb') as fout:
#                pickle.dump(w, fout)
#        except Exception as ex:
#            logger.error("NEXT EXCEPTION %s", str(ex))
#        logger.info("exported last multivector w to %s", wname)
#    finally:
    return w, sim_stats
예제 #6
0
def run_mc(err, w, pde, A, coeff_field, mesh0, ref_maxm, MC_N, MC_HMAX, param_sol_cache=None, direct_sol_cache=None, stored_rv_samples=None, quadrature_degree = -1):
    # create reference mesh and function space
    sub_spaces = w[Multiindex()].basis.num_sub_spaces
    degree = w[Multiindex()].basis.degree
    projection_basis = get_projection_basis(mesh0, mesh_refinements=0, degree=degree, sub_spaces=sub_spaces)
    logger.info("projection_basis dim = %i \t hmin of mi[0] = %s, reference mesh = (%s, %s)", projection_basis.dim, w[Multiindex()].basis.minh, projection_basis.minh, projection_basis.maxh)

    # get realization of coefficient field
    err_L2, err_H1 = 0, 0

    # set quadrature degree
    if quadrature_degree > -1:
        quadrature_degree_old = parameters["form_compiler"]["quadrature_degree"]
        parameters["form_compiler"]["quadrature_degree"] = quadrature_degree
        logger.debug("MC error sampling quadrature order = " + str(quadrature_degree))

    # setup caches for sample solutions
#    param_sol_cache = None #param_sol_cache or MCCache()
#    direct_sol_cache = None #direct_sol_cache or MCCache()
    logger.info("---- MC caches %s/%s ----", param_sol_cache, direct_sol_cache)
    # main MC loop
    for i in range(MC_N):
        logger.info("---- MC Iteration %i/%i ----", i + 1 , MC_N)
        # create new samples if required or reuse existing samples
        if stored_rv_samples is None or len(stored_rv_samples) <= i:
            sample_rvs = coeff_field.sample_rvs()
            RV_samples = [sample_rvs[j] for j in range(max(w.max_order, ref_maxm))]
            if stored_rv_samples is not None:
                stored_rv_samples.append(RV_samples)
        else:
            RV_samples = stored_rv_samples[i]        
        logger.info("-- RV_samples: %s", RV_samples)
        # evaluate solutions
        with timing(msg="parameteric sample solution", logfunc=logger.info):
            sample_sol_param = compute_parametric_sample_solution(RV_samples, coeff_field, w, projection_basis, param_sol_cache)
        with timing(msg="direct sample solution", logfunc=logger.info):
            sample_sol_direct = compute_direct_sample_solution(pde, RV_samples, coeff_field, A, ref_maxm, projection_basis, direct_sol_cache)
        # evaluate errors
        with timing(msg="L2_err_1", logfunc=logger.info):
            cerr_L2 = error_norm(sample_sol_param._fefunc, sample_sol_direct._fefunc, "L2")
        with timing(msg="H1A_err_1", logfunc=logger.info):
            cerr_H1 = error_norm(sample_sol_param._fefunc, sample_sol_direct._fefunc, pde.energy_norm)
#        cerr_H1 = errornorm(sample_sol_param._fefunc, sample_sol_direct._fefunc, "H1")
        logger.debug("-- current error L2 = %s    H1A = %s", cerr_L2, cerr_H1)
        err_L2 += 1.0 / MC_N * cerr_L2
        err_H1 += 1.0 / MC_N * cerr_H1
        
        if i + 1 == MC_N:
            # deterministic part
            with timing(msg="direct a0", logfunc=logger.info):
                sample_sol_direct_a0 = compute_direct_sample_solution(pde, RV_samples, coeff_field, A, 0, projection_basis, direct_sol_cache)
            with timing(msg="L2_err_2", logfunc=logger.info):
                L2_a0 = error_norm(sample_sol_param._fefunc, sample_sol_direct_a0._fefunc, "L2")
            with timing(msg="H1A_err_2", logfunc=logger.info):
                H1_a0 = error_norm(sample_sol_param._fefunc, sample_sol_direct_a0._fefunc, pde.energy_norm)
#            H1_a0 = errornorm(sample_sol_param._fefunc, sample_sol_direct_a0._fefunc, "H1")
            logger.debug("-- DETERMINISTIC error L2 = %s    H1A = %s", L2_a0, H1_a0)

            # stochastic part
            sample_sol_direct_am = sample_sol_direct - sample_sol_direct_a0
            logger.debug("-- STOCHASTIC norm L2 = %s    H1 = %s", sample_sol_direct_am.norm("L2"), sample_sol_direct_am.norm("H1"))

    # restore quadrature degree
    if quadrature_degree > -1:
        parameters["form_compiler"]["quadrature_degree"] = quadrature_degree_old

    logger.info("MC Error: L2: %s, H1A: %s", err_L2, err_H1)
    err.append((err_L2, err_H1, L2_a0, H1_a0))
예제 #7
0
def AdaptiveSolver(A, coeff_field, pde,
                    mis, w0, mesh0, degree,
                    # marking parameters
                    rho=1.0, # tail factor
                    theta_x=0.4, # residual marking bulk parameter
                    theta_y=0.4, # tail bound marking bulk paramter
                    maxh=0.1, # maximal mesh width for coefficient maximum norm evaluation
                    add_maxm=100, # maximal search length for new new multiindices (to be added to max order of solution w)
                    # estimator
                    estimator_type = "RESIDUAL",
                    quadrature_degree= -1,
                    # pcg solver
                    pcg_eps=1e-6,
                    pcg_maxiter=100,
                    # adaptive algorithm threshold
                    error_eps=1e-2,
                    # refinements
                    max_refinements=5,
                    max_dof=1e10,
                    do_refinement={"RES":True, "TAIL":True, "OSC":True},
                    do_uniform_refinement=False,
                    refine_osc_factor=1.0,
                    w_history=None,
                    sim_stats=None):
    
    # define store function for timings
    def _store_stats(val, key, stats):
        stats[key] = val
    
    # get rhs
    f = pde.f

    # setup w and statistics
    w = w0
    if sim_stats is None:
        assert w_history is None or len(w_history) == 0
        sim_stats = []

    try:
        start_iteration = max(len(sim_stats) - 1, 0)
    except:
        start_iteration = 0
    logger.info("START/CONTINUE EXPERIMENT at iteration %i", start_iteration)

    # data collection
    import resource
    refinement = None
    for refinement in range(start_iteration, max_refinements + 1):
        logger.info("************* REFINEMENT LOOP iteration {0} (of {1} or max dofs {2}) *************".format(refinement, max_refinements, max_dof))
        # memory usage info
        logger.info("\n======================================\nMEMORY USED: " + str(resource.getrusage(resource.RUSAGE_SELF).ru_maxrss) + "\n======================================\n")

        # ---------
        # pcg solve
        # ---------
        
        stats = {}
        with timing(msg="pcg_solve", logfunc=logger.info, store_func=partial(_store_stats, key="TIME-PCG", stats=stats)):
            w, zeta = pcg_solve(A, w, coeff_field, pde, stats, pcg_eps, pcg_maxiter)

        logger.info("DIM of w = %s", w.dim)
        if w_history is not None and (refinement == 0 or start_iteration < refinement):
            w_history.append(w)

        # -------------------
        # evaluate estimators
        # -------------------
        
        # evaluate estimate_y
        logger.debug("evaluating upper tail bound")
        with timing(msg="ResidualEstimator.evaluateUpperTailBound", logfunc=logger.info, store_func=partial(_store_stats, key="TIME-TAIL", stats=stats)):
            global_zeta, zeta, zeta_bar, eval_zeta_m = ResidualEstimator.evaluateUpperTailBound(w, coeff_field, pde, maxh, add_maxm)

        # evaluate estimate_x
        if estimator_type.upper() == "RESIDUAL":
            # evaluate estimate_x
            logger.debug("evaluating residual bound (residual)")
            with timing(msg="ResidualEstimator.evaluateResidualEstimator", logfunc=logger.info, store_func=partial(_store_stats, key="TIME-RES", stats=stats)):
                global_eta, eta, eta_local = ResidualEstimator.evaluateResidualEstimator(w, coeff_field, pde, f, quadrature_degree)
        elif estimator_type.upper() == "EQUILIBRATION_GLOBAL":
            logger.debug("evaluating residual bound (global equilibration)")
            with timing(msg="GlobalEquilibrationEstimator.evaluateEstimator", logfunc=logger.info, store_func=partial(_store_stats, key="TIME-RES", stats=stats)):
                global_eta, eta, eta_local, osc_global, osc_local = GlobalEquilibrationEstimator.evaluateEstimator\
                                                                            (w, coeff_field, pde, f, quadrature_degree)
        elif estimator_type.upper() == "EQUILIBRATION_LOCAL":
            logger.debug("evaluating residual bound (global equilibration)")
            with timing(msg="GlobalEquilibrationEstimator.evaluateEstimator", logfunc=logger.info, store_func=partial(_store_stats, key="TIME-RES", stats=stats)):
                global_eta, eta, eta_local, osc_global, osc_local = LocalEquilibrationEstimator.evaluateEstimator\
                                                                            (w, coeff_field, pde, f, quadrature_degree)
        else:
            raise TypeError("invalid estimator type %s" %estimator_type.upper())

        # set overall error
        xi = sqrt(global_eta ** 2 + global_zeta ** 2)
        logger.info("Overall Estimator Error xi = %s while spatial error is %s and tail error is %s", xi, global_eta, global_zeta)

        # store simulation data
        stats["ERROR-EST"] = xi
        stats["ERROR-RES"] = global_eta
        stats["ERROR-TAIL"] = global_zeta
        stats["MARKING-RES"] = 0
        stats["MARKING-MI"] = 0
#        stats["MARKING-OSC"] = 0
        stats["CADELTA"] = 0
        stats["TIME-MARK-RES"] = 0
        stats["TIME-REFINE-RES"] = 0
        stats["TIME-MARK-TAIL"] = 0
        stats["TIME-REFINE-TAIL"] = 0
        stats["TIME-REFINE-OSC"] = 0
        stats["MI"] = [mu for mu in w.active_indices()]
        stats["DIM"] = w.dim
        if refinement == 0 or start_iteration < refinement:
            sim_stats.append(stats)
            print "SIM_STATS:", sim_stats[refinement]
            
        # exit when either error threshold or max_refinements or max_dof is reached
        if refinement > max_refinements:
            logger.info("SKIPPING REFINEMENT after FINAL SOLUTION in ITERATION %i", refinement)
            break
        if sim_stats[refinement]["DOFS"] >= max_dof:
            logger.info("REACHED %i DOFS, EXITING refinement loop", sim_stats[refinement]["DOFS"])
            break
        if xi <= error_eps:
            logger.info("SKIPPING REFINEMENT since ERROR REACHED requested ACCURACY, xi=%f", xi)
            break

        # -----------------------------------
        # mark and refine and activate new mi
        # -----------------------------------

        if refinement < max_refinements:
            logger.debug("START marking === %s", str(do_refinement))
            # === mark x ===
            res_marked = False
            if do_refinement["RES"]:
                cell_ids = []
                if not do_uniform_refinement:        
                    if global_eta > rho * global_zeta or not do_refinement["TAIL"]:
                        logger.info("REFINE RES")
                        with timing(msg="Marking.mark_x", logfunc=logger.info, store_func=partial(_store_stats, key="TIME-MARK-RES", stats=stats)):
                            cell_ids = Marking.mark_x(global_eta, eta_local, theta_x)
                        res_marked = True
                    else:
                        logger.info("SKIP REFINE RES -> mark stochastic modes instead")
                else:
                    # uniformly refine mesh
                    logger.info("UNIFORM refinement RES")
                    cell_ids = [c.index() for c in cells(w.basis._fefs.mesh())]
                    res_marked = True
            else:
                logger.info("SKIP residual refinement")
            # refine mesh
            if res_marked:
                logger.debug("w.dim BEFORE refinement: %s", w.dim)
                with timing(msg="Marking.refine_x", logfunc=logger.info, store_func=partial(_store_stats, key="TIME-REFINE-RES", stats=stats)):
                    w = Marking.refine_x(w, cell_ids)
                logger.debug("w.dim AFTER refinement: %s", w.dim)
                            
            # === mark y ===
            if do_refinement["TAIL"] and not res_marked:
                logger.info("REFINE TAIL")
                with timing(msg="Marking.mark_y", logfunc=logger.info, store_func=partial(_store_stats, key="TIME-MARK-TAIL", stats=stats)):
                    new_mi = Marking.mark_y(w.active_indices(), zeta, eval_zeta_m, theta_y, add_maxm)
                # add new multiindices
                with timing(msg="Marking.refine_y", logfunc=logger.info, store_func=partial(_store_stats, key="TIME-REFINE-TAIL", stats=stats)):
                    Marking.refine_y(w, new_mi)
            else:
                new_mi = []
                logger.info("SKIP tail refinement")

            # === uniformly refine for coefficient function oscillations ===
            if do_refinement["OSC"]:
                logger.info("REFINE OSC")
                with timing(msg="Marking.refine_osc", logfunc=logger.info, store_func=partial(_store_stats, key="TIME-REFINE-OSC", stats=stats)):
                    w, maxh, Cadelta = Marking.refine_osc(w, coeff_field, refine_osc_factor)
                    logger.info("coefficient oscillations require maxh %f with current mesh maxh %f and Cadelta %f" % (maxh, w.basis.basis.mesh.hmax(), Cadelta))
                    stats["CADELTA"] = Cadelta
            else:
                logger.info("SKIP oscillation refinement")
            
            logger.info("MARKING was carried out with %s (res) cells and %s (mi) new multiindices", len(cell_ids), len(new_mi))
            stats["MARKING-RES"] = len(cell_ids)
            stats["MARKING-MI"] = len(new_mi)
    
    if refinement:
        logger.info("ENDED refinement loop after %i of (max) %i refinements with %i dofs and %i active multiindices",
                    refinement, max_refinements, sim_stats[refinement]["DOFS"], len(sim_stats[refinement]["MI"]))

    return w, sim_stats