def check_restarted_in_middle(restart_outdir, restart_point=20):

    last_no_soln = "soln%i.dat" % (restart_point - 1)
    first_soln = "soln%i.dat" % restart_point

    ok = (not os.path.isfile(pjoin(restart_outdir, last_no_soln))) \
      and (os.path.isfile(pjoin(restart_outdir, first_soln)))

    if ok:
        mm.okprint("Restarted at correct point in", restart_outdir)
    else:
        mm.badprint("Failed or restarted at wrong point in", restart_outdir)

    return ok
def check_ndt_less_than(data, max_ndt, identifier=None):

    if identifier is None:
        identifier = _default_label(data)

    ndt = len(data['dts'])

    passed = ndt < max_ndt

    if not passed:
        mm.badprint("FAILED in ",  identifier)
        mm.badprint("n steps is", ndt, "which is more than", max_ndt)

    else:
        mm.okprint("n steps", ndt, "ok in", identifier)

    return passed
def check_m_length(data, tol=1e-8, identifier=None):
    """Check that abs(|m| - 1) < tol."""

    if identifier is None:
        identifier = _default_label(data)

    max_length_err = max(map(abs, data['m_length_error_means']))
    length_test = max_length_err < tol

    if not length_test:
        mm.badprint("FAILED in ", identifier,
                    "with |m| error of", max_length_err)

    else:
        mm.okprint("|m| max of", max_length_err, "ok in", identifier)

    return length_test
def check_error_norm(data, tol=1e-5, identifier=None):
    """Check that max error norm < tol."""

    if identifier is None:
        identifier = _default_label(data)

    max_err_norm = max(map(abs, data['error_norms']))
    norm_test = max_err_norm < tol

    if not norm_test:
        mm.badprint("FAILED in", identifier)
        mm.badprint("with max error norm of", max_err_norm)

    else:
        mm.okprint("max error norm of", max_err_norm, "ok in", identifier)

    return norm_test
def check_solns_match(main_dir="Validation", compare_dir="validata", **kwargs):

    matches = []
    for fname in os.listdir(main_dir):
        if ("soln" in fname) and (".dat" in fname):
            match, _, _ = fpdiff(pjoin(main_dir, fname),
                                 pjoin(compare_dir, fname), **kwargs)
            matches.append(match)

    if len(matches) == 0:
        mm.badprint("No files in", main_dir)
        return False

    ok = all(matches)
    if ok:
        mm.okprint("Files match in", main_dir, compare_dir)
    else:
        mm.badprint("Files don't match in", main_dir, compare_dir)

    return ok
def try_run(args, dt, dir_naming_args, root_outdir,
             maxallowederror = 0.01, maxallowedangle = sp.pi/4):

    args.update({'-dt' : dt})

    # Try to run it
    err_code, outdir = mm._run(args, root_outdir, dir_naming_args,
                                quiet=True)


    # Parse output files
    data = mm.parse_run(outdir)
    if data is not None:
        errs = data['m_length_error_means']
        assert errs[0] >= 0
        maxerror = max(errs)

        angles = data['max_angle_errors']
        maxangle = max(angles)

        maxtime = data['times'][-1]
        nsteps = len(data['times'])

    else:
        maxerror = sp.inf
        maxangle = sp.inf
        maxtime = 0
        nsteps = 0

    success = (err_code == 0 and maxerror < maxallowederror
                and maxangle < maxallowedangle)

    if success:
        mm.okprint("Succedded in", data['-outdir'])
    else:
        mm.badprint(pjoin(data['-outdir'], "stdout:1:1:"), "FAILED")

    return success, data
def check_error_norm_relative(data, tol=1e-5, identifier=None):
    """Check that max error norm < tol * maximum_trace_value."""

    if identifier is None:
        identifier = _default_label(data)

    max_err_norm = max(map(abs, data['error_norms']))

    # Get biggest value from list of list of trace values
    max_value = max(it.chain(*data['trace_values']))

    # If max < 1 then use 1.0 as multiplier
    norm_test = max_err_norm < tol * max(max_value, 1.0)


    if not norm_test:
        mm.badprint("FAILED in", identifier)
        mm.badprint("with max error norm of", max_err_norm)

    else:
        mm.okprint("max error norm of", max_err_norm, "ok in", identifier)

    return norm_test
def check_convergence(datasets, expected_rate, tol=0.2):
    """For a set of convergence test check that the rate of convergence as dt
    and h -> 0 is as expected. dt and h are linked so it is enough to
    test only for dt.
    """

    rate = mm.convergence_rate(datasets, error_norm_norm=max)

    # Check if near to expected value
    ok = abs(rate - expected_rate) < tol

    # Print messages, identify by first dataset for simplicity.
    dataid = _default_label(datasets[0])
    if ok:
        mm.okprint("convergence rate", rate, "ok in", dataid)
    else:
        mm.badprint("convergence rate", rate,
                    "but expected", expected_rate, "+/-", tol,
                    "in", dataid)

    # ??ds check goodness of fit?

    return ok
def check_mean_m_matches(data1, data2, tol=1e-4):


    id1 = _default_label(data1)
    id2 = _default_label(data2)

    for a, b in zip(data1['mean_mxs'], data2['mean_mxs']):
        if abs(a - b) > tol:
            mm.badprint("Failed mx comparison in", id1, id2, "error of", abs(a - b))
            return False

    for a, b in zip(data1['mean_mys'], data2['mean_mys']):
        if abs(a - b) > tol:
            mm.badprint("Failed my comparison in", id1, id2, "error of", abs(a - b))
            return False

    for a, b in zip(data1['mean_mzs'], data2['mean_mzs']):
        if abs(a - b) > tol:
            mm.badprint("Failed mz comparison in", id1, id2, "error of", abs(a - b))
            return False

    mm.okprint("m comparison ok in", id1, id2)
    return True
def check_dicts_match(data, exact, keys, **kwargs):
    """Given two dicts and a list of keys check that those values match.

    ??ds actually temp hack: use the second value of the array given by
    data['k']...
    """

    def fpcheck(v1, v2, rtol=1e-8, zero=1e-14):
        """Check two floats are the same: if both small then true, otherwise check
        relative difference.
        """
        if v1 < zero and v2 < zero:
            return True
        else:
            return (abs((v1 - v2)/v1) < rtol)


    dataid = _default_label(data)

    ok = []
    for k in keys:
        try:
            a = fpcheck(exact[k], data[k][1], **kwargs)
            ok.append(a)

            # Print messages
            if a:
                mm.okprint(k, "comparison ok in", dataid)
            else:
                mm.badprint("Failed", k, "comparison with values of",
                            data[k][1], exact[k], "in", dataid)

        except KeyError:
            mm.badprint("Failed: missing key", k, "in", dataid)
            ok.append(False)

    return all(ok)