Пример #1
0
def calculate_DER(meeting):
    meeting_base = get_meeting_base(meeting)

    labels = pd.read_csv(os.path.join(labels_dir, meeting_base))
    true_seq = labels['combined']

    # load hmm states
    hmm_states = None
    with open(meeting, 'rb') as infile:
        hmm_states = cPickle.load(infile)
    stateseq = hmm_states.stateseq

    diarization_error, best_seq, _ = hdphmm_utils.find_error_rate(stateseq, true_seq)

    hdphmm_utils.plot_pred_labels(meeting_base, best_seq, labels_dir)
    # print('DER: {}'.format('{0:.3f}'.format(diarization_error)))

    # if diarization_error > 0.8:
    #     ipdb.set_trace()
    return diarization_error
Пример #2
0
def find_states(features_file):
    meeting_base = features_file.split('/')[-1]
    print(meeting_base)

    data = np.genfromtxt(features_file, delimiter=',')
    # data = np.load(features_file)

    labels = pd.read_csv(os.path.join(labels_dir, meeting_base))
    true_seq = labels['combined']

    # features_file = open(features_file, 'rb')
    # data = pickle.load(features_file)

    ##########################
    #     Sticky-HDP-HMM     #
    ##########################

    # and some hyperparameters
    obs_dim = data.shape[1]
    obs_hypparams = {'mu_0':np.zeros(obs_dim),
                    'sigma_0':np.eye(obs_dim),
                    'kappa_0':0.25,
                    'nu_0':obs_dim+2}

    # create a bunch of multivariate gaussians
    obs_distns = [pyhsmm.distributions.Gaussian(**obs_hypparams) for state in xrange(Nmax)]

    # parameters for priors taken from Fox 2012
    gamma_draw = gamma(12,2)
    alpha_plus_kappa_draw = gamma(6,1)
    sigma_draw = gamma(1,0.5)
    rho_draw = beta(500,5)

    # can deterministically retrieve kappa and alpha from draws for alpha+kappa and rho
    kappa = rho_draw * alpha_plus_kappa_draw
    alpha = (1-rho_draw) * alpha_plus_kappa_draw

    print('kappa: {}, alpha: {}, gamma: {}'.format(kappa, alpha, gamma_draw))
    # ipdb.set_trace()

    obs_hypparams = {'mu_0':np.zeros(obs_dim),
                    'sigma_0':np.eye(obs_dim),
                    'kappa_0':0.3,
                    'nu_0':obs_dim+5}
    dur_hypparams = {'alpha_0':2*30,
                     'beta_0':2,
                     'lmbda':2.5
                     }

    obs_distns = [pyhsmm.distributions.Gaussian(**obs_hypparams) for state in range(Nmax)]
    dur_distns = [pyhsmm.distributions.PoissonDuration(**dur_hypparams) for state in range(Nmax)]

    # ipdb.set_trace()

    posteriormodel = pyhsmm.models.WeakLimitStickyHDPHMM(
            # NOTE: instead of passing in alpha_0 and gamma_0, we pass in parameters
            # for priors over those concentration parameters
            kappa=kappa,alpha=alpha,gamma=gamma_draw,
            init_state_concentration=6.,
            obs_distns=obs_distns)

    # ipdb.set_trace()
    # posteriormodel = pyhsmm.models.WeakLimitHDPHSMM(
    #         alpha=alpha,gamma=gamma_draw,init_state_concentration=1.,
    #         obs_distns=obs_distns,
    #         dur_distns=dur_distns)

    # data = np.zeros(data.shape)

    # posteriormodel = pyhsmm.models.WeakLimitHDPHSMM(
    #     # NOTE: instead of passing in alpha_0 and gamma_0, we pass in parameters
    #     # for priors over those concentration parameters
    #     alpha_a_0=1.,alpha_b_0=1./4,
    #     gamma_a_0=1.,gamma_b_0=1./4,
    #     init_state_concentration=6.,
    #     obs_distns=obs_distns,
    #     dur_distns=dur_distns)
    posteriormodel.add_data(data)

    # for idx in progprint_xrange(100):
    #     posteriormodel.resample_model()

    # plt.figure()
    # posteriormodel.plot()
    # plt.gcf().suptitle('Sampled after 100 iterations')

    # plt.figure()
    # t = np.linspace(0.01,30,1000)
    # plt.plot(t,scipy.stats.gamma.pdf(t,1.,scale=4.)) # NOTE: numpy/scipy scale is inverted compared to my scale
    # plt.title('Prior on concentration parameters')

    # plt.show()


    # posteriormodel.add_data(data)
    # # ipdb.set_trace()

    num_cpu = 0
    # # num_iterations = 1
    all_trans_matrices = []
    for idx in progprint_xrange(num_iterations):
        posteriormodel.resample_model(num_procs=num_cpu)
        trans_matrix = np.array([row.weights for row in posteriormodel.trans_distn._row_distns])
        all_trans_matrices.append(trans_matrix)

    # final_trans_matrix = np.mean(np.array(all_trans_matrices), axis=0) # average transition probs
    # # ipdb.set_trace()
    # for i in range(len(final_trans_matrix)):
    #     posteriormodel.trans_distn._row_distns[i].weights = final_trans_matrix[i]

    # trans_matrix = np.array([row.weights for row in posteriormodel.trans_distn._row_distns])
    # print trans_matrix

    # posteriormodel.resample_model(num_procs=num_cpu)

    # dump state sequence information
    hmm_states = posteriormodel.states_list[0]

    # ipdb.set_trace()
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    with open(os.path.join(output_dir, meeting_base + '.pickle'), 'wb') as outf:
        cPickle.dump(hmm_states, outf)

    num_states = len(set(hmm_states.stateseq_norep))
    average_duration = np.average(hmm_states.durations) / 10.0
    print('Num States: {}'.format(num_states))
    print('Average Duration: {}'.format('{0:.2f}'.format(average_duration)))
    diarization_error, best_seq, _ = hdphmm_utils.find_error_rate(hmm_states.stateseq, true_seq)
    diarization_error = '{0:.3f}'.format(diarization_error)
    print('DER: {}'.format(diarization_error))

    hdphmm_utils.plot_pred_labels(meeting_base, best_seq, labels_dir, plot_dir, diarization_error)


    # EM PLOTTING
    # plt.figure()
    # posteriormodel.plots()
    # plt.gcf().suptitle('Gibbs-sampled initialization')

    # print 'EM'

    # likes = posteriormodel.EM_fit()

    # plt.figure()
    # posteriormodel.plot()
    # plt.gcf().suptitle('EM fit')

    # plt.figure()
    # plt.plot(likes)
    # plt.gcf().suptitle('log likelihoods during EM')

    # plt.show()
    # DONE


    # posteriormodel.plot()
    # plt.gcf().suptitle('Sticky HDP-HMM sampled model: {}\n \
    #                     Num States: {}, Avg Duration: {}s, Num Iterations: {}'\
    #                     .format(filebase, num_states, 
    #                         '{0:.2f}'.format(average_duration), num_iterations))
    # plt.savefig(os.path.join(output_dir, 'plots', filebase + '_' + str(Nmax) + '.png'))
    # plt.show()
    # plt.cla()
    # plt.clf()
    print('')

    return float(diarization_error)