コード例 #1
0
    def test2(self):
        sol = [[(2, 3), (1.6, 2.7)], [(4, 5, 6), (0.1, 0.2, 0.1)]]

        test = [([2, 3], [1.6, 2.7]), ([4, 5, 5, 6], [.1, .1, .1, .1])]
        res = ctmc.datacorrection(test)

        npt.assert_allclose(flatten(res), flatten(sol))
コード例 #2
0
    def test6(self):
        sol = [[(4, 6), (0.1, 0.1)]]

        test = [([4, 5, 6], [.1, .0, .1]), ([7, 8, 9], [.1, .0, .0])]
        res = ctmc.datacorrection(test)

        npt.assert_allclose(flatten(res), flatten(sol))
コード例 #3
0
    def test1(self):
        sol = [[(2, 3), (1.6, 2.7)]]

        test = [([1], [0.5]), ([2, 3], [1.6, 2.7])]
        res = ctmc.datacorrection(test)

        # npt.assert_allclose(res, sol)
        npt.assert_allclose(flatten(res), flatten(sol))
コード例 #4
0
ファイル: panelctmc_func.py プロジェクト: kmedian/panelctmc
def panelctmc(
    paneldata: np.ndarray,
    mapping: list,
    lastdate: datetime = None,
    transintv: float = 1.0,
    toltime: float = 1e-8,
    debug: bool = True
) -> (np.ndarray, np.ndarray, np.ndarray, np.ndarray, list):
    # check if numpy array
    if not isinstance(paneldata, np.ndarray):
        raise Exception("'paneldata' is not a numpy array")

    # force dtype=object
    paneldata = paneldata.astype(dtype=object)

    # convert str to datetime objects
    if isinstance(paneldata[:, 1][0], str):
        paneldata[:, 1] = [
            datetime.strptime(p, "%Y-%m-%d") for p in paneldata[:, 1]
        ]

    # encode state labels
    paneldata[:, 2] = grouplabelencode(paneldata[:, 2], mapping, nastate=True)

    # convert panel data to ctmc-datalist object
    datalist = panel_to_datalist(paneldata, lastdate=lastdate)

    # auto correct datalist
    datalist = datacorrection(datalist)

    # Compute transitition matrix
    transmat, genmat, transcount, statetime = ctmc(datalist,
                                                   len(mapping) + 1,
                                                   transintv=transintv,
                                                   toltime=toltime,
                                                   debug=debug)

    # done
    return transmat, genmat, transcount, statetime, datalist