def em_hmm(scores, ages, ref_lifespans, ref_ages, ref_scores, ref_states, age_sigma, score_sigma=None, lifespan_sigma=None, dead_sigma=None, em_iters=5, update_p_obs=True, update_p_trans=True): estimator = HMMParameterEstimator(ref_lifespans, ref_ages, ref_scores, ref_states, age_sigma, score_sigma, lifespan_sigma, dead_sigma) p_obses = estimator.p_obs(scores, ages) p_initial = estimator.p_initial(ages[0]) p_transition = estimator.p_transition(ages) if not (update_p_trans or update_p_obs): em_iters = 1 for i in range(em_iters): states = numpy.array([hmm.viterbi(p_obs, p_transition, p_initial) for p_obs in p_obses]) if i > 0: diffs = (states != prev_states).sum() # print(i, diffs) if diffs == 0: break prev_states = states if update_p_obs or update_p_trans: lifespans = cleanup_lifespans(states_to_lifespans(states, ages), ages) ref_ages, ref_scores, ref_states = make_ref(ages, scores, states) estimator = HMMParameterEstimator(lifespans, ref_ages, ref_scores, ref_states, age_sigma, score_sigma=None, lifespan_sigma=None, dead_sigma=None) if update_p_obs: p_obses = estimator.p_obs(scores, ages) if update_p_trans: p_initial = estimate_p_initial(states) p_transition = estimator.p_transition(ages) return states, p_initial, p_transition
def simple_hmm(scores, ages, ref_lifespans, ref_ages, ref_scores, ref_states, lifespan_sigma=None): scores_live = ref_scores[ref_states==1] scores_dead = ref_scores[ref_states==0] estimator = hmm.ObservationProbabilityEstimator([scores_dead, scores_live]) p_obses = [estimator(score_series) for score_series in scores] estimator = SimpleHMMEstimator(ref_lifespans, lifespan_sigma) p_initial = estimator.p_initial(ages[0]) p_transition = estimator.p_transition(ages) states = numpy.array([hmm.viterbi(p_obs, p_transition, p_initial) for p_obs in p_obses]) return states, p_initial, p_transition
def annotate_dead(scores, scores_live, scores_dead, transition_time_sigma=3, max_iters=5): estimator = hmm.ObservationProbabilityEstimator([scores_dead, scores_live]) p_obses = [estimator(score_series) for score_series in scores] prev_states = numpy.zeros(scores.shape, dtype=numpy.uint16) for i in range(max_iters): if i == 0: # state 0 = dead, state 1 =live p_transition = [[1, 0], [0.5, 0.5]] p_initial = [0.1, 0.9] else: p_initial, p_transition = hmm.estimate_hmm_params(prev_states, pseudocount=1, time_sigma=transition_time_sigma) p_transition[:,0] = [1, 0] # sternly envorce that the dead stay dead. states = numpy.array([hmm.viterbi(p_obs, p_transition, p_initial) for p_obs in p_obses]) diffs = (states != prev_states).sum() prev_states = states #print(i, diffs) if diffs == 0: break return states, p_initial, p_transition