コード例 #1
0
def create_mlf(indir, outfile, verbose):
    # pylint: disable-msg=E1103
    # get all transcriptions
    files = []
    files.append(glob.glob(os.path.join(indir, '*.wav')))
    files.append(glob.glob(os.path.join(indir, '*', '*.wav')))
    files.append(glob.glob(os.path.join(indir, '*', '*', '*.wav')))
    files.append(glob.glob(os.path.join(indir, '*', '*', '*', '*.wav')))
    files.append(glob.glob(os.path.join(indir, '*', '*', '*', '*', '*.wav')))
    files.append(
        glob.glob(os.path.join(indir, '*', '*', '*', '*', '*', '*.wav')))

    files = flatten(files)

    mlf = open(outfile, "w")
    mlf.write("#!MLF!#\n")
    size = 0
    for f in files:

        if verbose:
            print "Processing wav file: ", f

        wav_stats = subprocess.check_output("soxi %s" % f, shell=True)
        wav_stats = wav_stats.split('\n')

        mfc = f.replace('.wav', ".mfc")
        subprocess.check_output("HCopy -T 1 -C config -C configwav %s %s" %
                                (f, mfc),
                                shell=True)

        for l in wav_stats:
            if l.startswith('Duration'):
                l = l.split()
                time = [float(x) for x in l[2].split(':')]
                time = time[0] * 60 * 60 + time[1] * 60 + time[2]

                # convert time into HTK 100ns units
                time = int(time * 10000000)

                size += time

        f = f.replace(".wav", ".lab").replace("data/", "*/")
        mlf.write('"%s"\n' % f)
        mlf.write('0 %d sil\n' % time)
        mlf.write(".\n")

    mlf.close()

    hour = size / 10000000 / 3600.0

    print "Length of audio data in hours:", hour
コード例 #2
0
def create_mlf(indir, outfile, verbose):
    # pylint: disable-msg=E1103
    # get all transcriptions
    files = []
    files.append(glob.glob(os.path.join(indir, "*.wav")))
    files.append(glob.glob(os.path.join(indir, "*", "*.wav")))
    files.append(glob.glob(os.path.join(indir, "*", "*", "*.wav")))
    files.append(glob.glob(os.path.join(indir, "*", "*", "*", "*.wav")))
    files.append(glob.glob(os.path.join(indir, "*", "*", "*", "*", "*.wav")))
    files.append(glob.glob(os.path.join(indir, "*", "*", "*", "*", "*", "*.wav")))

    files = flatten(files)

    mlf = open(outfile, "w")
    mlf.write("#!MLF!#\n")
    size = 0
    for f in files:

        if verbose:
            print "Processing wav file: ", f

        wav_stats = subprocess.check_output("soxi %s" % f, shell=True)
        wav_stats = wav_stats.split("\n")

        # mfc = f.replace('.wav', ".mfc")
        # subprocess.check_output("HCopy -T 1 -C config -C configwav %s %s" % (f, mfc), shell=True)

        for l in wav_stats:
            if l.startswith("Duration"):
                l = l.split()
                time = [float(x) for x in l[2].split(":")]
                time = time[0] * 60 * 60 + time[1] * 60 + time[2]

                # convert time into HTK 100ns units
                time = int(time * 10000000)

                size += time

        f = f.replace(".wav", ".lab").replace("./data/", "*/")
        mlf.write('"%s"\n' % f)
        mlf.write("0 %d sil\n" % time)
        mlf.write(".\n")

    mlf.close()

    hour = size / 10000000 / 3600.0

    print "Length of audio data in hours:", hour
コード例 #3
0
ファイル: build.py プロジェクト: jakub-stejskal/alex
        print cmd
        exit_on_system_fail(cmd)

    if not os.path.exists(indomain_data_text_trn_norm):
        print "Generating train and dev data"
        print "-"*120
        ###############################################################################################

        files = []
        files.append(glob.glob(os.path.join(indomain_data_dir, 'asr_transcribed.xml')))
        files.append(glob.glob(os.path.join(indomain_data_dir, '*', 'asr_transcribed.xml')))
        files.append(glob.glob(os.path.join(indomain_data_dir, '*', '*', 'asr_transcribed.xml')))
        files.append(glob.glob(os.path.join(indomain_data_dir, '*', '*', '*', 'asr_transcribed.xml')))
        files.append(glob.glob(os.path.join(indomain_data_dir, '*', '*', '*', '*', 'asr_transcribed.xml')))
        files.append(glob.glob(os.path.join(indomain_data_dir, '*', '*', '*', '*', '*', 'asr_transcribed.xml')))
        files = various.flatten(files)

        tt = []
        pt = []
        for fn in files:
#            print "Processing:", fn
            doc = xml.dom.minidom.parse(fn)
            turns = doc.getElementsByTagName("turn")
            
            for turn in turns:
                recs_list = turn.getElementsByTagName("rec")
                trans_list = turn.getElementsByTagName("asr_transcription")

                if trans_list:
                    trans = trans_list[-1]
コード例 #4
0
        print cmd
        os.system(cmd)

    if not os.path.exists(indomain_data_text_trn_norm):
        print "Generating train and dev data"
        print "-"*120
        ###############################################################################################

        files = []
        files.append(glob.glob(os.path.join(indomain_data_dir, 'asr_transcribed.xml')))
        files.append(glob.glob(os.path.join(indomain_data_dir, '*', 'asr_transcribed.xml')))
        files.append(glob.glob(os.path.join(indomain_data_dir, '*', '*', 'asr_transcribed.xml')))
        files.append(glob.glob(os.path.join(indomain_data_dir, '*', '*', '*', 'asr_transcribed.xml')))
        files.append(glob.glob(os.path.join(indomain_data_dir, '*', '*', '*', '*', 'asr_transcribed.xml')))
        files.append(glob.glob(os.path.join(indomain_data_dir, '*', '*', '*', '*', '*', 'asr_transcribed.xml')))
        files = various.flatten(files)

        tt = []
        pt = []
        for fn in files:
#            print "Processing:", fn
            doc = xml.dom.minidom.parse(fn)
            turns = doc.getElementsByTagName("turn")
            
            for turn in turns:
                recs_list = turn.getElementsByTagName("rec")
                trans_list = turn.getElementsByTagName("asr_transcription")

                if trans_list:
                    trans = trans_list[-1]
コード例 #5
0
def main():

    global asr_log
    global num_workers

    parser = argparse.ArgumentParser(
        formatter_class=argparse.RawDescriptionHelpFormatter,
        description="""This program prepares data for training Alex PTIcs SLU.
      """)

    parser.add_argument('--num_workers',
                        action="store",
                        default=num_workers,
                        type=int,
                        help='number of workers used for ASR: default %d' %
                        num_workers)
    parser.add_argument('--asr_log',
                        action="store",
                        default=asr_log,
                        type=int,
                        help='use ASR results from logs: default %d' % asr_log)

    args = parser.parse_args()

    asr_log = args.asr_log
    num_workers = args.num_workers

    fn_uniq_trn = 'uniq.trn'
    fn_uniq_trn_hdc_sem = 'uniq.trn.hdc.sem'
    fn_uniq_trn_sem = 'uniq.trn.sem'

    fn_all_sem = 'all.sem'
    fn_all_trn = 'all.trn'
    fn_all_trn_hdc_sem = 'all.trn.hdc.sem'
    fn_all_asr = 'all.asr'
    fn_all_nbl = 'all.nbl'

    fn_train_sem = 'train.sem'
    fn_train_trn = 'train.trn'
    fn_train_trn_hdc_sem = 'train.trn.hdc.sem'
    fn_train_asr = 'train.asr'
    fn_train_nbl = 'train.nbl'

    fn_dev_sem = 'dev.sem'
    fn_dev_trn = 'dev.trn'
    fn_dev_trn_hdc_sem = 'dev.trn.hdc.sem'
    fn_dev_asr = 'dev.asr'
    fn_dev_nbl = 'dev.nbl'

    fn_test_sem = 'test.sem'
    fn_test_trn = 'test.trn'
    fn_test_trn_hdc_sem = 'test.trn.hdc.sem'
    fn_test_asr = 'test.asr'
    fn_test_nbl = 'test.nbl'

    indomain_data_dir = "indomain_data"

    print "Generating the SLU train and test data"
    print "-" * 120
    ###############################################################################################

    files = []
    files.append(
        glob.glob(os.path.join(indomain_data_dir, 'asr_transcribed.xml')))
    files.append(
        glob.glob(os.path.join(indomain_data_dir, '*', 'asr_transcribed.xml')))
    files.append(
        glob.glob(
            os.path.join(indomain_data_dir, '*', '*', 'asr_transcribed.xml')))
    files.append(
        glob.glob(
            os.path.join(indomain_data_dir, '*', '*', '*',
                         'asr_transcribed.xml')))
    files.append(
        glob.glob(
            os.path.join(indomain_data_dir, '*', '*', '*', '*',
                         'asr_transcribed.xml')))
    files.append(
        glob.glob(
            os.path.join(indomain_data_dir, '*', '*', '*', '*', '*',
                         'asr_transcribed.xml')))
    files = various.flatten(files)

    files = files[:100000]
    asr = []
    nbl = []
    sem = []
    trn = []
    trn_hdc_sem = []

    p_process_call_logs = multiprocessing.Pool(num_workers)
    processed_cls = p_process_call_logs.imap_unordered(process_call_log, files)

    count = 0
    for pcl in processed_cls:
        count += 1
        #process_call_log(fn) # uniq utterances
        #print pcl

        print "=" * 80
        print "Processed files ", count, "/", len(files)
        print "=" * 80

        asr.extend(pcl[0])
        nbl.extend(pcl[1])
        sem.extend(pcl[2])
        trn.extend(pcl[3])
        trn_hdc_sem.extend(pcl[4])

    uniq_trn = {}
    uniq_trn_hdc_sem = {}
    uniq_trn_sem = {}
    trn_set = set()

    sem = dict(trn_hdc_sem)
    for k, v in trn:
        if not v in trn_set:
            trn_set.add(v)
            uniq_trn[k] = v
            uniq_trn_hdc_sem[k] = sem[k]
            uniq_trn_sem[k] = v + " <=> " + unicode(sem[k])

    save_wavaskey(fn_uniq_trn, uniq_trn)
    save_wavaskey(fn_uniq_trn_hdc_sem,
                  uniq_trn_hdc_sem,
                  trans=lambda da: '&'.join(sorted(unicode(da).split('&'))))
    save_wavaskey(fn_uniq_trn_sem, uniq_trn_sem)

    # all
    save_wavaskey(fn_all_trn, dict(trn))
    save_wavaskey(fn_all_trn_hdc_sem,
                  dict(trn_hdc_sem),
                  trans=lambda da: '&'.join(sorted(unicode(da).split('&'))))
    save_wavaskey(fn_all_asr, dict(asr))
    save_wavaskey(fn_all_nbl, dict(nbl))

    seed_value = 10

    random.seed(seed_value)
    random.shuffle(trn)
    random.seed(seed_value)
    random.shuffle(trn_hdc_sem)
    random.seed(seed_value)
    random.shuffle(asr)
    random.seed(seed_value)
    random.shuffle(nbl)

    # trn
    train_trn = trn[:int(0.8 * len(trn))]
    dev_trn = trn[int(0.8 * len(trn)):int(0.9 * len(trn))]
    test_trn = trn[int(0.9 * len(trn)):]

    save_wavaskey(fn_train_trn, dict(train_trn))
    save_wavaskey(fn_dev_trn, dict(dev_trn))
    save_wavaskey(fn_test_trn, dict(test_trn))

    # trn_hdc_sem
    train_trn_hdc_sem = trn_hdc_sem[:int(0.8 * len(trn_hdc_sem))]
    dev_trn_hdc_sem = trn_hdc_sem[int(0.8 *
                                      len(trn_hdc_sem)):int(0.9 *
                                                            len(trn_hdc_sem))]
    test_trn_hdc_sem = trn_hdc_sem[int(0.9 * len(trn_hdc_sem)):]

    save_wavaskey(fn_train_trn_hdc_sem,
                  dict(train_trn_hdc_sem),
                  trans=lambda da: '&'.join(sorted(unicode(da).split('&'))))
    save_wavaskey(fn_dev_trn_hdc_sem,
                  dict(dev_trn_hdc_sem),
                  trans=lambda da: '&'.join(sorted(unicode(da).split('&'))))
    save_wavaskey(fn_test_trn_hdc_sem,
                  dict(test_trn_hdc_sem),
                  trans=lambda da: '&'.join(sorted(unicode(da).split('&'))))

    # asr
    train_asr = asr[:int(0.8 * len(asr))]
    dev_asr = asr[int(0.8 * len(asr)):int(0.9 * len(asr))]
    test_asr = asr[int(0.9 * len(asr)):]

    save_wavaskey(fn_train_asr, dict(train_asr))
    save_wavaskey(fn_dev_asr, dict(dev_asr))
    save_wavaskey(fn_test_asr, dict(test_asr))

    # n-best lists
    train_nbl = nbl[:int(0.8 * len(nbl))]
    dev_nbl = nbl[int(0.8 * len(nbl)):int(0.9 * len(nbl))]
    test_nbl = nbl[int(0.9 * len(nbl)):]

    save_wavaskey(fn_train_nbl, dict(train_nbl))
    save_wavaskey(fn_dev_nbl, dict(dev_nbl))
    save_wavaskey(fn_test_nbl, dict(test_nbl))
コード例 #6
0
ファイル: prepare_data.py プロジェクト: zhangziliang04/alex
def main():
    cldb = CategoryLabelDatabase('../data/database.py')
    preprocessing = PTIENSLUPreprocessing(cldb)
    slu = PTIENHDCSLU(preprocessing, cfg={'SLU': {PTIENHDCSLU: {'utt2da': as_project_path("applications/PublicTransportInfoEN/data/utt2da_dict.txt")}}})
    cfg = Config.load_configs(['../kaldi.cfg',], use_default=True)
    asr_rec = asr_factory(cfg)                    

    fn_uniq_trn = 'uniq.trn'
    fn_uniq_trn_hdc_sem = 'uniq.trn.hdc.sem'
    fn_uniq_trn_sem = 'uniq.trn.sem'

    fn_all_sem = 'all.sem'
    fn_all_trn = 'all.trn'
    fn_all_trn_hdc_sem = 'all.trn.hdc.sem'
    fn_all_asr = 'all.asr'
    fn_all_asr_hdc_sem = 'all.asr.hdc.sem'
    fn_all_nbl = 'all.nbl'
    fn_all_nbl_hdc_sem = 'all.nbl.hdc.sem'

    fn_train_sem = 'train.sem'
    fn_train_trn = 'train.trn'
    fn_train_trn_hdc_sem = 'train.trn.hdc.sem'
    fn_train_asr = 'train.asr'
    fn_train_asr_hdc_sem = 'train.asr.hdc.sem'
    fn_train_nbl = 'train.nbl'
    fn_train_nbl_hdc_sem = 'train.nbl.hdc.sem'

    fn_dev_sem = 'dev.sem'
    fn_dev_trn = 'dev.trn'
    fn_dev_trn_hdc_sem = 'dev.trn.hdc.sem'
    fn_dev_asr = 'dev.asr'
    fn_dev_asr_hdc_sem = 'dev.asr.hdc.sem'
    fn_dev_nbl = 'dev.nbl'
    fn_dev_nbl_hdc_sem = 'dev.nbl.hdc.sem'

    fn_test_sem = 'test.sem'
    fn_test_trn = 'test.trn'
    fn_test_trn_hdc_sem = 'test.trn.hdc.sem'
    fn_test_asr = 'test.asr'
    fn_test_asr_hdc_sem = 'test.asr.hdc.sem'
    fn_test_nbl = 'test.nbl'
    fn_test_nbl_hdc_sem = 'test.nbl.hdc.sem'

    indomain_data_dir = "indomain_data"

    print "Generating the SLU train and test data"
    print "-"*120
    ###############################################################################################

    files = []
    files.append(glob.glob(os.path.join(indomain_data_dir, 'asr_transcribed.xml')))
    files.append(glob.glob(os.path.join(indomain_data_dir, '*', 'asr_transcribed.xml')))
    files.append(glob.glob(os.path.join(indomain_data_dir, '*', '*', 'asr_transcribed.xml')))
    files.append(glob.glob(os.path.join(indomain_data_dir, '*', '*', '*', 'asr_transcribed.xml')))
    files.append(glob.glob(os.path.join(indomain_data_dir, '*', '*', '*', '*', 'asr_transcribed.xml')))
    files.append(glob.glob(os.path.join(indomain_data_dir, '*', '*', '*', '*', '*', 'asr_transcribed.xml')))
    files = various.flatten(files)

    sem = []
    trn = []
    trn_hdc_sem = []
    asr = []
    asr_hdc_sem = []
    nbl = []
    nbl_hdc_sem = []

    for fn in files[:100000]:
        f_dir = os.path.dirname(fn)

        print "Processing:", fn
        doc = xml.dom.minidom.parse(fn)
        turns = doc.getElementsByTagName("turn")

        for i, turn in enumerate(turns):
            if turn.getAttribute('speaker') != 'user':
                continue

            recs = turn.getElementsByTagName("rec")
            trans = turn.getElementsByTagName("asr_transcription")
            asrs = turn.getElementsByTagName("asr")

            if len(recs) != 1:
                print "Skipping a turn {turn} in file: {fn} - recs: {recs}".format(turn=i,fn=fn, recs=len(recs))
                continue

            if len(asrs) == 0 and (i + 1) < len(turns):
                next_asrs = turns[i+1].getElementsByTagName("asr")
                if len(next_asrs) != 2:
                    print "Skipping a turn {turn} in file: {fn} - asrs: {asrs} - next_asrs: {next_asrs}".format(turn=i, fn=fn, asrs=len(asrs), next_asrs=len(next_asrs))
                    continue
                print "Recovered from missing ASR output by using a delayed ASR output from the following turn of turn {turn}. File: {fn} - next_asrs: {asrs}".format(turn=i, fn=fn, asrs=len(next_asrs))
                hyps = next_asrs[0].getElementsByTagName("hypothesis")
            elif len(asrs) == 1:
                hyps = asrs[0].getElementsByTagName("hypothesis")
            elif len(asrs) == 2:
                print "Recovered from EXTRA ASR outputs by using a the last ASR output from the turn. File: {fn} - asrs: {asrs}".format(fn=fn, asrs=len(asrs))
                hyps = asrs[-1].getElementsByTagName("hypothesis")
            else:
                print "Skipping a turn {turn} in file {fn} - asrs: {asrs}".format(turn=i,fn=fn, asrs=len(asrs))
                continue

            if len(trans) == 0:
                print "Skipping a turn in {fn} - trans: {trans}".format(fn=fn, trans=len(trans))
                continue

            wav_key = recs[0].getAttribute('fname')
            wav_path = os.path.join(f_dir, wav_key)
            
            # FIXME: Check whether the last transcription is really the best! FJ
            t = various.get_text_from_xml_node(trans[-1])
            t = normalise_text(t)

            
            if '--asr-log' not in sys.argv:
                asr_rec_nbl = asr_rec.rec_wav_file(wav_path)
                a = unicode(asr_rec_nbl.get_best())
            else:  
                a = various.get_text_from_xml_node(hyps[0])
                a = normalise_semi_words(a)

            if exclude_slu(t) or 'DOM Element:' in a:
                print "Skipping transcription:", unicode(t)
                print "Skipping ASR output:   ", unicode(a)
                continue

            # The silence does not have a label in the language model.
            t = t.replace('_SIL_','')

            trn.append((wav_key, t))

            print "Parsing transcription:", unicode(t)
            print "                  ASR:", unicode(a)

            # HDC SLU on transcription
            s = slu.parse_1_best({'utt':Utterance(t)}).get_best_da()
            trn_hdc_sem.append((wav_key, s))

            if '--uniq' not in sys.argv:
                # HDC SLU on 1 best ASR
                if '--asr-log' not in sys.argv:
                    a = unicode(asr_rec_nbl.get_best())
                else:  
                    a = various.get_text_from_xml_node(hyps[0])
                    a = normalise_semi_words(a)

                asr.append((wav_key, a))

                s = slu.parse_1_best({'utt':Utterance(a)}).get_best_da()
                asr_hdc_sem.append((wav_key, s))

                # HDC SLU on N best ASR
                n = UtteranceNBList()
                if '--asr-log' not in sys.argv:
                   n = asr_rec_nbl
                   
                   print 'ASR RECOGNITION NBLIST\n',unicode(n)
                else:
                    for h in hyps:
                        txt = various.get_text_from_xml_node(h)
                        txt = normalise_semi_words(txt)

                        n.add(abs(float(h.getAttribute('p'))),Utterance(txt))

                n.merge()
                n.normalise()

                nbl.append((wav_key, n.serialise()))

                if '--fast' not in sys.argv:
                    s = slu.parse_nblist({'utt_nbl':n}).get_best_da()
                nbl_hdc_sem.append((wav_key, s))

            # there is no manual semantics in the transcriptions yet
            sem.append((wav_key, None))


    uniq_trn = {}
    uniq_trn_hdc_sem = {}
    uniq_trn_sem = {}
    trn_set = set()

    sem = dict(trn_hdc_sem)
    for k, v in trn:
        if not v in trn_set:
            trn_set.add(v)
            uniq_trn[k] = v
            uniq_trn_hdc_sem[k] = sem[k]
            uniq_trn_sem[k] = v + " <=> " + unicode(sem[k])

    save_wavaskey(fn_uniq_trn, uniq_trn)
    save_wavaskey(fn_uniq_trn_hdc_sem, uniq_trn_hdc_sem, trans = lambda da: '&'.join(sorted(unicode(da).split('&'))))
    save_wavaskey(fn_uniq_trn_sem, uniq_trn_sem)

    # all
    save_wavaskey(fn_all_trn, dict(trn))
    save_wavaskey(fn_all_trn_hdc_sem, dict(trn_hdc_sem), trans = lambda da: '&'.join(sorted(unicode(da).split('&'))))

    if '--uniq' not in sys.argv:
        save_wavaskey(fn_all_asr, dict(asr))
        save_wavaskey(fn_all_asr_hdc_sem, dict(asr_hdc_sem), trans = lambda da: '&'.join(sorted(unicode(da).split('&'))))

        save_wavaskey(fn_all_nbl, dict(nbl))
        save_wavaskey(fn_all_nbl_hdc_sem, dict(nbl_hdc_sem), trans = lambda da: '&'.join(sorted(unicode(da).split('&'))))


        seed_value = 10

        random.seed(seed_value)
        random.shuffle(trn)
        random.seed(seed_value)
        random.shuffle(trn_hdc_sem)
        random.seed(seed_value)
        random.shuffle(asr)
        random.seed(seed_value)
        random.shuffle(asr_hdc_sem)
        random.seed(seed_value)
        random.shuffle(nbl)
        random.seed(seed_value)
        random.shuffle(nbl_hdc_sem)

        # trn
        train_trn = trn[:int(0.8*len(trn))]
        dev_trn = trn[int(0.8*len(trn)):int(0.9*len(trn))]
        test_trn = trn[int(0.9*len(trn)):]

        save_wavaskey(fn_train_trn, dict(train_trn))
        save_wavaskey(fn_dev_trn, dict(dev_trn))
        save_wavaskey(fn_test_trn, dict(test_trn))

        # trn_hdc_sem
        train_trn_hdc_sem = trn_hdc_sem[:int(0.8*len(trn_hdc_sem))]
        dev_trn_hdc_sem = trn_hdc_sem[int(0.8*len(trn_hdc_sem)):int(0.9*len(trn_hdc_sem))]
        test_trn_hdc_sem = trn_hdc_sem[int(0.9*len(trn_hdc_sem)):]

        save_wavaskey(fn_train_trn_hdc_sem, dict(train_trn_hdc_sem), trans = lambda da: '&'.join(sorted(unicode(da).split('&'))))
        save_wavaskey(fn_dev_trn_hdc_sem, dict(dev_trn_hdc_sem), trans = lambda da: '&'.join(sorted(unicode(da).split('&'))))
        save_wavaskey(fn_test_trn_hdc_sem, dict(test_trn_hdc_sem), trans = lambda da: '&'.join(sorted(unicode(da).split('&'))))

        # asr
        train_asr = asr[:int(0.8*len(asr))]
        dev_asr = asr[int(0.8*len(asr)):int(0.9*len(asr))]
        test_asr = asr[int(0.9*len(asr)):]

        save_wavaskey(fn_train_asr, dict(train_asr))
        save_wavaskey(fn_dev_asr, dict(dev_asr))
        save_wavaskey(fn_test_asr, dict(test_asr))

        # asr_hdc_sem
        train_asr_hdc_sem = asr_hdc_sem[:int(0.8*len(asr_hdc_sem))]
        dev_asr_hdc_sem = asr_hdc_sem[int(0.8*len(asr_hdc_sem)):int(0.9*len(asr_hdc_sem))]
        test_asr_hdc_sem = asr_hdc_sem[int(0.9*len(asr_hdc_sem)):]

        save_wavaskey(fn_train_asr_hdc_sem, dict(train_asr_hdc_sem), trans = lambda da: '&'.join(sorted(unicode(da).split('&'))))
        save_wavaskey(fn_dev_asr_hdc_sem, dict(dev_asr_hdc_sem), trans = lambda da: '&'.join(sorted(unicode(da).split('&'))))
        save_wavaskey(fn_test_asr_hdc_sem, dict(test_asr_hdc_sem), trans = lambda da: '&'.join(sorted(unicode(da).split('&'))))

        # n-best lists
        train_nbl = nbl[:int(0.8*len(nbl))]
        dev_nbl = nbl[int(0.8*len(nbl)):int(0.9*len(nbl))]
        test_nbl = nbl[int(0.9*len(nbl)):]

        save_wavaskey(fn_train_nbl, dict(train_nbl))
        save_wavaskey(fn_dev_nbl, dict(dev_nbl))
        save_wavaskey(fn_test_nbl, dict(test_nbl))

        # nbl_hdc_sem
        train_nbl_hdc_sem = nbl_hdc_sem[:int(0.8*len(nbl_hdc_sem))]
        dev_nbl_hdc_sem = nbl_hdc_sem[int(0.8*len(nbl_hdc_sem)):int(0.9*len(nbl_hdc_sem))]
        test_nbl_hdc_sem = nbl_hdc_sem[int(0.9*len(nbl_hdc_sem)):]

        save_wavaskey(fn_train_nbl_hdc_sem, dict(train_nbl_hdc_sem), trans = lambda da: '&'.join(sorted(unicode(da).split('&'))))
        save_wavaskey(fn_dev_nbl_hdc_sem, dict(dev_nbl_hdc_sem), trans = lambda da: '&'.join(sorted(unicode(da).split('&'))))
        save_wavaskey(fn_test_nbl_hdc_sem, dict(test_nbl_hdc_sem), trans = lambda da: '&'.join(sorted(unicode(da).split('&'))))
コード例 #7
0
ファイル: prepare_data.py プロジェクト: choko/alex
def main():
    global asr_log
    global num_workers

    parser = argparse.ArgumentParser(
        formatter_class=argparse.RawDescriptionHelpFormatter,
        description="""This program prepares data for training Alex PTIcs SLU.
      """)

    parser.add_argument('--num_workers', action="store", default=num_workers, type=int,
                        help='number of workers used for ASR: default %d' % num_workers)
    parser.add_argument('--asr_log', action="store", default=asr_log, type=int,
                        help='use ASR results from logs: default %d' % asr_log)

    args = parser.parse_args()

    asr_log = args.asr_log
    num_workers = args.num_workers

    fn_uniq_trn = 'uniq.trn'
    fn_uniq_trn_hdc_sem = 'uniq.trn.hdc.sem'
    fn_uniq_trn_sem = 'uniq.trn.sem'

    fn_all_sem = 'all.sem'
    fn_all_trn = 'all.trn'
    fn_all_trn_hdc_sem = 'all.trn.hdc.sem'
    fn_all_asr = 'all.asr'
    fn_all_nbl = 'all.nbl'

    fn_train_sem = 'train.sem'
    fn_train_trn = 'train.trn'
    fn_train_trn_hdc_sem = 'train.trn.hdc.sem'
    fn_train_asr = 'train.asr'
    fn_train_nbl = 'train.nbl'

    fn_dev_sem = 'dev.sem'
    fn_dev_trn = 'dev.trn'
    fn_dev_trn_hdc_sem = 'dev.trn.hdc.sem'
    fn_dev_asr = 'dev.asr'
    fn_dev_nbl = 'dev.nbl'

    fn_test_sem = 'test.sem'
    fn_test_trn = 'test.trn'
    fn_test_trn_hdc_sem = 'test.trn.hdc.sem'
    fn_test_asr = 'test.asr'
    fn_test_nbl = 'test.nbl'

    indomain_data_dir = "indomain_data"

    print "Generating the SLU train and test data"
    print "-"*120
    ###############################################################################################

    files = []
    files.append(glob.glob(os.path.join(indomain_data_dir, 'asr_transcribed.xml')))
    files.append(glob.glob(os.path.join(indomain_data_dir, '*', 'asr_transcribed.xml')))
    files.append(glob.glob(os.path.join(indomain_data_dir, '*', '*', 'asr_transcribed.xml')))
    files.append(glob.glob(os.path.join(indomain_data_dir, '*', '*', '*', 'asr_transcribed.xml')))
    files.append(glob.glob(os.path.join(indomain_data_dir, '*', '*', '*', '*', 'asr_transcribed.xml')))
    files.append(glob.glob(os.path.join(indomain_data_dir, '*', '*', '*', '*', '*', 'asr_transcribed.xml')))
    files = various.flatten(files)

    files = files[:100000]
    asr = []
    nbl = []
    sem = []
    trn = []
    trn_hdc_sem = []


    p_process_call_logs = multiprocessing.Pool(num_workers)
    processed_cls = p_process_call_logs.imap_unordered(process_call_log, files)

    count = 0
    for pcl in processed_cls:
        count += 1
        #process_call_log(fn) # uniq utterances
        #print pcl

        print "="*80
        print "Processed files ", count, "/", len(files)
        print "="*80

        asr.extend(pcl[0])
        nbl.extend(pcl[1])
        sem.extend(pcl[2])
        trn.extend(pcl[3])
        trn_hdc_sem.extend(pcl[4])

    uniq_trn = {}
    uniq_trn_hdc_sem = {}
    uniq_trn_sem = {}
    trn_set = set()

    sem = dict(trn_hdc_sem)
    for k, v in trn:
        if not v in trn_set:
            trn_set.add(v)
            uniq_trn[k] = v
            uniq_trn_hdc_sem[k] = sem[k]
            uniq_trn_sem[k] = v + " <=> " + unicode(sem[k])

    save_wavaskey(fn_uniq_trn, uniq_trn)
    save_wavaskey(fn_uniq_trn_hdc_sem, uniq_trn_hdc_sem, trans = lambda da: '&'.join(sorted(unicode(da).split('&'))))
    save_wavaskey(fn_uniq_trn_sem, uniq_trn_sem)

    # all
    save_wavaskey(fn_all_trn, dict(trn))
    save_wavaskey(fn_all_trn_hdc_sem, dict(trn_hdc_sem), trans = lambda da: '&'.join(sorted(unicode(da).split('&'))))
    save_wavaskey(fn_all_asr, dict(asr))
    save_wavaskey(fn_all_nbl, dict(nbl))

    seed_value = 10

    random.seed(seed_value)
    random.shuffle(trn)
    random.seed(seed_value)
    random.shuffle(trn_hdc_sem)
    random.seed(seed_value)
    random.shuffle(asr)
    random.seed(seed_value)
    random.shuffle(nbl)

    # trn
    train_trn = trn[:int(0.8*len(trn))]
    dev_trn = trn[int(0.8*len(trn)):int(0.9*len(trn))]
    test_trn = trn[int(0.9*len(trn)):]

    save_wavaskey(fn_train_trn, dict(train_trn))
    save_wavaskey(fn_dev_trn, dict(dev_trn))
    save_wavaskey(fn_test_trn, dict(test_trn))

    # trn_hdc_sem
    train_trn_hdc_sem = trn_hdc_sem[:int(0.8*len(trn_hdc_sem))]
    dev_trn_hdc_sem = trn_hdc_sem[int(0.8*len(trn_hdc_sem)):int(0.9*len(trn_hdc_sem))]
    test_trn_hdc_sem = trn_hdc_sem[int(0.9*len(trn_hdc_sem)):]

    save_wavaskey(fn_train_trn_hdc_sem, dict(train_trn_hdc_sem), trans = lambda da: '&'.join(sorted(unicode(da).split('&'))))
    save_wavaskey(fn_dev_trn_hdc_sem, dict(dev_trn_hdc_sem), trans = lambda da: '&'.join(sorted(unicode(da).split('&'))))
    save_wavaskey(fn_test_trn_hdc_sem, dict(test_trn_hdc_sem), trans = lambda da: '&'.join(sorted(unicode(da).split('&'))))

    # asr
    train_asr = asr[:int(0.8*len(asr))]
    dev_asr = asr[int(0.8*len(asr)):int(0.9*len(asr))]
    test_asr = asr[int(0.9*len(asr)):]

    save_wavaskey(fn_train_asr, dict(train_asr))
    save_wavaskey(fn_dev_asr, dict(dev_asr))
    save_wavaskey(fn_test_asr, dict(test_asr))

    # n-best lists
    train_nbl = nbl[:int(0.8*len(nbl))]
    dev_nbl = nbl[int(0.8*len(nbl)):int(0.9*len(nbl))]
    test_nbl = nbl[int(0.9*len(nbl)):]

    save_wavaskey(fn_train_nbl, dict(train_nbl))
    save_wavaskey(fn_dev_nbl, dict(dev_nbl))
    save_wavaskey(fn_test_nbl, dict(test_nbl))