def extract_false_positives(num_mix_params,first_pass_fns):
    for num_mix in num_mix_params:
        affinities = np.load('data/%d_affinities.npy' % (num_mix))
        for fnr in first_pass_fnrs:
            out = open('data/false_pos_times_%d_%d.pkl' % (num_mix,fnr),'rb')
            false_pos_times=pickle.load(out)
            out.close()
            false_positives = rf.get_false_positives(false_pos_times,
                                                     S_config=sp,
                                                     E_config=ep,
                                                     offset=0,
                                                     waveform_offset=7)
            example_mat = gtrd.recover_example_map(false_positives)
            false_pos_assigned_phns,false_pos_phn_contexts,utt_paths,file_idx,start_ends = gtrd.recover_assigned_phns(false_positives,example_mat)
            np.savez('data/false_pos_phns_assigned_contexts_%d_%d.npy' % (num_mix,fnr),
                    example_mat=example_mat,
                    assigned_phns=false_pos_assigned_phns,
                    phn_contexts=false_pos_phn_contexts,
                     utt_paths=utt_paths,
                     file_idx=file_idx,
                     start_ends=start_ends)
            np.save('data/false_positives_example_mat_%d_%d.npy' % (num_mix,fnr),example_mat)
            lengths,waveforms  = gtrd.recover_waveforms(false_positives,example_mat)
            np.savez('data/false_positives_waveforms_lengths_%d_%d.npz' % (num_mix,fnr),
                     lengths=lengths,
                     waveforms=waveforms)
            Slengths,Ss  = gtrd.recover_specs(false_positives,example_mat)
            np.savez('data/false_positives_Ss_lengths_%d_%d.npz' % (num_mix,fnr),
                     lengths=Slengths,
                     Ss=Ss)
            Elengths,Es  = gtrd.recover_edgemaps(false_positives,example_mat)
            np.savez('data/false_positives_Es_lengths_%d_%d.npz' % (num_mix,fnr),
                     lengths=Elengths,
                     Es=Es)
            templates = et.recover_different_length_templates(affinities,
                                                      Es,
                                                      Elengths)
            spec_templates = et.recover_different_length_templates(affinities,
                                                                   Ss,
                                                                   Slengths)
            cluster_counts = np.zeros(num_mix)
            for mix_id in xrange(num_mix):
                os.system("mkdir -p data/%d_%d_%d" % (num_mix,fnr,mix_id))
            for example_id in xrange(waveforms.shape[0]):
                if affinities[example_id].max() > .999:
                    cluster_id = np.argmax(affinities[example_id])
                    wavfile.write('data/%d_%d_%d/%d.wav' % (num_mix,fnr,cluster_id,cluster_counts[cluster_id]),16000,((2**15-1)*waveforms[example_id]).astype(np.int16))
                    cluster_counts[cluster_id] += 1

np.save('data/aar1_padded_examplesE.npy',Es)



num_mix = 2
import matplotlib.pyplot as plt

num_mix_params = [2,3,4,5,6,7,8,9]
for num_mix in num_mix_params:
    print num_mix
    bem = bm.BernoulliMixture(num_mix,Es)
    bem.run_EM(.000001)
    templates = et.recover_different_length_templates(bem.affinities,
                                                        Es,
                                                        lengths)
    spec_templates = et.recover_different_length_templates(bem.affinities,
                                                           Ss,
                                                           Slengths)
    num_plots = len(spec_templates)
    num_rows = 2
    num_cols = num_plots/num_rows+1
    for i in xrange(len(spec_templates)):
        plt.subplot(num_cols,num_rows,i+1)
        plt.imshow(spec_templates[i].T[::-1],interpolation='nearest')
    plt.savefig('aar1_spec_templates_%d.png' % num_mix)
    np.savez('aar1_templates_%d.npz' % num_mix, templates)
    np.save('aar1_affinities_%d.npy' % num_mix, bem.affinities)

#
# we are now going to do the clustering procedure on these
#

aar_examples = np.load(tmp_data_path+'aar_examples.npy')
aar_lengths = np.load(tmp_data_path + 'aar_lengths.npy')
clipped_bgd = np.load(tmp_data_path+'clipped_train_bgd.npy')

aar_bgd_padded = et.pad_examples_bgd_samples(aar_examples,aar_lengths,clipped_bgd)


bem2 = bm.BernoulliMixture(2,aar_bgd_padded)
bem2.run_EM(.000001)

aar2 = et.recover_different_length_templates(bem2.affinities,aar_examples,aar_lengths)

for i in xrange(2):
    np.save(tmp_data_path + 'aar2_%d.npy' % i,aar2[i])

aar_mixture = tuple(
    np.load(tmp_data_path+'aar2_%d.npy' % i)
    for i in xrange(2))

# now we will get the detection array associated with these two

test_example_lengths = np.load(tmp_data_path+'test_example_lengths.npy')

detection_array = np.zeros((test_example_lengths.shape[0],
                            int(test_example_lengths.max()/float(log_part_blocks.shape[1]) + .5) + 2),dtype=np.float32)
def perform_phn_template_estimation(phn,utterances_path,
                                    file_indices,sp,ep,
                                    num_mix_params,
                                    phn_mapping=None,
                                    waveform_offset=15,
                                    chunk_length=1000):
    phn_tuple = (phn,)
    print phn
    phn_features,avg_bgd=gtrd.get_syllable_features_directory(utterances_path,file_indices,phn_tuple,
                                                              S_config=sp,E_config=ep,offset=0,
                                                              E_verbose=False,return_avg_bgd=True,
                                                              waveform_offset=15,
                                                              phn_mapping=phn_mapping)
    bgd = np.clip(avg_bgd.E,.01,.99)
    np.save('data/bgd.npy',bgd)
    example_mat = gtrd.recover_example_map(phn_features)
    lengths,waveforms  = gtrd.recover_waveforms(phn_features,example_mat)
    np.savez('data/waveforms_lengths.npz',waveforms=waveforms,
             lengths=lengths,
         example_mat=example_mat)
    Slengths,Ss  = gtrd.recover_specs(phn_features,example_mat)
    Ss = Ss.astype(np.float32)
    np.savez('data/Ss_lengths.npz' ,Ss=Ss,Slengths=Slengths,example_mat=example_mat)
    Elengths,Es  = gtrd.recover_edgemaps(phn_features,example_mat,bgd=bgd)
    Es = Es.astype(np.uint8)
    np.savez('data/Es_lengths.npz' ,Es=Es,Elengths=Elengths,example_mat=example_mat)
    # the Es are padded from recover_edgemaps
    f = open('data/mixture_estimation_stats_%s.data' % phn,'w')
    for num_mix in num_mix_params:
        print num_mix
        if num_mix == 1:
            affinities = np.ones((Es.shape[0],1),dtype=np.float64)
            mean_length = int(np.mean(Elengths) + .5)
            templates = (np.mean(Es,0)[:mean_length],)
            spec_templates = (np.mean(Ss,0)[:mean_length],)
            np.save('data/%d_affinities.npy' % (num_mix),
                    affinities)
            np.save('data/%d_templates.npy' % (num_mix),
                    templates)
            np.save('data/%d_spec_templates.npy' % (num_mix),
                    spec_templates)
            np.save('data/%d_templates_%s.npy' % (num_mix,phn),
                    templates)
            np.save('data/%d_spec_templates_%s.npy' % (num_mix,phn),
                    spec_templates)
            #
            # write the data to the mixture file for checking purposes
            # format is:
            #   num_components total c0 c1 c2 ... ck
            f.write('%d %d %g\n' % (num_mix,
                                  len(affinities),np.sum(affinities[:,0])))
        else:
            if len(Es) > chunk_length:
                bem = bm.BernoulliMixture(num_mix,Es[:chunk_length])
                bem.run_EM(.000001)
                for i in xrange(1,len(Es)/chunk_length):
                    start_idx = i*chunk_length
                    block_length = min(chunk_length,len(Es)-start_idx)
                    if block_length < chunk_length:
                        end_idx = len(Es)
                        start_idx = len(Es)-chunk_length
                        block_length = chunk_length
                    else:
                        end_idx = start_idx + block_length
                    bem.data_mat = Es[start_idx:end_idx].reshape(
                        block_length,bem.data_length)
                    bem.run_EM(.000001)


            else:
                bem = bm.BernoulliMixture(num_mix,Es)
                bem.run_EM(.000001)
            templates = et.recover_different_length_templates(bem.affinities,
                                                              Es[start_idx:end_idx],
                                                              Elengths[start_idx:end_idx])
            spec_templates = et.recover_different_length_templates(bem.affinities,
                                                               Ss[start_idx:end_idx],
                                                               Slengths[start_idx:end_idx])
            np.save('data/%d_affinities.npy' % (num_mix),
                    bem.affinities)
            np.savez('data/%d_templates.npz' % (num_mix),
                    *templates)
            np.savez('data/%d_spec_templates.npz' % (num_mix),
                    *spec_templates)
            np.savez('data/%d_templates_%s.npz' % (num_mix,phn),
                    *templates)
            np.savez('data/%d_spec_templates_%s.npz' % (num_mix,phn),
                    *spec_templates)
            f.write('%d %d ' % (num_mix,
                                  len(affinities))
                    + ' '.join(str(np.sum(affinities[:,i]))
                               for i in xrange(affinities.shape[1]))
                               +'\n')
    f.close()
     if end_idx < len(Es):
         end_idx = len(Es)
         start_idx = len(Es)-chunk_length
         bem.data_mat = Es[start_idx:end_idx].reshape(
             block_length,bem.data_length)
         gc.collect()
         bem.run_EM(.000001)
 else:
     print "len(Es)=%d > chunk_length=%d" % (len(Es),chunk_length)
     bem = bm.BernoulliMixture(num_mix,Es)
     bem.run_EM(.000001)
     start_idx = 0
     end_idx = len(Es)
 print "start_idx=%d\tend_idx=%d" % (start_idx,end_idx)
 templates = et.recover_different_length_templates(bem.affinities,
                                                   Es[start_idx:end_idx],
                                                   Elengths[start_idx:end_idx])
 spec_templates = et.recover_different_length_templates(bem.affinities,
                                                    Ss[start_idx:end_idx],
                                                    Slengths[start_idx:end_idx])
 np.save('data/%d_affinities.npy' % (num_mix),
         bem.affinities)
 np.savez('data/%d_templates.npz' % (num_mix),
         *templates)
 np.savez('data/%d_spec_templates.npz' % (num_mix),
         *spec_templates)
 np.savez('data/%d_templates_%s.npz' % (num_mix,phn),
         *templates)
 np.savez('data/%d_spec_templates_%s.npz' % (num_mix,phn),
         *spec_templates)
 f.write('%d %d ' % (num_mix,
del backgrounds


# estimate mixture models
#
import template_speech_rec.bernoulli_mixture as bm

mixture_models_syllable = {}
for syllable, examples in padded_examples_syllable_dict.items():
    mixture_models_syllable[syllable] = bm.BernoulliMixture(2,examples)
    mixture_models_syllable[syllable].run_EM(.000001)
    print syllable

template_tuples_syllable = {}
for syllable, mm in mixture_models_syllable.items():
    template_tuples_syllable[syllable] = et.recover_different_length_templates(mm.affinities,padded_examples_syllable_dict[syllable],
                                                                               np.array([e.shape[0] for e in syllable_examples[syllable]]))


for syllable, template_tuples in template_tuples_syllable.items():
    for i, template in enumerate(template_tuples):
        np.save(tmp_data_path+'template_%s_%s__%d_%d.npy' % (syllable[0],
                                                             syllable[1],
                                                             len(template_tuples),
                                                             i),
                                                             template)


test_example_lengths = gtrd.get_detect_lengths(data_path+'Test/')

detection_array = np.zeros((test_example_lengths.shape[0],
                            int(test_example_lengths.max()/float(log_part_blocks.shape[1]) + .5) + 2),dtype=np.float32)