def em_hmm(scores, ages, ref_lifespans, ref_ages, ref_scores, ref_states, age_sigma, score_sigma=None, lifespan_sigma=None, dead_sigma=None, em_iters=5, update_p_obs=True, update_p_trans=True):
    estimator = HMMParameterEstimator(ref_lifespans, ref_ages, ref_scores, ref_states, age_sigma, score_sigma, lifespan_sigma, dead_sigma)
    p_obses = estimator.p_obs(scores, ages)
    p_initial = estimator.p_initial(ages[0])
    p_transition = estimator.p_transition(ages)
    if not (update_p_trans or update_p_obs):
        em_iters = 1
    for i in range(em_iters):
        states = numpy.array([hmm.viterbi(p_obs, p_transition, p_initial) for p_obs in p_obses])
        if i > 0:
            diffs = (states != prev_states).sum()
            # print(i, diffs)
            if diffs == 0:
                break
        prev_states = states
        if update_p_obs or update_p_trans:
            lifespans = cleanup_lifespans(states_to_lifespans(states, ages), ages)
            ref_ages, ref_scores, ref_states = make_ref(ages, scores, states)
            estimator = HMMParameterEstimator(lifespans, ref_ages, ref_scores, ref_states, age_sigma, score_sigma=None, lifespan_sigma=None, dead_sigma=None)
            if update_p_obs:
                p_obses = estimator.p_obs(scores, ages)
            if update_p_trans:
                p_initial = estimate_p_initial(states)
                p_transition = estimator.p_transition(ages)
    return states, p_initial, p_transition
def simple_hmm(scores, ages, ref_lifespans, ref_ages, ref_scores, ref_states, lifespan_sigma=None):
    scores_live = ref_scores[ref_states==1]
    scores_dead = ref_scores[ref_states==0]
    estimator = hmm.ObservationProbabilityEstimator([scores_dead, scores_live])
    p_obses = [estimator(score_series) for score_series in scores]
    estimator = SimpleHMMEstimator(ref_lifespans, lifespan_sigma)
    p_initial = estimator.p_initial(ages[0])
    p_transition = estimator.p_transition(ages)
    states = numpy.array([hmm.viterbi(p_obs, p_transition, p_initial) for p_obs in p_obses])
    return states, p_initial, p_transition
def annotate_dead(scores, scores_live, scores_dead, transition_time_sigma=3, max_iters=5):
    estimator = hmm.ObservationProbabilityEstimator([scores_dead, scores_live])
    p_obses = [estimator(score_series) for score_series in scores]
    prev_states = numpy.zeros(scores.shape, dtype=numpy.uint16)
    for i in range(max_iters):
        if i == 0:
            # state 0 = dead, state 1 =live
            p_transition = [[1, 0], [0.5, 0.5]]
            p_initial = [0.1, 0.9]
        else:
            p_initial, p_transition = hmm.estimate_hmm_params(prev_states, pseudocount=1, time_sigma=transition_time_sigma)
            p_transition[:,0] = [1, 0] # sternly envorce that the dead stay dead.
        states = numpy.array([hmm.viterbi(p_obs, p_transition, p_initial) for p_obs in p_obses])
        diffs = (states != prev_states).sum()
        prev_states = states
        #print(i, diffs)
        if diffs == 0:
            break
    return states, p_initial, p_transition