def test_no_overlay():
    m = Model(name="Overlay", drift=DriftConstant(drift=2), overlay=OverlayNone())
    s = m.solve_numerical()
    sample = s.resample(10000)
    f = fit_model(sample, drift=DriftConstant(drift=Fittable(minval=0, maxval=3)))
    plot_compare_solutions(s, f.solve_numerical())
    print(f)
    _verify_param_match("drift", "drift", m, f)
def test_uniform_overlay():
    m = Model(name="Overlay", drift=DriftConstant(drift=2), overlay=OverlayUniformMixture(umixturecoef=.1))
    s = m.solve_numerical()
    sample = s.resample(10000)
    f = fit_model(sample, drift=DriftConstant(drift=Fittable(minval=0, maxval=3)),
                  overlay=OverlayUniformMixture(umixturecoef=Fittable(minval=.001, maxval=.5)))
    plot_compare_solutions(s, f.solve_numerical())
    print(f)
    _verify_param_match("drift", "drift", m, f)
    _verify_param_match("overlay", "umixturecoef", m, f)
def test_fit_simple_ddm():
    m = Model(name="DDM", dt=.01,
              drift=DriftConstant(drift=2),
              noise=NoiseConstant(noise=1),
              bound=BoundConstant(B=1))
    s = m.solve()
    sample = s.resample(10000)
    mfit = fit_model(sample, drift=DriftConstant(drift=Fittable(minval=0, maxval=10)))
    # Within 10%
    if SHOW_PLOTS:
        mfit.name = "Fitted solution"
        sfit = mfit.solve()
        plot_compare_solutions(s, sfit)
        plt.show()

    _verify_param_match("drift", "drift", m, mfit)
def test_overlay_chain_distribution_integrates_to_1():
    m = Model(name="Overlay_test", drift=DriftConstant(drift=2),
              overlay=OverlayChain(overlays=[OverlayPoissonMixture(pmixturecoef=.2, rate=2),
                                             OverlayUniformMixture(umixturecoef=.2)]))
    s = m.solve_numerical()
    distsum = s.prob_correct() + s.prob_error()
    assert .98 < distsum < 1.0001, "Distribution doesn't sum to 1"
def test_verify_ddm_analytic_close_to_numeric_params4():
    m = Model(dx=.005, dt=.01, T_dur=2,
              drift=DriftConstant(drift=.1),
              noise=NoiseConstant(noise=1),
              bound=BoundConstant(B=.6))
    _modeltest_numerical_vs_analytical(m, max_diff=1)