コード例 #1
0
def hdc_slu(fn_input, constructor, fn_output):
    """
    Use for transcription a HDC SLU model.

    :param fn_model:
    :param fn_input:
    :param constructor:
    :param fn_reference:
    :return:
    """
    print "="*120
    print "HDC SLU: ", fn_input, fn_output
    print "-"*120

    from alex.components.slu.base import CategoryLabelDatabase
    from alex.applications.PublicTransportInfoCS.preprocessing import PTICSSLUPreprocessing
    from alex.applications.PublicTransportInfoCS.hdc_slu import PTICSHDCSLU
    from alex.corpustools.wavaskey import load_wavaskey, save_wavaskey
    from alex.corpustools.semscore import score

    cldb = CategoryLabelDatabase('../data/database.py')
    preprocessing = PTICSSLUPreprocessing(cldb)
    hdc_slu = PTICSHDCSLU(preprocessing, cfg = {'SLU': {PTICSHDCSLU: {'utt2da': as_project_path("applications/PublicTransportInfoCS/data/utt2da_dict.txt")}}})

    test_utterances = load_wavaskey(fn_input, constructor, limit=1000000)

    parsed_das = {}
    for utt_key, utt in sorted(test_utterances.iteritems()):
        if isinstance(utt, Utterance):
            obs = {'utt': utt}
        elif isinstance(utt, UtteranceNBList):
            obs = {'utt_nbl': utt}
        else:
            raise BaseException('Unsupported observation type')

        print '-' * 120
        print "Observation:"
        print utt_key, " ==> "
        print unicode(utt)

        da_confnet = hdc_slu.parse(obs, verbose=False)

        print "Conf net:"
        print unicode(da_confnet)

        da_confnet.prune()
        dah = da_confnet.get_best_da_hyp()

        print "1 best: "
        print unicode(dah)

        parsed_das[utt_key] = dah.da

        if 'CL_' in str(dah.da):
            print '*' * 120
            print utt
            print dah.da
            hdc_slu.parse(obs, verbose=True)

    save_wavaskey(fn_output, parsed_das, trans = lambda da: '&'.join(sorted(unicode(da).split('&'))))
コード例 #2
0
def main():
    import autopath

    files = []
    for i in range(1, len(sys.argv)):

        k = load_wavaskey(sys.argv[i], unicode)
        print sys.argv[i], len(k)
        files.append(k)

    keys = set()
    keys.update(set(files[0].keys()))
    ukeys = set()
    for f in files:
        keys = keys.intersection(set(f.keys()))
        ukeys = ukeys.union(set(f.keys()))

    print len(keys), len(ukeys), len(ukeys - keys)

    for f in files:
        rk = set(f.keys()) - keys
        for k in rk:
            if k in f:
                del f[k]

    for i in range(1, len(sys.argv)):
        save_wavaskey(sys.argv[i]+'.pruned',files[i-1])
コード例 #3
0
def main():

    files = []
    for i in range(1, len(sys.argv)):

        k = load_wavaskey(sys.argv[i], unicode)
        print sys.argv[i], len(k)
        files.append(k)

    keys = set()
    keys.update(set(files[0].keys()))
    ukeys = set()
    for f in files:
        keys = keys.intersection(set(f.keys()))
        ukeys = ukeys.union(set(f.keys()))

    print len(keys), len(ukeys), len(ukeys - keys)

    for f in files:
        rk = set(f.keys()) - keys
        for k in rk:
            if k in f:
                del f[k]

    for i in range(1, len(sys.argv)):
        save_wavaskey(sys.argv[i] + '.pruned', files[i - 1])
コード例 #4
0
ファイル: gen_uniq.py プロジェクト: UFAL-DSG/alex
def main():

    cldb = CategoryLabelDatabase('../data/database.py')
    preprocessing = PTICSSLUPreprocessing(cldb)
    slu = PTICSHDCSLU(preprocessing, cfg={'SLU': {PTICSHDCSLU: {'utt2da': as_project_path("applications/PublicTransportInfoCS/data/utt2da_dict.txt")}}})

    output_alignment = False
    output_utterance = True
    output_abstraction = False
    output_da = True

    if len(sys.argv) < 2:
        fn_uniq_trn = 'uniq.trn'
    else:
        fn_uniq_trn = sys.argv[1]
    fn_uniq_trn_sem = fn_uniq_trn + '.sem.tmp'

    print "Processing input from file", fn_uniq_trn
    uniq_trn = codecs.open(fn_uniq_trn, "r", encoding='utf8')
    uniq_trn_sem = {}
    for line in uniq_trn:
        wav_key, utterance = line.split(" => ", 2)
        annotation = []
        if output_alignment:
            norm_utterance = slu.preprocessing.normalise_utterance(Utterance(utterance))
            abutterance, _, _ = slu.abstract_utterance(norm_utterance)
            abutterance = slu.handle_false_abstractions(abutterance)
            da = slu.parse_1_best({'utt': Utterance(utterance)}).get_best_da()

            max_alignment_idx = lambda _dai: max(_dai.alignment) if _dai.alignment else len(abutterance)
            for i, dai in enumerate(sorted(da, key=max_alignment_idx)):
                if not dai.alignment:
                    print "Empty alignment:", unicode(abutterance), ";", dai

                if not dai.alignment or dai.alignment == {-1}:
                    dai_alignment_idx = len(abutterance)
                else:
                    dai_alignment_idx = max(dai.alignment) + i + 1
                abutterance.insert(dai_alignment_idx, "[{} - {}]".format(unicode(dai), list(dai.alignment if dai.alignment else [])))
            annotation += [unicode(abutterance)]
        else:
            if output_utterance:
                annotation += [utterance.rstrip()]
            if output_abstraction:
                norm_utterance = slu.preprocessing.normalise_utterance(Utterance(utterance))
                abutterance, _ = slu.abstract_utterance(norm_utterance)
                annotation += [abutterance]
            if output_da:
                da = slu.parse_1_best({'utt': Utterance(utterance)}).get_best_da()
                annotation += [unicode(da)]

        uniq_trn_sem[wav_key] = " <=> ".join(annotation)

    print "Saving output to file", fn_uniq_trn_sem
    save_wavaskey(fn_uniq_trn_sem, uniq_trn_sem)
コード例 #5
0
ファイル: gen_uniq.py プロジェクト: kangliqiang/alex
def main():

    cldb = CategoryLabelDatabase('../data/database.py')
    preprocessing = PTICSSLUPreprocessing(cldb)
    slu = PTICSHDCSLU(
        preprocessing,
        cfg={
            'SLU': {
                PTICSHDCSLU: {
                    'utt2da':
                    as_project_path(
                        "applications/PublicTransportInfoCS/data/utt2da_dict.txt"
                    )
                }
            }
        })

    output_utterance = True
    output_abstraction = False
    output_da = True

    fn_uniq_trn_sem = 'uniq.trn.sem.tmp'

    if len(sys.argv) < 2:
        fn_uniq_trn = 'uniq.trn'
    else:
        fn_uniq_trn = sys.argv[1]

    print "Processing input from file", fn_uniq_trn
    uniq_trn = codecs.open(fn_uniq_trn, "r", encoding='utf8')
    uniq_trn_sem = {}
    for line in uniq_trn:
        wav_key, utterance = line.split(" => ", 2)
        annotation = []
        if output_utterance:
            annotation += [utterance.rstrip()]
        if output_abstraction:
            norm_utterance = slu.preprocessing.normalise_utterance(utterance)
            abutterance, _ = slu.abstract_utterance(norm_utterance)
            annotation += [abutterance]
        if output_da:
            da = slu.parse_1_best({'utt': Utterance(utterance)}).get_best_da()
            annotation += [unicode(da)]

        uniq_trn_sem[wav_key] = " <=> ".join(annotation)

    print "Saving output to file", fn_uniq_trn_sem
    save_wavaskey(fn_uniq_trn_sem, uniq_trn_sem)
コード例 #6
0
ファイル: gen_uniq.py プロジェクト: beka-evature/alex
def main():
    import autopath

    cldb = CategoryLabelDatabase('../data/database.py')
    preprocessing = PTICSSLUPreprocessing(cldb)
    slu = PTICSHDCSLU(preprocessing, cfg = {'SLU': {PTICSHDCSLU: {'utt2da': as_project_path("applications/PublicTransportInfoCS/data/utt2da_dict.txt")}}})

    output_utterance = True
    output_abstraction = False
    output_da = True

    fn_uniq_trn_sem = 'uniq.trn.sem.tmp'

    if len(sys.argv) < 2:
        fn_uniq_trn = 'uniq.trn'
    else:
        fn_uniq_trn = sys.argv[1]

    print "Processing input from file", fn_uniq_trn
    uniq_trn = codecs.open(fn_uniq_trn, "r", encoding='utf8')
    uniq_trn_sem = {}
    for line in uniq_trn:
        wav_key, utterance = line.split(" => ", 2)
        annotation = []
        if output_utterance:
            annotation += [utterance.rstrip()]
        if output_abstraction:
            norm_utterance = slu.preprocessing.normalise_utterance(utterance)
            abutterance, _ = slu.abstract_utterance(norm_utterance)
            annotation += [abutterance]
        if output_da:
            da = slu.parse_1_best({'utt': Utterance(utterance)}).get_best_da()
            annotation += [unicode(da)]

        uniq_trn_sem[wav_key] = " <=> ".join(annotation)

    print "Saving output to file", fn_uniq_trn_sem
    save_wavaskey(fn_uniq_trn_sem, uniq_trn_sem)
コード例 #7
0
ファイル: decode_indomain.py プロジェクト: AoJ/alex
def compute_rt_factor(outdir, trn_dict, dec_dict, wavlen_dict, declen_dict, fwlen_dict):
    reference = os.path.join(outdir, 'ref_trn.txt')
    hypothesis = os.path.join(outdir, 'dec_trn.txt')
    save_wavaskey(reference, trn_dict)
    save_wavaskey(hypothesis, dec_dict)
    save_wavaskey(os.path.join(outdir, 'wavlen.txt'), wavlen_dict)
    save_wavaskey(os.path.join(outdir, 'dec_duration.txt'), declen_dict)

    rtf, latency, fw_rtf, fw_latency, d_tot, w_tot = [], [], [], [], 0, 0
    for k in declen_dict.keys():
        w, d, f = wavlen_dict[k], declen_dict[k], fwlen_dict[k]
        d_tot, w_tot = d_tot + d, w_tot + w
        rtf.append(float(d) / w)
        latency.append(float(d) - w)
        fw_rtf.append(float(f) / w)
        fw_latency.append(max(float(f) - w, 0))
    rtf_global = float(d_tot) / w_tot

    print
    print """    # waws:                  %d""" % len(rtf)
    print """    Global RTF mean:         %(rtfglob)f""" % {'rtfglob': rtf_global}

    try:
        rtf.sort()
        rm = rtf[int(len(rtf)*0.5)]
        rc95 = rtf[int(len(rtf)*0.95)]

        print """    RTF median:              %(median)f  RTF       < %(c95)f [in 95%%]""" % {'median': rm, 'c95': rc95}
    except:
        pass

    try:
        fw_rtf.sort()
        fm = fw_rtf[int(len(fw_rtf)*0.5)]
        fc95 = fw_rtf[int(len(fw_rtf)*0.95)]

        print """    Forward RTF median:      %(median)f  FWRTF     < %(c95)f [in 95%%]""" % {'median': fm, 'c95': fc95}
    except:
        pass

    try:
        latency.sort()
        lm = latency[int(len(latency)*0.5)]
        lc95 = latency[int(len(latency)*0.95)]

        print """    Latency median:          %(median)f  Latency   < %(c95)f [in 95%%]""" % {'median': lm, 'c95': lc95}
    except:
        pass

    try:
        fw_latency.sort()
        flm = fw_latency[int(len(fw_latency)*0.5)]
        flc95 = fw_latency[int(len(fw_latency)*0.95)]

        print """    Forward latency median:  %(median)f  FWLatency < %(c95)f [in 95%%]
    """ % {'median': flm, 'c95': flc95}
    except:
        pass

    try:
        print "    95%RTF={rtf:0.2f} 95%FWRTF={fwrtf:0.2f} " \
              "95%LAT={lat:0.2f} 95%FWLAT={fwlat:0.2f}".format(rtf=rc95, fwrtf=fc95, lat=lc95, fwlat=flc95)
        print
    except:
        pass
コード例 #8
0
def main():

    cldb = CategoryLabelDatabase('../data/database.py')

    f_dupl = 0
    f_subs = 0
    examples = defaultdict(list)
    
    for f in cldb.form2value2cl:
        if len(cldb.form2value2cl[f]) >= 2:
            f_dupl += 1
            
        if len(f) < 2:
            continue
            
        for w in f:
            w = (w,)
            if w in cldb.form2value2cl:
                for v in cldb.form2value2cl[w]:
                    for c in cldb.form2value2cl[w][v]:
                        cc = c
                        break
                
                print '{w},{cc} -> {f}'.format(w=w, cc=cc, f=' '.join(f))
                break
        else:
            continue
        
        f_subs += 1
        for v in cldb.form2value2cl[f]:
            for c in cldb.form2value2cl[f][v]:
                examples[(c,cc)].extend(inform(f,v,c))
                examples[(c,cc)].extend(confirm(f,v,c))

    print "There were {f} surface forms.".format(f=len(cldb.form2value2cl))
    print "There were {f_dupl} surface form duplicits.".format(f_dupl=f_dupl)
    print "There were {f_subs} surface forms with substring surface forms.".format(f_subs=f_subs)
    
    max_examples = 100
    ex = []
    for c in sorted(examples.keys()):             
        print c
        z = examples[c]
        if max_examples < len(z):
            z = random.sample(z, max_examples)
        for s, t in z:
            print ' - ', s, '<=>', t
            ex.append((s, t))
            
    examples = ex
    
    examples.sort()
    
    sem = {}
    trn = {}
    for i, e in enumerate(examples):
        key = 'bootstrap_gen_{i:06d}.wav'.format(i=i)
        sem[key] = e[0]
        trn[key] = e[1]

    save_wavaskey('bootstrap_gen.sem', sem)
    save_wavaskey('bootstrap_gen.trn', trn)
コード例 #9
0
def compute_rt_factor(outdir, trn_dict, dec_dict, wavlen_dict, declen_dict,
                      fwlen_dict):
    """
    Prints RTF statistics for decoding and (decoding + ASR extraction)

    Args:
        outdir(str): path to directory for the generated log files are saved.
        trn_dict(dict): (Wave name, transcription) dictionary
        dec_dict(dict): (Wave name, decoded transcription) dictionary
        wavlen_dict(dict): (Wave name, Wave length) dictionary
        declen_dict(dict): (Wave name, (decoding time + extraction time)) dictionary
        fwlen_dict(dict): (Wave name, decoding time) dictionary
    """

    reference = os.path.join(outdir, 'ref_trn.txt')
    hypothesis = os.path.join(outdir, 'dec_trn.txt')
    save_wavaskey(reference, trn_dict)
    save_wavaskey(hypothesis, dec_dict)
    save_wavaskey(os.path.join(outdir, 'wavlen.txt'), wavlen_dict)
    save_wavaskey(os.path.join(outdir, 'dec_duration.txt'), declen_dict)

    rtf, latency, fw_rtf, fw_latency, d_tot, w_tot = [], [], [], [], 0, 0
    for k in declen_dict.keys():
        w, d, f = wavlen_dict[k], declen_dict[k], fwlen_dict[k]
        d_tot, w_tot = d_tot + d, w_tot + w
        rtf.append(float(d) / w)
        latency.append(float(d) - w)
        fw_rtf.append(float(f) / w)
        fw_latency.append(max(float(f) - w, 0))
    rtf_global = float(d_tot) / w_tot

    print
    print """    # waws:                  %d""" % len(rtf)
    print """    Global RTF mean:         %(rtfglob)f""" % {
        'rtfglob': rtf_global
    }

    try:
        rtf.sort()
        rm = rtf[int(len(rtf) * 0.5)]
        rc95 = rtf[int(len(rtf) * 0.95)]

        print """    RTF median:              %(median)f  RTF       < %(c95)f [in 95%%]""" % {
            'median': rm,
            'c95': rc95
        }
    except:
        pass

    try:
        fw_rtf.sort()
        fm = fw_rtf[int(len(fw_rtf) * 0.5)]
        fc95 = fw_rtf[int(len(fw_rtf) * 0.95)]

        print """    Forward RTF median:      %(median)f  FWRTF     < %(c95)f [in 95%%]""" % {
            'median': fm,
            'c95': fc95
        }
    except:
        pass

    try:
        latency.sort()
        lm = latency[int(len(latency) * 0.5)]
        lc95 = latency[int(len(latency) * 0.95)]

        print """    Latency median:          %(median)f  Latency   < %(c95)f [in 95%%]""" % {
            'median': lm,
            'c95': lc95
        }
    except:
        pass

    try:
        fw_latency.sort()
        flm = fw_latency[int(len(fw_latency) * 0.5)]
        flc95 = fw_latency[int(len(fw_latency) * 0.95)]

        print """    Forward latency median:  %(median)f  FWLatency < %(c95)f [in 95%%]
    """ % {
            'median': flm,
            'c95': flc95
        }
    except:
        pass

    try:
        print "    95%RTF={rtf:0.2f} 95%FWRTF={fwrtf:0.2f} " \
              "95%LAT={lat:0.2f} 95%FWLAT={fwlat:0.2f}".format(rtf=rc95, fwrtf=fc95, lat=lc95, fwlat=flc95)
        print
    except:
        pass
コード例 #10
0
def hdc_slu_test(fn_input, constructor, fn_reference):
    """
    Tests the HDC SLU.

    :param fn_model:
    :param fn_input:
    :param constructor:
    :param fn_reference:
    :return:
    """
    print "=" * 120
    print "Testing HDC SLU: ", fn_input, fn_reference
    print "-" * 120

    from alex.components.slu.base import CategoryLabelDatabase
    from alex.applications.PublicTransportInfoCS.preprocessing import PTICSSLUPreprocessing
    from alex.applications.PublicTransportInfoCS.hdc_slu import PTICSHDCSLU
    from alex.corpustools.wavaskey import load_wavaskey, save_wavaskey
    from alex.corpustools.semscore import score

    cldb = CategoryLabelDatabase('../data/database.py')
    preprocessing = PTICSSLUPreprocessing(cldb)
    hdc_slu = PTICSHDCSLU(
        preprocessing,
        cfg={
            'SLU': {
                PTICSHDCSLU: {
                    'utt2da':
                    as_project_path(
                        "applications/PublicTransportInfoCS/data/utt2da_dict.txt"
                    )
                }
            }
        })

    test_utterances = load_wavaskey(fn_input, constructor, limit=100000)

    parsed_das = {}
    for utt_key, utt in sorted(test_utterances.iteritems()):
        if isinstance(utt, Utterance):
            obs = {'utt': utt}
        elif isinstance(utt, UtteranceNBList):
            obs = {'utt_nbl': utt}
        else:
            raise BaseException('Unsupported observation type')

        print '-' * 120
        print "Observation:"
        print utt_key, " ==> "
        print unicode(utt)

        da_confnet = hdc_slu.parse(obs, verbose=False)

        print "Conf net:"
        print unicode(da_confnet)

        da_confnet.prune()
        dah = da_confnet.get_best_da_hyp()

        print "1 best: "
        print unicode(dah)

        parsed_das[utt_key] = dah.da

        if 'CL_' in str(dah.da):
            print '*' * 120
            print utt
            print dah.da
            hdc_slu.parse(obs, verbose=True)

    fn_sem = os.path.basename(fn_input) + '.hdc.sem.out'

    save_wavaskey(fn_sem,
                  parsed_das,
                  trans=lambda da: '&'.join(sorted(unicode(da).split('&'))))

    f = codecs.open(os.path.basename(fn_sem) + '.score',
                    'w+',
                    encoding='UTF-8')
    score(fn_reference, fn_sem, True, True, f)
    f.close()
コード例 #11
0
        sf_train = sorted(sf[:int(train_data_size*len(sf))], key=lambda k: k[1][0])
        sf_dev = sorted(sf[int(train_data_size*len(sf)):], key=lambda k: k[1][0])

        t_train = [a for a, b in sf_train]
        pt_train = [b for a, b in sf_train]

        t_dev = [a for a, b in sf_dev]
        pt_dev = [b for a, b in sf_dev]

        with codecs.open(indomain_data_text_trn,"w", "UTF-8") as w:
            w.write('\n'.join(t_train))
        with codecs.open(indomain_data_text_dev,"w", "UTF-8") as w:
            w.write('\n'.join(t_dev))

        save_wavaskey(fn_pt_trn, dict(pt_train))
        save_wavaskey(fn_pt_dev, dict(pt_dev))

        # train data
        cmd = r"cat %s %s | iconv -f UTF-8 -t UTF-8//IGNORE | sed 's/\. /\n/g' | sed 's/[[:digit:]]/ /g; s/[^[:alnum:]_]/ /g; s/[ˇ]/ /g; s/ \+/ /g' | sed 's/[[:lower:]]*/\U&/g' | sed s/[\%s→€…│]//g > %s" % \
              (bootstrap_text,
               indomain_data_text_trn,
               "'",
               indomain_data_text_trn_norm)

        print cmd
        os.system(cmd)

        # dev data
        cmd = r"cat %s | iconv -f UTF-8 -t UTF-8//IGNORE | sed 's/\. /\n/g' | sed 's/[[:digit:]]/ /g; s/[^[:alnum:]_]/ /g; s/[ˇ]/ /g; s/ \+/ /g' | sed 's/[[:lower:]]*/\U&/g' | sed s/[\%s→€…│]//g > %s" % \
              (indomain_data_text_dev,
コード例 #12
0
ファイル: da.py プロジェクト: oplatek/alex
def save_das(file_name, das, encoding = 'UTF-8'):
    save_wavaskey(file_name, das, encoding)
コード例 #13
0
ファイル: test.py プロジェクト: elnaaz/alex
def hdc_slu_test(fn_input, constructor, fn_reference):
    """
    Tests a SLU DAILogRegClassifier model.

    :param fn_model:
    :param fn_input:
    :param constructor:
    :param fn_reference:
    :return:
    """
    print "="*120
    print "Testing HDC SLU: ", fn_input, fn_reference
    print "-"*120

    from alex.components.slu.base import CategoryLabelDatabase
    from alex.applications.PublicTransportInfoCS.preprocessing import PTICSSLUPreprocessing
    from alex.applications.PublicTransportInfoCS.hdc_slu import PTICSHDCSLU
    from alex.corpustools.wavaskey import load_wavaskey, save_wavaskey
    from alex.corpustools.semscore import score

    cldb = CategoryLabelDatabase('../data/database.py')
    preprocessing = PTICSSLUPreprocessing(cldb)
    hdc_slu = PTICSHDCSLU(preprocessing)

    test_utterances = load_wavaskey(fn_input, constructor, limit=100000)

    parsed_das = {}
    for utt_key, utt in sorted(test_utterances.iteritems()):
        if isinstance(utt, Utterance):
            obs = {'utt': utt}
        elif isinstance(utt, UtteranceNBList):
            obs = {'utt_nbl': utt}
        else:
            raise BaseException('Unsupported observation type')

        print '-' * 120
        print "Observation:"
        print utt_key, " ==> "
        print unicode(utt)

        da_confnet = hdc_slu.parse(obs, verbose=False)

        print "Conf net:"
        print unicode(da_confnet)

        da_confnet.prune()
        dah = da_confnet.get_best_da_hyp()

        print "1 best: "
        print unicode(dah)

        parsed_das[utt_key] = dah.da

        if 'CL_' in str(dah.da):
            print '*' * 120
            print utt
            print dah.da
            hdc_slu.parse(obs, verbose=True)

    fn_sem = os.path.basename(fn_input)+'.hdc.sem.out'

    save_wavaskey(fn_sem, parsed_das, trans = lambda da: '&'.join(sorted(unicode(da).split('&'))))

    f = codecs.open(os.path.basename(fn_sem)+'.score', 'w+', encoding='UTF-8')
    score(fn_reference, fn_sem, True, True, f)
    f.close()
コード例 #14
0
ファイル: decode_indomain.py プロジェクト: beka-evature/alex
def compute_rt_factor(outdir, trn_dict, dec_dict, wavlen_dict, declen_dict, fwlen_dict):
    """
    Prints RTF statistics for decoding and (decoding + ASR extraction)

    Args:
        outdir(str): path to directory for the generated log files are saved.
        trn_dict(dict): (Wave name, transcription) dictionary
        dec_dict(dict): (Wave name, decoded transcription) dictionary
        wavlen_dict(dict): (Wave name, Wave length) dictionary
        declen_dict(dict): (Wave name, (decoding time + extraction time)) dictionary
        fwlen_dict(dict): (Wave name, decoding time) dictionary
    """

    reference = os.path.join(outdir, 'ref_trn.txt')
    hypothesis = os.path.join(outdir, 'dec_trn.txt')
    save_wavaskey(reference, trn_dict)
    save_wavaskey(hypothesis, dec_dict)
    save_wavaskey(os.path.join(outdir, 'wavlen.txt'), wavlen_dict)
    save_wavaskey(os.path.join(outdir, 'dec_duration.txt'), declen_dict)

    rtf, latency, fw_rtf, fw_latency, d_tot, w_tot = [], [], [], [], 0, 0
    for k in declen_dict.keys():
        w, d, f = wavlen_dict[k], declen_dict[k], fwlen_dict[k]
        d_tot, w_tot = d_tot + d, w_tot + w
        rtf.append(float(d) / w)
        latency.append(float(d) - w)
        fw_rtf.append(float(f) / w)
        fw_latency.append(max(float(f) - w, 0))
    rtf_global = float(d_tot) / w_tot

    print
    print """    # waws:                  %d""" % len(rtf)
    print """    Global RTF mean:         %(rtfglob)f""" % {'rtfglob': rtf_global}

    try:
        rtf.sort()
        rm = rtf[int(len(rtf)*0.5)]
        rc95 = rtf[int(len(rtf)*0.95)]

        print """    RTF median:              %(median)f  RTF       < %(c95)f [in 95%%]""" % {'median': rm, 'c95': rc95}
    except:
        pass

    try:
        fw_rtf.sort()
        fm = fw_rtf[int(len(fw_rtf)*0.5)]
        fc95 = fw_rtf[int(len(fw_rtf)*0.95)]

        print """    Forward RTF median:      %(median)f  FWRTF     < %(c95)f [in 95%%]""" % {'median': fm, 'c95': fc95}
    except:
        pass

    try:
        latency.sort()
        lm = latency[int(len(latency)*0.5)]
        lc95 = latency[int(len(latency)*0.95)]

        print """    Latency median:          %(median)f  Latency   < %(c95)f [in 95%%]""" % {'median': lm, 'c95': lc95}
    except:
        pass

    try:
        fw_latency.sort()
        flm = fw_latency[int(len(fw_latency)*0.5)]
        flc95 = fw_latency[int(len(fw_latency)*0.95)]

        print """    Forward latency median:  %(median)f  FWLatency < %(c95)f [in 95%%]
    """ % {'median': flm, 'c95': flc95}
    except:
        pass

    try:
        print "    95%RTF={rtf:0.2f} 95%FWRTF={fwrtf:0.2f} " \
              "95%LAT={lat:0.2f} 95%FWLAT={fwlat:0.2f}".format(rtf=rc95, fwrtf=fc95, lat=lc95, fwlat=flc95)
        print
    except:
        pass
コード例 #15
0
def main():

    global asr_log
    global num_workers

    parser = argparse.ArgumentParser(
        formatter_class=argparse.RawDescriptionHelpFormatter,
        description="""This program prepares data for training Alex PTIcs SLU.
      """)

    parser.add_argument('--num_workers',
                        action="store",
                        default=num_workers,
                        type=int,
                        help='number of workers used for ASR: default %d' %
                        num_workers)
    parser.add_argument('--asr_log',
                        action="store",
                        default=asr_log,
                        type=int,
                        help='use ASR results from logs: default %d' % asr_log)

    args = parser.parse_args()

    asr_log = args.asr_log
    num_workers = args.num_workers

    fn_uniq_trn = 'uniq.trn'
    fn_uniq_trn_hdc_sem = 'uniq.trn.hdc.sem'
    fn_uniq_trn_sem = 'uniq.trn.sem'

    fn_all_sem = 'all.sem'
    fn_all_trn = 'all.trn'
    fn_all_trn_hdc_sem = 'all.trn.hdc.sem'
    fn_all_asr = 'all.asr'
    fn_all_nbl = 'all.nbl'

    fn_train_sem = 'train.sem'
    fn_train_trn = 'train.trn'
    fn_train_trn_hdc_sem = 'train.trn.hdc.sem'
    fn_train_asr = 'train.asr'
    fn_train_nbl = 'train.nbl'

    fn_dev_sem = 'dev.sem'
    fn_dev_trn = 'dev.trn'
    fn_dev_trn_hdc_sem = 'dev.trn.hdc.sem'
    fn_dev_asr = 'dev.asr'
    fn_dev_nbl = 'dev.nbl'

    fn_test_sem = 'test.sem'
    fn_test_trn = 'test.trn'
    fn_test_trn_hdc_sem = 'test.trn.hdc.sem'
    fn_test_asr = 'test.asr'
    fn_test_nbl = 'test.nbl'

    indomain_data_dir = "indomain_data"

    print "Generating the SLU train and test data"
    print "-" * 120
    ###############################################################################################

    files = []
    files.append(
        glob.glob(os.path.join(indomain_data_dir, 'asr_transcribed.xml')))
    files.append(
        glob.glob(os.path.join(indomain_data_dir, '*', 'asr_transcribed.xml')))
    files.append(
        glob.glob(
            os.path.join(indomain_data_dir, '*', '*', 'asr_transcribed.xml')))
    files.append(
        glob.glob(
            os.path.join(indomain_data_dir, '*', '*', '*',
                         'asr_transcribed.xml')))
    files.append(
        glob.glob(
            os.path.join(indomain_data_dir, '*', '*', '*', '*',
                         'asr_transcribed.xml')))
    files.append(
        glob.glob(
            os.path.join(indomain_data_dir, '*', '*', '*', '*', '*',
                         'asr_transcribed.xml')))
    files = various.flatten(files)

    files = files[:100000]
    asr = []
    nbl = []
    sem = []
    trn = []
    trn_hdc_sem = []

    p_process_call_logs = multiprocessing.Pool(num_workers)
    processed_cls = p_process_call_logs.imap_unordered(process_call_log, files)

    count = 0
    for pcl in processed_cls:
        count += 1
        #process_call_log(fn) # uniq utterances
        #print pcl

        print "=" * 80
        print "Processed files ", count, "/", len(files)
        print "=" * 80

        asr.extend(pcl[0])
        nbl.extend(pcl[1])
        sem.extend(pcl[2])
        trn.extend(pcl[3])
        trn_hdc_sem.extend(pcl[4])

    uniq_trn = {}
    uniq_trn_hdc_sem = {}
    uniq_trn_sem = {}
    trn_set = set()

    sem = dict(trn_hdc_sem)
    for k, v in trn:
        if not v in trn_set:
            trn_set.add(v)
            uniq_trn[k] = v
            uniq_trn_hdc_sem[k] = sem[k]
            uniq_trn_sem[k] = v + " <=> " + unicode(sem[k])

    save_wavaskey(fn_uniq_trn, uniq_trn)
    save_wavaskey(fn_uniq_trn_hdc_sem,
                  uniq_trn_hdc_sem,
                  trans=lambda da: '&'.join(sorted(unicode(da).split('&'))))
    save_wavaskey(fn_uniq_trn_sem, uniq_trn_sem)

    # all
    save_wavaskey(fn_all_trn, dict(trn))
    save_wavaskey(fn_all_trn_hdc_sem,
                  dict(trn_hdc_sem),
                  trans=lambda da: '&'.join(sorted(unicode(da).split('&'))))
    save_wavaskey(fn_all_asr, dict(asr))
    save_wavaskey(fn_all_nbl, dict(nbl))

    seed_value = 10

    random.seed(seed_value)
    random.shuffle(trn)
    random.seed(seed_value)
    random.shuffle(trn_hdc_sem)
    random.seed(seed_value)
    random.shuffle(asr)
    random.seed(seed_value)
    random.shuffle(nbl)

    # trn
    train_trn = trn[:int(0.8 * len(trn))]
    dev_trn = trn[int(0.8 * len(trn)):int(0.9 * len(trn))]
    test_trn = trn[int(0.9 * len(trn)):]

    save_wavaskey(fn_train_trn, dict(train_trn))
    save_wavaskey(fn_dev_trn, dict(dev_trn))
    save_wavaskey(fn_test_trn, dict(test_trn))

    # trn_hdc_sem
    train_trn_hdc_sem = trn_hdc_sem[:int(0.8 * len(trn_hdc_sem))]
    dev_trn_hdc_sem = trn_hdc_sem[int(0.8 *
                                      len(trn_hdc_sem)):int(0.9 *
                                                            len(trn_hdc_sem))]
    test_trn_hdc_sem = trn_hdc_sem[int(0.9 * len(trn_hdc_sem)):]

    save_wavaskey(fn_train_trn_hdc_sem,
                  dict(train_trn_hdc_sem),
                  trans=lambda da: '&'.join(sorted(unicode(da).split('&'))))
    save_wavaskey(fn_dev_trn_hdc_sem,
                  dict(dev_trn_hdc_sem),
                  trans=lambda da: '&'.join(sorted(unicode(da).split('&'))))
    save_wavaskey(fn_test_trn_hdc_sem,
                  dict(test_trn_hdc_sem),
                  trans=lambda da: '&'.join(sorted(unicode(da).split('&'))))

    # asr
    train_asr = asr[:int(0.8 * len(asr))]
    dev_asr = asr[int(0.8 * len(asr)):int(0.9 * len(asr))]
    test_asr = asr[int(0.9 * len(asr)):]

    save_wavaskey(fn_train_asr, dict(train_asr))
    save_wavaskey(fn_dev_asr, dict(dev_asr))
    save_wavaskey(fn_test_asr, dict(test_asr))

    # n-best lists
    train_nbl = nbl[:int(0.8 * len(nbl))]
    dev_nbl = nbl[int(0.8 * len(nbl)):int(0.9 * len(nbl))]
    test_nbl = nbl[int(0.9 * len(nbl)):]

    save_wavaskey(fn_train_nbl, dict(train_nbl))
    save_wavaskey(fn_dev_nbl, dict(dev_nbl))
    save_wavaskey(fn_test_nbl, dict(test_nbl))
コード例 #16
0
ファイル: test.py プロジェクト: tkraut/alex
def trained_slu_test(fn_model, fn_input, constructor, fn_reference):
    """
    Tests a SLU DAILogRegClassifier model.

    :param fn_model:
    :param fn_input:
    :param constructor:
    :param fn_reference:
    :return:
    """
    print "="*120
    print "Testing: ", fn_model, fn_input, fn_reference
    print "-"*120

    from alex.applications.PublicTransportInfoCS.preprocessing import PTICSSLUPreprocessing
    from alex.components.slu.base import CategoryLabelDatabase
    from alex.components.slu.dailrclassifier import DAILogRegClassifier
    from alex.corpustools.wavaskey import load_wavaskey, save_wavaskey
    from alex.corpustools.semscore import score

    cldb = CategoryLabelDatabase('../data/database.py')
    preprocessing = PTICSSLUPreprocessing(cldb)
    slu = DAILogRegClassifier(cldb, preprocessing)

    slu.load_model(fn_model)

    test_utterances = load_wavaskey(fn_input, constructor, limit=100000)

    parsed_das = {}
    for utt_key, utt in sorted(test_utterances.iteritems()):
        if isinstance(utt, Utterance):
            obs = {'utt': utt}
        elif isinstance(utt, UtteranceNBList):
            obs = {'utt_nbl': utt}
        else:
            raise BaseException('Unsupported observation type')

        print '-' * 120
        print "Observation:"
        print utt_key, " ==> "
        print unicode(utt)

        da_confnet = slu.parse(obs, verbose=False)

        print "Conf net:"
        print unicode(da_confnet)

        da_confnet.prune()
        dah = da_confnet.get_best_da_hyp()

        print "1 best: "
        print unicode(dah)

        parsed_das[utt_key] = dah.da

        if 'CL_' in str(dah.da):
            print '*' * 120
            print utt
            print dah.da
            slu.parse(obs, verbose=True)

    if 'trn' in fn_model:
        fn_sem = os.path.basename(fn_input)+'.model.trn.sem.out'
    elif 'asr' in fn_model:
        fn_sem = os.path.basename(fn_input)+'.model.asr.sem.out'
    elif 'nbl' in fn_model:
        fn_sem = os.path.basename(fn_input)+'.model.nbl.sem.out'
    else:
        fn_sem = os.path.basename(fn_input)+'.XXX.sem.out'

    save_wavaskey(fn_sem, parsed_das, trans = lambda da: '&'.join(sorted(unicode(da).split('&'))))

    f = codecs.open(os.path.basename(fn_sem)+'.score', 'w+', encoding='UTF-8')
    score(fn_reference, fn_sem, True, True, f)
    f.close()
コード例 #17
0
def save_das(file_name, das, encoding='UTF-8'):
    save_wavaskey(file_name, das, encoding)
コード例 #18
0
def main():

    cldb = CategoryLabelDatabase('../data/database.py')
    preprocessing = PTICSSLUPreprocessing(cldb)
    slu = PTICSHDCSLU(
        preprocessing,
        cfg={
            'SLU': {
                PTICSHDCSLU: {
                    'utt2da':
                    as_project_path(
                        "applications/PublicTransportInfoCS/data/utt2da_dict.txt"
                    )
                }
            }
        })

    output_alignment = False
    output_utterance = True
    output_abstraction = False
    output_da = True

    if len(sys.argv) < 2:
        fn_uniq_trn = 'uniq.trn'
    else:
        fn_uniq_trn = sys.argv[1]
    fn_uniq_trn_sem = fn_uniq_trn + '.sem.tmp'

    print "Processing input from file", fn_uniq_trn
    uniq_trn = codecs.open(fn_uniq_trn, "r", encoding='utf8')
    uniq_trn_sem = {}
    for line in uniq_trn:
        wav_key, utterance = line.split(" => ", 2)
        annotation = []
        if output_alignment:
            norm_utterance = slu.preprocessing.normalise_utterance(
                Utterance(utterance))
            abutterance, _, _ = slu.abstract_utterance(norm_utterance)
            abutterance = slu.handle_false_abstractions(abutterance)
            da = slu.parse_1_best({'utt': Utterance(utterance)}).get_best_da()

            max_alignment_idx = lambda _dai: max(
                _dai.alignment) if _dai.alignment else len(abutterance)
            for i, dai in enumerate(sorted(da, key=max_alignment_idx)):
                if not dai.alignment:
                    print "Empty alignment:", unicode(abutterance), ";", dai

                if not dai.alignment or dai.alignment == {-1}:
                    dai_alignment_idx = len(abutterance)
                else:
                    dai_alignment_idx = max(dai.alignment) + i + 1
                abutterance.insert(
                    dai_alignment_idx, "[{} - {}]".format(
                        unicode(dai),
                        list(dai.alignment if dai.alignment else [])))
            annotation += [unicode(abutterance)]
        else:
            if output_utterance:
                annotation += [utterance.rstrip()]
            if output_abstraction:
                norm_utterance = slu.preprocessing.normalise_utterance(
                    Utterance(utterance))
                abutterance, _ = slu.abstract_utterance(norm_utterance)
                annotation += [abutterance]
            if output_da:
                da = slu.parse_1_best({
                    'utt': Utterance(utterance)
                }).get_best_da()
                annotation += [unicode(da)]

        uniq_trn_sem[wav_key] = " <=> ".join(annotation)

    print "Saving output to file", fn_uniq_trn_sem
    save_wavaskey(fn_uniq_trn_sem, uniq_trn_sem)
コード例 #19
0
ファイル: prepare_data.py プロジェクト: zhangziliang04/alex
def main():
    cldb = CategoryLabelDatabase('../data/database.py')
    preprocessing = PTIENSLUPreprocessing(cldb)
    slu = PTIENHDCSLU(preprocessing, cfg={'SLU': {PTIENHDCSLU: {'utt2da': as_project_path("applications/PublicTransportInfoEN/data/utt2da_dict.txt")}}})
    cfg = Config.load_configs(['../kaldi.cfg',], use_default=True)
    asr_rec = asr_factory(cfg)                    

    fn_uniq_trn = 'uniq.trn'
    fn_uniq_trn_hdc_sem = 'uniq.trn.hdc.sem'
    fn_uniq_trn_sem = 'uniq.trn.sem'

    fn_all_sem = 'all.sem'
    fn_all_trn = 'all.trn'
    fn_all_trn_hdc_sem = 'all.trn.hdc.sem'
    fn_all_asr = 'all.asr'
    fn_all_asr_hdc_sem = 'all.asr.hdc.sem'
    fn_all_nbl = 'all.nbl'
    fn_all_nbl_hdc_sem = 'all.nbl.hdc.sem'

    fn_train_sem = 'train.sem'
    fn_train_trn = 'train.trn'
    fn_train_trn_hdc_sem = 'train.trn.hdc.sem'
    fn_train_asr = 'train.asr'
    fn_train_asr_hdc_sem = 'train.asr.hdc.sem'
    fn_train_nbl = 'train.nbl'
    fn_train_nbl_hdc_sem = 'train.nbl.hdc.sem'

    fn_dev_sem = 'dev.sem'
    fn_dev_trn = 'dev.trn'
    fn_dev_trn_hdc_sem = 'dev.trn.hdc.sem'
    fn_dev_asr = 'dev.asr'
    fn_dev_asr_hdc_sem = 'dev.asr.hdc.sem'
    fn_dev_nbl = 'dev.nbl'
    fn_dev_nbl_hdc_sem = 'dev.nbl.hdc.sem'

    fn_test_sem = 'test.sem'
    fn_test_trn = 'test.trn'
    fn_test_trn_hdc_sem = 'test.trn.hdc.sem'
    fn_test_asr = 'test.asr'
    fn_test_asr_hdc_sem = 'test.asr.hdc.sem'
    fn_test_nbl = 'test.nbl'
    fn_test_nbl_hdc_sem = 'test.nbl.hdc.sem'

    indomain_data_dir = "indomain_data"

    print "Generating the SLU train and test data"
    print "-"*120
    ###############################################################################################

    files = []
    files.append(glob.glob(os.path.join(indomain_data_dir, 'asr_transcribed.xml')))
    files.append(glob.glob(os.path.join(indomain_data_dir, '*', 'asr_transcribed.xml')))
    files.append(glob.glob(os.path.join(indomain_data_dir, '*', '*', 'asr_transcribed.xml')))
    files.append(glob.glob(os.path.join(indomain_data_dir, '*', '*', '*', 'asr_transcribed.xml')))
    files.append(glob.glob(os.path.join(indomain_data_dir, '*', '*', '*', '*', 'asr_transcribed.xml')))
    files.append(glob.glob(os.path.join(indomain_data_dir, '*', '*', '*', '*', '*', 'asr_transcribed.xml')))
    files = various.flatten(files)

    sem = []
    trn = []
    trn_hdc_sem = []
    asr = []
    asr_hdc_sem = []
    nbl = []
    nbl_hdc_sem = []

    for fn in files[:100000]:
        f_dir = os.path.dirname(fn)

        print "Processing:", fn
        doc = xml.dom.minidom.parse(fn)
        turns = doc.getElementsByTagName("turn")

        for i, turn in enumerate(turns):
            if turn.getAttribute('speaker') != 'user':
                continue

            recs = turn.getElementsByTagName("rec")
            trans = turn.getElementsByTagName("asr_transcription")
            asrs = turn.getElementsByTagName("asr")

            if len(recs) != 1:
                print "Skipping a turn {turn} in file: {fn} - recs: {recs}".format(turn=i,fn=fn, recs=len(recs))
                continue

            if len(asrs) == 0 and (i + 1) < len(turns):
                next_asrs = turns[i+1].getElementsByTagName("asr")
                if len(next_asrs) != 2:
                    print "Skipping a turn {turn} in file: {fn} - asrs: {asrs} - next_asrs: {next_asrs}".format(turn=i, fn=fn, asrs=len(asrs), next_asrs=len(next_asrs))
                    continue
                print "Recovered from missing ASR output by using a delayed ASR output from the following turn of turn {turn}. File: {fn} - next_asrs: {asrs}".format(turn=i, fn=fn, asrs=len(next_asrs))
                hyps = next_asrs[0].getElementsByTagName("hypothesis")
            elif len(asrs) == 1:
                hyps = asrs[0].getElementsByTagName("hypothesis")
            elif len(asrs) == 2:
                print "Recovered from EXTRA ASR outputs by using a the last ASR output from the turn. File: {fn} - asrs: {asrs}".format(fn=fn, asrs=len(asrs))
                hyps = asrs[-1].getElementsByTagName("hypothesis")
            else:
                print "Skipping a turn {turn} in file {fn} - asrs: {asrs}".format(turn=i,fn=fn, asrs=len(asrs))
                continue

            if len(trans) == 0:
                print "Skipping a turn in {fn} - trans: {trans}".format(fn=fn, trans=len(trans))
                continue

            wav_key = recs[0].getAttribute('fname')
            wav_path = os.path.join(f_dir, wav_key)
            
            # FIXME: Check whether the last transcription is really the best! FJ
            t = various.get_text_from_xml_node(trans[-1])
            t = normalise_text(t)

            
            if '--asr-log' not in sys.argv:
                asr_rec_nbl = asr_rec.rec_wav_file(wav_path)
                a = unicode(asr_rec_nbl.get_best())
            else:  
                a = various.get_text_from_xml_node(hyps[0])
                a = normalise_semi_words(a)

            if exclude_slu(t) or 'DOM Element:' in a:
                print "Skipping transcription:", unicode(t)
                print "Skipping ASR output:   ", unicode(a)
                continue

            # The silence does not have a label in the language model.
            t = t.replace('_SIL_','')

            trn.append((wav_key, t))

            print "Parsing transcription:", unicode(t)
            print "                  ASR:", unicode(a)

            # HDC SLU on transcription
            s = slu.parse_1_best({'utt':Utterance(t)}).get_best_da()
            trn_hdc_sem.append((wav_key, s))

            if '--uniq' not in sys.argv:
                # HDC SLU on 1 best ASR
                if '--asr-log' not in sys.argv:
                    a = unicode(asr_rec_nbl.get_best())
                else:  
                    a = various.get_text_from_xml_node(hyps[0])
                    a = normalise_semi_words(a)

                asr.append((wav_key, a))

                s = slu.parse_1_best({'utt':Utterance(a)}).get_best_da()
                asr_hdc_sem.append((wav_key, s))

                # HDC SLU on N best ASR
                n = UtteranceNBList()
                if '--asr-log' not in sys.argv:
                   n = asr_rec_nbl
                   
                   print 'ASR RECOGNITION NBLIST\n',unicode(n)
                else:
                    for h in hyps:
                        txt = various.get_text_from_xml_node(h)
                        txt = normalise_semi_words(txt)

                        n.add(abs(float(h.getAttribute('p'))),Utterance(txt))

                n.merge()
                n.normalise()

                nbl.append((wav_key, n.serialise()))

                if '--fast' not in sys.argv:
                    s = slu.parse_nblist({'utt_nbl':n}).get_best_da()
                nbl_hdc_sem.append((wav_key, s))

            # there is no manual semantics in the transcriptions yet
            sem.append((wav_key, None))


    uniq_trn = {}
    uniq_trn_hdc_sem = {}
    uniq_trn_sem = {}
    trn_set = set()

    sem = dict(trn_hdc_sem)
    for k, v in trn:
        if not v in trn_set:
            trn_set.add(v)
            uniq_trn[k] = v
            uniq_trn_hdc_sem[k] = sem[k]
            uniq_trn_sem[k] = v + " <=> " + unicode(sem[k])

    save_wavaskey(fn_uniq_trn, uniq_trn)
    save_wavaskey(fn_uniq_trn_hdc_sem, uniq_trn_hdc_sem, trans = lambda da: '&'.join(sorted(unicode(da).split('&'))))
    save_wavaskey(fn_uniq_trn_sem, uniq_trn_sem)

    # all
    save_wavaskey(fn_all_trn, dict(trn))
    save_wavaskey(fn_all_trn_hdc_sem, dict(trn_hdc_sem), trans = lambda da: '&'.join(sorted(unicode(da).split('&'))))

    if '--uniq' not in sys.argv:
        save_wavaskey(fn_all_asr, dict(asr))
        save_wavaskey(fn_all_asr_hdc_sem, dict(asr_hdc_sem), trans = lambda da: '&'.join(sorted(unicode(da).split('&'))))

        save_wavaskey(fn_all_nbl, dict(nbl))
        save_wavaskey(fn_all_nbl_hdc_sem, dict(nbl_hdc_sem), trans = lambda da: '&'.join(sorted(unicode(da).split('&'))))


        seed_value = 10

        random.seed(seed_value)
        random.shuffle(trn)
        random.seed(seed_value)
        random.shuffle(trn_hdc_sem)
        random.seed(seed_value)
        random.shuffle(asr)
        random.seed(seed_value)
        random.shuffle(asr_hdc_sem)
        random.seed(seed_value)
        random.shuffle(nbl)
        random.seed(seed_value)
        random.shuffle(nbl_hdc_sem)

        # trn
        train_trn = trn[:int(0.8*len(trn))]
        dev_trn = trn[int(0.8*len(trn)):int(0.9*len(trn))]
        test_trn = trn[int(0.9*len(trn)):]

        save_wavaskey(fn_train_trn, dict(train_trn))
        save_wavaskey(fn_dev_trn, dict(dev_trn))
        save_wavaskey(fn_test_trn, dict(test_trn))

        # trn_hdc_sem
        train_trn_hdc_sem = trn_hdc_sem[:int(0.8*len(trn_hdc_sem))]
        dev_trn_hdc_sem = trn_hdc_sem[int(0.8*len(trn_hdc_sem)):int(0.9*len(trn_hdc_sem))]
        test_trn_hdc_sem = trn_hdc_sem[int(0.9*len(trn_hdc_sem)):]

        save_wavaskey(fn_train_trn_hdc_sem, dict(train_trn_hdc_sem), trans = lambda da: '&'.join(sorted(unicode(da).split('&'))))
        save_wavaskey(fn_dev_trn_hdc_sem, dict(dev_trn_hdc_sem), trans = lambda da: '&'.join(sorted(unicode(da).split('&'))))
        save_wavaskey(fn_test_trn_hdc_sem, dict(test_trn_hdc_sem), trans = lambda da: '&'.join(sorted(unicode(da).split('&'))))

        # asr
        train_asr = asr[:int(0.8*len(asr))]
        dev_asr = asr[int(0.8*len(asr)):int(0.9*len(asr))]
        test_asr = asr[int(0.9*len(asr)):]

        save_wavaskey(fn_train_asr, dict(train_asr))
        save_wavaskey(fn_dev_asr, dict(dev_asr))
        save_wavaskey(fn_test_asr, dict(test_asr))

        # asr_hdc_sem
        train_asr_hdc_sem = asr_hdc_sem[:int(0.8*len(asr_hdc_sem))]
        dev_asr_hdc_sem = asr_hdc_sem[int(0.8*len(asr_hdc_sem)):int(0.9*len(asr_hdc_sem))]
        test_asr_hdc_sem = asr_hdc_sem[int(0.9*len(asr_hdc_sem)):]

        save_wavaskey(fn_train_asr_hdc_sem, dict(train_asr_hdc_sem), trans = lambda da: '&'.join(sorted(unicode(da).split('&'))))
        save_wavaskey(fn_dev_asr_hdc_sem, dict(dev_asr_hdc_sem), trans = lambda da: '&'.join(sorted(unicode(da).split('&'))))
        save_wavaskey(fn_test_asr_hdc_sem, dict(test_asr_hdc_sem), trans = lambda da: '&'.join(sorted(unicode(da).split('&'))))

        # n-best lists
        train_nbl = nbl[:int(0.8*len(nbl))]
        dev_nbl = nbl[int(0.8*len(nbl)):int(0.9*len(nbl))]
        test_nbl = nbl[int(0.9*len(nbl)):]

        save_wavaskey(fn_train_nbl, dict(train_nbl))
        save_wavaskey(fn_dev_nbl, dict(dev_nbl))
        save_wavaskey(fn_test_nbl, dict(test_nbl))

        # nbl_hdc_sem
        train_nbl_hdc_sem = nbl_hdc_sem[:int(0.8*len(nbl_hdc_sem))]
        dev_nbl_hdc_sem = nbl_hdc_sem[int(0.8*len(nbl_hdc_sem)):int(0.9*len(nbl_hdc_sem))]
        test_nbl_hdc_sem = nbl_hdc_sem[int(0.9*len(nbl_hdc_sem)):]

        save_wavaskey(fn_train_nbl_hdc_sem, dict(train_nbl_hdc_sem), trans = lambda da: '&'.join(sorted(unicode(da).split('&'))))
        save_wavaskey(fn_dev_nbl_hdc_sem, dict(dev_nbl_hdc_sem), trans = lambda da: '&'.join(sorted(unicode(da).split('&'))))
        save_wavaskey(fn_test_nbl_hdc_sem, dict(test_nbl_hdc_sem), trans = lambda da: '&'.join(sorted(unicode(da).split('&'))))
コード例 #20
0
ファイル: build.py プロジェクト: jakub-stejskal/alex
        sf_train = sorted(sf[:int(train_data_size*len(sf))], key=lambda k: k[1][0])
        sf_dev = sorted(sf[int(train_data_size*len(sf)):], key=lambda k: k[1][0])

        t_train = [a for a, b in sf_train]
        pt_train = [b for a, b in sf_train]

        t_dev = [a for a, b in sf_dev]
        pt_dev = [b for a, b in sf_dev]

        with codecs.open(indomain_data_text_trn,"w", "UTF-8") as w:
            w.write('\n'.join(t_train))
        with codecs.open(indomain_data_text_dev,"w", "UTF-8") as w:
            w.write('\n'.join(t_dev))

        save_wavaskey(fn_pt_trn, dict(pt_train))
        save_wavaskey(fn_pt_dev, dict(pt_dev))

        # train data
        cmd = r"cat %s %s | iconv -f UTF-8 -t UTF-8//IGNORE | sed 's/\. /\n/g' | sed 's/[[:digit:]]/ /g; s/[^[:alnum:]_]/ /g; s/[ˇ]/ /g; s/ \+/ /g' | sed 's/[[:lower:]]*/\U&/g' | sed s/[\%s→€…│]//g > %s" % \
              (bootstrap_text,
               indomain_data_text_trn,
               "'",
               indomain_data_text_trn_norm)

        print cmd
        exit_on_system_fail(cmd)

        # dev data
        cmd = r"cat %s | iconv -f UTF-8 -t UTF-8//IGNORE | sed 's/\. /\n/g' | sed 's/[[:digit:]]/ /g; s/[^[:alnum:]_]/ /g; s/[ˇ]/ /g; s/ \+/ /g' | sed 's/[[:lower:]]*/\U&/g' | sed s/[\%s→€…│]//g > %s" % \
              (indomain_data_text_dev,
コード例 #21
0
ファイル: prepare_data.py プロジェクト: choko/alex
def main():
    global asr_log
    global num_workers

    parser = argparse.ArgumentParser(
        formatter_class=argparse.RawDescriptionHelpFormatter,
        description="""This program prepares data for training Alex PTIcs SLU.
      """)

    parser.add_argument('--num_workers', action="store", default=num_workers, type=int,
                        help='number of workers used for ASR: default %d' % num_workers)
    parser.add_argument('--asr_log', action="store", default=asr_log, type=int,
                        help='use ASR results from logs: default %d' % asr_log)

    args = parser.parse_args()

    asr_log = args.asr_log
    num_workers = args.num_workers

    fn_uniq_trn = 'uniq.trn'
    fn_uniq_trn_hdc_sem = 'uniq.trn.hdc.sem'
    fn_uniq_trn_sem = 'uniq.trn.sem'

    fn_all_sem = 'all.sem'
    fn_all_trn = 'all.trn'
    fn_all_trn_hdc_sem = 'all.trn.hdc.sem'
    fn_all_asr = 'all.asr'
    fn_all_nbl = 'all.nbl'

    fn_train_sem = 'train.sem'
    fn_train_trn = 'train.trn'
    fn_train_trn_hdc_sem = 'train.trn.hdc.sem'
    fn_train_asr = 'train.asr'
    fn_train_nbl = 'train.nbl'

    fn_dev_sem = 'dev.sem'
    fn_dev_trn = 'dev.trn'
    fn_dev_trn_hdc_sem = 'dev.trn.hdc.sem'
    fn_dev_asr = 'dev.asr'
    fn_dev_nbl = 'dev.nbl'

    fn_test_sem = 'test.sem'
    fn_test_trn = 'test.trn'
    fn_test_trn_hdc_sem = 'test.trn.hdc.sem'
    fn_test_asr = 'test.asr'
    fn_test_nbl = 'test.nbl'

    indomain_data_dir = "indomain_data"

    print "Generating the SLU train and test data"
    print "-"*120
    ###############################################################################################

    files = []
    files.append(glob.glob(os.path.join(indomain_data_dir, 'asr_transcribed.xml')))
    files.append(glob.glob(os.path.join(indomain_data_dir, '*', 'asr_transcribed.xml')))
    files.append(glob.glob(os.path.join(indomain_data_dir, '*', '*', 'asr_transcribed.xml')))
    files.append(glob.glob(os.path.join(indomain_data_dir, '*', '*', '*', 'asr_transcribed.xml')))
    files.append(glob.glob(os.path.join(indomain_data_dir, '*', '*', '*', '*', 'asr_transcribed.xml')))
    files.append(glob.glob(os.path.join(indomain_data_dir, '*', '*', '*', '*', '*', 'asr_transcribed.xml')))
    files = various.flatten(files)

    files = files[:100000]
    asr = []
    nbl = []
    sem = []
    trn = []
    trn_hdc_sem = []


    p_process_call_logs = multiprocessing.Pool(num_workers)
    processed_cls = p_process_call_logs.imap_unordered(process_call_log, files)

    count = 0
    for pcl in processed_cls:
        count += 1
        #process_call_log(fn) # uniq utterances
        #print pcl

        print "="*80
        print "Processed files ", count, "/", len(files)
        print "="*80

        asr.extend(pcl[0])
        nbl.extend(pcl[1])
        sem.extend(pcl[2])
        trn.extend(pcl[3])
        trn_hdc_sem.extend(pcl[4])

    uniq_trn = {}
    uniq_trn_hdc_sem = {}
    uniq_trn_sem = {}
    trn_set = set()

    sem = dict(trn_hdc_sem)
    for k, v in trn:
        if not v in trn_set:
            trn_set.add(v)
            uniq_trn[k] = v
            uniq_trn_hdc_sem[k] = sem[k]
            uniq_trn_sem[k] = v + " <=> " + unicode(sem[k])

    save_wavaskey(fn_uniq_trn, uniq_trn)
    save_wavaskey(fn_uniq_trn_hdc_sem, uniq_trn_hdc_sem, trans = lambda da: '&'.join(sorted(unicode(da).split('&'))))
    save_wavaskey(fn_uniq_trn_sem, uniq_trn_sem)

    # all
    save_wavaskey(fn_all_trn, dict(trn))
    save_wavaskey(fn_all_trn_hdc_sem, dict(trn_hdc_sem), trans = lambda da: '&'.join(sorted(unicode(da).split('&'))))
    save_wavaskey(fn_all_asr, dict(asr))
    save_wavaskey(fn_all_nbl, dict(nbl))

    seed_value = 10

    random.seed(seed_value)
    random.shuffle(trn)
    random.seed(seed_value)
    random.shuffle(trn_hdc_sem)
    random.seed(seed_value)
    random.shuffle(asr)
    random.seed(seed_value)
    random.shuffle(nbl)

    # trn
    train_trn = trn[:int(0.8*len(trn))]
    dev_trn = trn[int(0.8*len(trn)):int(0.9*len(trn))]
    test_trn = trn[int(0.9*len(trn)):]

    save_wavaskey(fn_train_trn, dict(train_trn))
    save_wavaskey(fn_dev_trn, dict(dev_trn))
    save_wavaskey(fn_test_trn, dict(test_trn))

    # trn_hdc_sem
    train_trn_hdc_sem = trn_hdc_sem[:int(0.8*len(trn_hdc_sem))]
    dev_trn_hdc_sem = trn_hdc_sem[int(0.8*len(trn_hdc_sem)):int(0.9*len(trn_hdc_sem))]
    test_trn_hdc_sem = trn_hdc_sem[int(0.9*len(trn_hdc_sem)):]

    save_wavaskey(fn_train_trn_hdc_sem, dict(train_trn_hdc_sem), trans = lambda da: '&'.join(sorted(unicode(da).split('&'))))
    save_wavaskey(fn_dev_trn_hdc_sem, dict(dev_trn_hdc_sem), trans = lambda da: '&'.join(sorted(unicode(da).split('&'))))
    save_wavaskey(fn_test_trn_hdc_sem, dict(test_trn_hdc_sem), trans = lambda da: '&'.join(sorted(unicode(da).split('&'))))

    # asr
    train_asr = asr[:int(0.8*len(asr))]
    dev_asr = asr[int(0.8*len(asr)):int(0.9*len(asr))]
    test_asr = asr[int(0.9*len(asr)):]

    save_wavaskey(fn_train_asr, dict(train_asr))
    save_wavaskey(fn_dev_asr, dict(dev_asr))
    save_wavaskey(fn_test_asr, dict(test_asr))

    # n-best lists
    train_nbl = nbl[:int(0.8*len(nbl))]
    dev_nbl = nbl[int(0.8*len(nbl)):int(0.9*len(nbl))]
    test_nbl = nbl[int(0.9*len(nbl)):]

    save_wavaskey(fn_train_nbl, dict(train_nbl))
    save_wavaskey(fn_dev_nbl, dict(dev_nbl))
    save_wavaskey(fn_test_nbl, dict(test_nbl))