コード例 #1
0
ファイル: test_fit.py プロジェクト: mdhimes/mc3
def test_fit_minimal():
    output = mc3.fit(data, uncert, quad, np.copy(params), indparams=[x])
    np.testing.assert_allclose(output['best_log_post'], -54.43381306220858)
    np.testing.assert_equal(-2 * output['best_log_post'], output['best_chisq'])
    np.testing.assert_allclose(output['bestp'],
                               np.array([4.28263253, -2.40781859, 0.49534411]),
                               rtol=1e-7)
コード例 #2
0
ファイル: test_fit.py プロジェクト: mdhimes/mc3
def test_fit_trf():
    output = mc3.fit(data,
                     uncert,
                     quad,
                     np.copy(params),
                     indparams=[x],
                     leastsq='trf')
    np.testing.assert_allclose(output['best_log_post'], -54.43381306220856)
    np.testing.assert_allclose(output['bestp'],
                               np.array([4.28263252, -2.40781858, 0.49534411]),
                               rtol=1e-7)
コード例 #3
0
ファイル: test_fit.py プロジェクト: mdhimes/mc3
def test_fit_leastsq_error(capsys):
    with pytest.raises(SystemExit):
        output = mc3.fit(data,
                         uncert,
                         quad,
                         np.copy(params),
                         indparams=[x],
                         leastsq='invalid')
        captured = capsys.readouterr()
        assert "Invalid 'leastsq' input (invalid). Must select from " \
               "['lm', 'trf']." in captured.out
コード例 #4
0
ファイル: test_fit.py プロジェクト: mdhimes/mc3
def test_fit_bounds():
    output = mc3.fit(data,
                     uncert,
                     quad, [4.5, -2.5, 0.5],
                     indparams=[x],
                     pmin=[4.4, -3.0, 0.4],
                     pmax=[5.0, -2.0, 0.6],
                     leastsq='trf')
    np.testing.assert_allclose(output['best_log_post'], -54.45536109795812)
    np.testing.assert_allclose(output['bestp'],
                               np.array([4.4, -2.46545897, 0.5009366]),
                               rtol=1e-7)
コード例 #5
0
ファイル: test_fit.py プロジェクト: mdhimes/mc3
def test_fit_shared():
    output = mc3.fit(data1,
                     uncert1,
                     quad,
                     np.copy(params),
                     indparams=[x],
                     pstep=[1.0, -1, 1.0])
    assert output['bestp'][1] == output['bestp'][0]
    np.testing.assert_allclose(output['best_log_post'], -51.037667264657)
    np.testing.assert_allclose(output['bestp'],
                               np.array([4.58657213, 4.58657213, 0.43347714]),
                               rtol=1e-7)
コード例 #6
0
ファイル: test_fit.py プロジェクト: mdhimes/mc3
def test_fit_fixed():
    pars = np.copy(params)
    pars[0] = p0[0]
    output = mc3.fit(data,
                     uncert,
                     quad,
                     pars,
                     indparams=[x],
                     pstep=[0.0, 1.0, 1.0])
    assert output['bestp'][0] == pars[0]
    np.testing.assert_allclose(output['best_log_post'], -54.507722717665466)
    np.testing.assert_allclose(output['bestp'],
                               np.array([4.5, -2.51456999, 0.50570154]),
                               rtol=1e-7)
コード例 #7
0
ファイル: test_fit.py プロジェクト: mdhimes/mc3
def test_fit_priors():
    prior = np.array([4.5, 0.0, 0.0])
    priorlow = np.array([0.1, 0.0, 0.0])
    priorup = np.array([0.1, 0.0, 0.0])
    output = mc3.fit(data,
                     uncert,
                     quad,
                     np.copy(params),
                     indparams=[x],
                     prior=prior,
                     priorlow=priorlow,
                     priorup=priorup)
    np.testing.assert_allclose(output['best_log_post'], -54.50548056991611)
    # First parameter is closer to 4.5 than without a prior:
    np.testing.assert_allclose(output['bestp'],
                               np.array([4.49340587, -2.51133157, 0.50538734]),
                               rtol=1e-7)
コード例 #8
0
names = [None] * ntargets
flat_chisq = np.zeros(ntargets)
flat_rchisq = np.zeros(ntargets)
flat_bic = np.zeros(ntargets)

for i in range(ntargets):
    cfile = [f for f in os.listdir(targets[i]) if f.endswith(".cfg")][0]
    with cd(targets[i]):
        pyrat = pb.run(cfile, run_step='dry', no_logfile=True)

    data = pyrat.obs.data
    uncert = pyrat.obs.uncert
    indparams = [data]
    params = np.array([np.mean(data)])
    pstep = np.array([1.0])
    mc3_fit = mc3.fit(data, uncert, sfit, params, indparams, pstep)

    dof = len(data) - len(params)
    names[i] = cfile.split('_')[1] + ' ' + cfile.split('_')[2]
    flat_chisq[i] = mc3_fit['best_chisq']
    flat_rchisq[i] = mc3_fit['best_chisq'] / dof
    flat_bic[i] = mc3_fit['best_chisq'] + len(params) * np.log(len(data))

with open("stats/flat_fit.dat", "w") as f:
    f.write("Flat-curve fit to transmission data (Nfree = 1):\n\n"
            "Planet                chi-square  red-chisq      BIC\n"
            "----------------------------------------------------\n")
    for i in range(ntargets):
        f.write(f"{names[i]:20s}  {flat_chisq[i]:10.3f}  "
                f"{flat_rchisq[i]:9.3f}  {flat_bic[i]:7.3f}\n")