コード例 #1
0
def process_file(file_path):

    cldb = CategoryLabelDatabase(as_project_path('applications/PublicTransportInfoCS/data/database.py'))
    preprocessing = PTICSSLUPreprocessing(cldb)
    hdc_slu = PTICSHDCSLU(preprocessing, cfg = {'SLU': {PTICSHDCSLU: {'utt2da': as_project_path('applications/PublicTransportInfoCS/data/utt2da_dict.txt')}}})
    stdout = codecs.getwriter('UTF-8')(sys.stdout)

    with open(file_path, 'r') as fh:
        for line in codecs.getreader('UTF-8')(fh):
            line = line.strip("\r\n")

            # skip empty lines (dialogue boundaries)
            if not line:
                continue

            person, da, utt = line.split("\t")
            # skip system utterances, use just user utterances
            if 'SYSTEM' in person:
                continue

            # reparse utterance using transcription
            utt = re.sub(r',', r' ', utt)
            utt = Utterance(utt)
            sem = hdc_slu.parse({'utt': utt})

            # get abstracted utterance text
            abutt = hdc_slu.abstract_utterance(utt)
            abutt_str = get_abutt_str(utt, abutt)

            # get abstracted DA
            best_da = sem.get_best_da()
            best_da_str = unicode(best_da)
            abstract_da(best_da)

            print >> stdout, unicode(utt) + "\t" + abutt_str + "\t" + best_da_str + "\t" + unicode(best_da)
コード例 #2
0
def hdc_slu(fn_input, constructor, fn_output):
    """
    Use for transcription a HDC SLU model.

    :param fn_model:
    :param fn_input:
    :param constructor:
    :param fn_reference:
    :return:
    """
    print "="*120
    print "HDC SLU: ", fn_input, fn_output
    print "-"*120

    from alex.components.slu.base import CategoryLabelDatabase
    from alex.applications.PublicTransportInfoCS.preprocessing import PTICSSLUPreprocessing
    from alex.applications.PublicTransportInfoCS.hdc_slu import PTICSHDCSLU
    from alex.corpustools.wavaskey import load_wavaskey, save_wavaskey
    from alex.corpustools.semscore import score

    cldb = CategoryLabelDatabase('../data/database.py')
    preprocessing = PTICSSLUPreprocessing(cldb)
    hdc_slu = PTICSHDCSLU(preprocessing, cfg = {'SLU': {PTICSHDCSLU: {'utt2da': as_project_path("applications/PublicTransportInfoCS/data/utt2da_dict.txt")}}})

    test_utterances = load_wavaskey(fn_input, constructor, limit=1000000)

    parsed_das = {}
    for utt_key, utt in sorted(test_utterances.iteritems()):
        if isinstance(utt, Utterance):
            obs = {'utt': utt}
        elif isinstance(utt, UtteranceNBList):
            obs = {'utt_nbl': utt}
        else:
            raise BaseException('Unsupported observation type')

        print '-' * 120
        print "Observation:"
        print utt_key, " ==> "
        print unicode(utt)

        da_confnet = hdc_slu.parse(obs, verbose=False)

        print "Conf net:"
        print unicode(da_confnet)

        da_confnet.prune()
        dah = da_confnet.get_best_da_hyp()

        print "1 best: "
        print unicode(dah)

        parsed_das[utt_key] = dah.da

        if 'CL_' in str(dah.da):
            print '*' * 120
            print utt
            print dah.da
            hdc_slu.parse(obs, verbose=True)

    save_wavaskey(fn_output, parsed_das, trans = lambda da: '&'.join(sorted(unicode(da).split('&'))))
コード例 #3
0
ファイル: gen_uniq.py プロジェクト: kangliqiang/alex
def main():

    cldb = CategoryLabelDatabase('../data/database.py')
    preprocessing = PTICSSLUPreprocessing(cldb)
    slu = PTICSHDCSLU(
        preprocessing,
        cfg={
            'SLU': {
                PTICSHDCSLU: {
                    'utt2da':
                    as_project_path(
                        "applications/PublicTransportInfoCS/data/utt2da_dict.txt"
                    )
                }
            }
        })

    output_utterance = True
    output_abstraction = False
    output_da = True

    fn_uniq_trn_sem = 'uniq.trn.sem.tmp'

    if len(sys.argv) < 2:
        fn_uniq_trn = 'uniq.trn'
    else:
        fn_uniq_trn = sys.argv[1]

    print "Processing input from file", fn_uniq_trn
    uniq_trn = codecs.open(fn_uniq_trn, "r", encoding='utf8')
    uniq_trn_sem = {}
    for line in uniq_trn:
        wav_key, utterance = line.split(" => ", 2)
        annotation = []
        if output_utterance:
            annotation += [utterance.rstrip()]
        if output_abstraction:
            norm_utterance = slu.preprocessing.normalise_utterance(utterance)
            abutterance, _ = slu.abstract_utterance(norm_utterance)
            annotation += [abutterance]
        if output_da:
            da = slu.parse_1_best({'utt': Utterance(utterance)}).get_best_da()
            annotation += [unicode(da)]

        uniq_trn_sem[wav_key] = " <=> ".join(annotation)

    print "Saving output to file", fn_uniq_trn_sem
    save_wavaskey(fn_uniq_trn_sem, uniq_trn_sem)
コード例 #4
0
ファイル: gen_uniq.py プロジェクト: beka-evature/alex
def main():
    import autopath

    cldb = CategoryLabelDatabase('../data/database.py')
    preprocessing = PTICSSLUPreprocessing(cldb)
    slu = PTICSHDCSLU(preprocessing, cfg = {'SLU': {PTICSHDCSLU: {'utt2da': as_project_path("applications/PublicTransportInfoCS/data/utt2da_dict.txt")}}})

    output_utterance = True
    output_abstraction = False
    output_da = True

    fn_uniq_trn_sem = 'uniq.trn.sem.tmp'

    if len(sys.argv) < 2:
        fn_uniq_trn = 'uniq.trn'
    else:
        fn_uniq_trn = sys.argv[1]

    print "Processing input from file", fn_uniq_trn
    uniq_trn = codecs.open(fn_uniq_trn, "r", encoding='utf8')
    uniq_trn_sem = {}
    for line in uniq_trn:
        wav_key, utterance = line.split(" => ", 2)
        annotation = []
        if output_utterance:
            annotation += [utterance.rstrip()]
        if output_abstraction:
            norm_utterance = slu.preprocessing.normalise_utterance(utterance)
            abutterance, _ = slu.abstract_utterance(norm_utterance)
            annotation += [abutterance]
        if output_da:
            da = slu.parse_1_best({'utt': Utterance(utterance)}).get_best_da()
            annotation += [unicode(da)]

        uniq_trn_sem[wav_key] = " <=> ".join(annotation)

    print "Saving output to file", fn_uniq_trn_sem
    save_wavaskey(fn_uniq_trn_sem, uniq_trn_sem)
コード例 #5
0
ファイル: gen_uniq.py プロジェクト: UFAL-DSG/alex
def main():

    cldb = CategoryLabelDatabase('../data/database.py')
    preprocessing = PTICSSLUPreprocessing(cldb)
    slu = PTICSHDCSLU(preprocessing, cfg={'SLU': {PTICSHDCSLU: {'utt2da': as_project_path("applications/PublicTransportInfoCS/data/utt2da_dict.txt")}}})

    output_alignment = False
    output_utterance = True
    output_abstraction = False
    output_da = True

    if len(sys.argv) < 2:
        fn_uniq_trn = 'uniq.trn'
    else:
        fn_uniq_trn = sys.argv[1]
    fn_uniq_trn_sem = fn_uniq_trn + '.sem.tmp'

    print "Processing input from file", fn_uniq_trn
    uniq_trn = codecs.open(fn_uniq_trn, "r", encoding='utf8')
    uniq_trn_sem = {}
    for line in uniq_trn:
        wav_key, utterance = line.split(" => ", 2)
        annotation = []
        if output_alignment:
            norm_utterance = slu.preprocessing.normalise_utterance(Utterance(utterance))
            abutterance, _, _ = slu.abstract_utterance(norm_utterance)
            abutterance = slu.handle_false_abstractions(abutterance)
            da = slu.parse_1_best({'utt': Utterance(utterance)}).get_best_da()

            max_alignment_idx = lambda _dai: max(_dai.alignment) if _dai.alignment else len(abutterance)
            for i, dai in enumerate(sorted(da, key=max_alignment_idx)):
                if not dai.alignment:
                    print "Empty alignment:", unicode(abutterance), ";", dai

                if not dai.alignment or dai.alignment == {-1}:
                    dai_alignment_idx = len(abutterance)
                else:
                    dai_alignment_idx = max(dai.alignment) + i + 1
                abutterance.insert(dai_alignment_idx, "[{} - {}]".format(unicode(dai), list(dai.alignment if dai.alignment else [])))
            annotation += [unicode(abutterance)]
        else:
            if output_utterance:
                annotation += [utterance.rstrip()]
            if output_abstraction:
                norm_utterance = slu.preprocessing.normalise_utterance(Utterance(utterance))
                abutterance, _ = slu.abstract_utterance(norm_utterance)
                annotation += [abutterance]
            if output_da:
                da = slu.parse_1_best({'utt': Utterance(utterance)}).get_best_da()
                annotation += [unicode(da)]

        uniq_trn_sem[wav_key] = " <=> ".join(annotation)

    print "Saving output to file", fn_uniq_trn_sem
    save_wavaskey(fn_uniq_trn_sem, uniq_trn_sem)
コード例 #6
0
def hdc_slu_test(fn_input, constructor, fn_reference):
    """
    Tests the HDC SLU.

    :param fn_model:
    :param fn_input:
    :param constructor:
    :param fn_reference:
    :return:
    """
    print "=" * 120
    print "Testing HDC SLU: ", fn_input, fn_reference
    print "-" * 120

    from alex.components.slu.base import CategoryLabelDatabase
    from alex.applications.PublicTransportInfoCS.preprocessing import PTICSSLUPreprocessing
    from alex.applications.PublicTransportInfoCS.hdc_slu import PTICSHDCSLU
    from alex.corpustools.wavaskey import load_wavaskey, save_wavaskey
    from alex.corpustools.semscore import score

    cldb = CategoryLabelDatabase('../data/database.py')
    preprocessing = PTICSSLUPreprocessing(cldb)
    hdc_slu = PTICSHDCSLU(
        preprocessing,
        cfg={
            'SLU': {
                PTICSHDCSLU: {
                    'utt2da':
                    as_project_path(
                        "applications/PublicTransportInfoCS/data/utt2da_dict.txt"
                    )
                }
            }
        })

    test_utterances = load_wavaskey(fn_input, constructor, limit=100000)

    parsed_das = {}
    for utt_key, utt in sorted(test_utterances.iteritems()):
        if isinstance(utt, Utterance):
            obs = {'utt': utt}
        elif isinstance(utt, UtteranceNBList):
            obs = {'utt_nbl': utt}
        else:
            raise BaseException('Unsupported observation type')

        print '-' * 120
        print "Observation:"
        print utt_key, " ==> "
        print unicode(utt)

        da_confnet = hdc_slu.parse(obs, verbose=False)

        print "Conf net:"
        print unicode(da_confnet)

        da_confnet.prune()
        dah = da_confnet.get_best_da_hyp()

        print "1 best: "
        print unicode(dah)

        parsed_das[utt_key] = dah.da

        if 'CL_' in str(dah.da):
            print '*' * 120
            print utt
            print dah.da
            hdc_slu.parse(obs, verbose=True)

    fn_sem = os.path.basename(fn_input) + '.hdc.sem.out'

    save_wavaskey(fn_sem,
                  parsed_das,
                  trans=lambda da: '&'.join(sorted(unicode(da).split('&'))))

    f = codecs.open(os.path.basename(fn_sem) + '.score',
                    'w+',
                    encoding='UTF-8')
    score(fn_reference, fn_sem, True, True, f)
    f.close()
コード例 #7
0
def main():

    cldb = CategoryLabelDatabase('../data/database.py')
    preprocessing = PTICSSLUPreprocessing(cldb)
    slu = PTICSHDCSLU(
        preprocessing,
        cfg={
            'SLU': {
                PTICSHDCSLU: {
                    'utt2da':
                    as_project_path(
                        "applications/PublicTransportInfoCS/data/utt2da_dict.txt"
                    )
                }
            }
        })

    output_alignment = False
    output_utterance = True
    output_abstraction = False
    output_da = True

    if len(sys.argv) < 2:
        fn_uniq_trn = 'uniq.trn'
    else:
        fn_uniq_trn = sys.argv[1]
    fn_uniq_trn_sem = fn_uniq_trn + '.sem.tmp'

    print "Processing input from file", fn_uniq_trn
    uniq_trn = codecs.open(fn_uniq_trn, "r", encoding='utf8')
    uniq_trn_sem = {}
    for line in uniq_trn:
        wav_key, utterance = line.split(" => ", 2)
        annotation = []
        if output_alignment:
            norm_utterance = slu.preprocessing.normalise_utterance(
                Utterance(utterance))
            abutterance, _, _ = slu.abstract_utterance(norm_utterance)
            abutterance = slu.handle_false_abstractions(abutterance)
            da = slu.parse_1_best({'utt': Utterance(utterance)}).get_best_da()

            max_alignment_idx = lambda _dai: max(
                _dai.alignment) if _dai.alignment else len(abutterance)
            for i, dai in enumerate(sorted(da, key=max_alignment_idx)):
                if not dai.alignment:
                    print "Empty alignment:", unicode(abutterance), ";", dai

                if not dai.alignment or dai.alignment == {-1}:
                    dai_alignment_idx = len(abutterance)
                else:
                    dai_alignment_idx = max(dai.alignment) + i + 1
                abutterance.insert(
                    dai_alignment_idx, "[{} - {}]".format(
                        unicode(dai),
                        list(dai.alignment if dai.alignment else [])))
            annotation += [unicode(abutterance)]
        else:
            if output_utterance:
                annotation += [utterance.rstrip()]
            if output_abstraction:
                norm_utterance = slu.preprocessing.normalise_utterance(
                    Utterance(utterance))
                abutterance, _ = slu.abstract_utterance(norm_utterance)
                annotation += [abutterance]
            if output_da:
                da = slu.parse_1_best({
                    'utt': Utterance(utterance)
                }).get_best_da()
                annotation += [unicode(da)]

        uniq_trn_sem[wav_key] = " <=> ".join(annotation)

    print "Saving output to file", fn_uniq_trn_sem
    save_wavaskey(fn_uniq_trn_sem, uniq_trn_sem)
コード例 #8
0
from alex.components.slu.base import CategoryLabelDatabase


if __name__ == "__main__":
    if len(sys.argv) < 2:
        print "No utterance entered as argument. Processing sample utterance instead..."
        utterance = u"CHTĚL BYCH JET ZE ZASTÁVKY ANDĚL DO ZASTÁVKY MALOSTRANSKÉ NÁMĚSTÍ"
    else:
        utterance = sys.argv[1].decode("utf-8")
        sys.argv = sys.argv[:1]

    cldb = CategoryLabelDatabase("../data/database.py")
    preprocessing = PTICSSLUPreprocessing(cldb)
    slu = PTICSHDCSLU(
        preprocessing,
        cfg={
            "SLU": {PTICSHDCSLU: {"utt2da": as_project_path("applications/PublicTransportInfoCS/data/utt2da_dict.txt")}}
        },
    )

    norm_utterance = slu.preprocessing.normalise_utterance(Utterance(utterance))
    abutterance, _, _ = slu.abstract_utterance(norm_utterance)
    da = slu.parse_1_best({"utt": Utterance(utterance)}, verbose=True).get_best_da()
    print "Abstracted utterance:", unicode(abutterance)
    print "Dialogue act:", unicode(da)

    max_alignment_idx = lambda _dai: max(_dai.alignment) if _dai.alignment else len(abutterance)
    for i, dai in enumerate(sorted(da, key=max_alignment_idx)):
        if not dai.alignment:
            print "Empty alignment:", unicode(abutterance), ";", dai

        if not dai.alignment or dai.alignment == -1:
コード例 #9
0
--asr-log   it uses the asr hypotheses from call logs

"""

asr_log = 0
num_workers = 1

cldb = CategoryLabelDatabase('../data/database.py')
preprocessing = PTICSSLUPreprocessing(cldb)
slu = PTICSHDCSLU(
    preprocessing,
    cfg={
        'SLU': {
            PTICSHDCSLU: {
                'utt2da':
                as_project_path(
                    "applications/PublicTransportInfoCS/data/utt2da_dict.txt")
            }
        }
    })
cfg = Config.load_configs([
    '../kaldi.cfg',
], use_default=True)
asr_rec = asr_factory(cfg)


def normalise_semi_words(txt):
    # normalise these semi-words
    if txt == '__other__':
        txt = '_other_'
コード例 #10
0
from alex.components.slu.base import CategoryLabelDatabase
"""
Serves to quickly test HDC SLU with a single utterance supplied as argument
"""

if len(sys.argv) < 2:
    print "No utterance entered as argument. Processing sample utterance instead..."
    utterance = u"CHTĚL BYCH JET ZE ZASTÁVKY ANDĚL DO ZASTÁVKY MALOSTRANSKÉ NÁMĚSTÍ"
else:
    utterance = sys.argv[1].decode('utf-8')
    sys.argv = sys.argv[:1]

cldb = CategoryLabelDatabase('../data/database.py')
preprocessing = PTICSSLUPreprocessing(cldb)
slu = PTICSHDCSLU(
    preprocessing,
    cfg={
        'SLU': {
            PTICSHDCSLU: {
                'utt2da':
                as_project_path(
                    "applications/PublicTransportInfoCS/data/utt2da_dict.txt")
            }
        }
    })

da = slu.parse_1_best({
    'utt': Utterance(utterance)
}, verbose=True).get_best_da()

print "Resulting dialogue act: \n", unicode(da)
コード例 #11
0
ファイル: prepare_data.py プロジェクト: AoJ/alex
def main():
    cldb = CategoryLabelDatabase('../data/database.py')
    preprocessing = PTICSSLUPreprocessing(cldb)
    slu = PTICSHDCSLU(preprocessing)
    cfg = Config.load_configs(['../kaldi.cfg',], use_default=True)
    asr_rec = asr_factory(cfg)                    

    fn_uniq_trn = 'uniq.trn'
    fn_uniq_trn_hdc_sem = 'uniq.trn.hdc.sem'
    fn_uniq_trn_sem = 'uniq.trn.sem'

    fn_all_sem = 'all.sem'
    fn_all_trn = 'all.trn'
    fn_all_trn_hdc_sem = 'all.trn.hdc.sem'
    fn_all_asr = 'all.asr'
    fn_all_asr_hdc_sem = 'all.asr.hdc.sem'
    fn_all_nbl = 'all.nbl'
    fn_all_nbl_hdc_sem = 'all.nbl.hdc.sem'

    fn_train_sem = 'train.sem'
    fn_train_trn = 'train.trn'
    fn_train_trn_hdc_sem = 'train.trn.hdc.sem'
    fn_train_asr = 'train.asr'
    fn_train_asr_hdc_sem = 'train.asr.hdc.sem'
    fn_train_nbl = 'train.nbl'
    fn_train_nbl_hdc_sem = 'train.nbl.hdc.sem'

    fn_dev_sem = 'dev.sem'
    fn_dev_trn = 'dev.trn'
    fn_dev_trn_hdc_sem = 'dev.trn.hdc.sem'
    fn_dev_asr = 'dev.asr'
    fn_dev_asr_hdc_sem = 'dev.asr.hdc.sem'
    fn_dev_nbl = 'dev.nbl'
    fn_dev_nbl_hdc_sem = 'dev.nbl.hdc.sem'

    fn_test_sem = 'test.sem'
    fn_test_trn = 'test.trn'
    fn_test_trn_hdc_sem = 'test.trn.hdc.sem'
    fn_test_asr = 'test.asr'
    fn_test_asr_hdc_sem = 'test.asr.hdc.sem'
    fn_test_nbl = 'test.nbl'
    fn_test_nbl_hdc_sem = 'test.nbl.hdc.sem'

    indomain_data_dir = "indomain_data"

    print "Generating the SLU train and test data"
    print "-"*120
    ###############################################################################################

    files = []
    files.append(glob.glob(os.path.join(indomain_data_dir, 'asr_transcribed.xml')))
    files.append(glob.glob(os.path.join(indomain_data_dir, '*', 'asr_transcribed.xml')))
    files.append(glob.glob(os.path.join(indomain_data_dir, '*', '*', 'asr_transcribed.xml')))
    files.append(glob.glob(os.path.join(indomain_data_dir, '*', '*', '*', 'asr_transcribed.xml')))
    files.append(glob.glob(os.path.join(indomain_data_dir, '*', '*', '*', '*', 'asr_transcribed.xml')))
    files.append(glob.glob(os.path.join(indomain_data_dir, '*', '*', '*', '*', '*', 'asr_transcribed.xml')))
    files = various.flatten(files)

    sem = []
    trn = []
    trn_hdc_sem = []
    asr = []
    asr_hdc_sem = []
    nbl = []
    nbl_hdc_sem = []

    for fn in files[:100000]:
        f_dir = os.path.dirname(fn)

        print "Processing:", fn
        doc = xml.dom.minidom.parse(fn)
        turns = doc.getElementsByTagName("turn")

        for i, turn in enumerate(turns):
            if turn.getAttribute('speaker') != 'user':
                continue

            recs = turn.getElementsByTagName("rec")
            trans = turn.getElementsByTagName("asr_transcription")
            asrs = turn.getElementsByTagName("asr")

            if len(recs) != 1:
                print "Skipping a turn {turn} in file: {fn} - recs: {recs}".format(turn=i,fn=fn, recs=len(recs))
                continue

            if len(asrs) == 0 and (i + 1) < len(turns):
                next_asrs = turns[i+1].getElementsByTagName("asr")
                if len(next_asrs) != 2:
                    print "Skipping a turn {turn} in file: {fn} - asrs: {asrs} - next_asrs: {next_asrs}".format(turn=i, fn=fn, asrs=len(asrs), next_asrs=len(next_asrs))
                    continue
                print "Recovered from missing ASR output by using a delayed ASR output from the following turn of turn {turn}. File: {fn} - next_asrs: {asrs}".format(turn=i, fn=fn, asrs=len(next_asrs))
                hyps = next_asrs[0].getElementsByTagName("hypothesis")
            elif len(asrs) == 1:
                hyps = asrs[0].getElementsByTagName("hypothesis")
            elif len(asrs) == 2:
                print "Recovered from EXTRA ASR outputs by using a the last ASR output from the turn. File: {fn} - asrs: {asrs}".format(fn=fn, asrs=len(asrs))
                hyps = asrs[-1].getElementsByTagName("hypothesis")
            else:
                print "Skipping a turn {turn} in file {fn} - asrs: {asrs}".format(turn=i,fn=fn, asrs=len(asrs))
                continue

            if len(trans) == 0:
                print "Skipping a turn in {fn} - trans: {trans}".format(fn=fn, trans=len(trans))
                continue

            wav_key = recs[0].getAttribute('fname')
            wav_path = os.path.join(f_dir, wav_key)
            
            # FIXME: Check whether the last transcription is really the best! FJ
            t = various.get_text_from_xml_node(trans[-1])
            t = normalise_text(t)

            
            if '--asr-log' not in sys.argv:
                asr_rec_nbl = asr_rec.rec_wav_file(wav_path)
                a = unicode(asr_rec_nbl.get_best())
            else:  
                a = various.get_text_from_xml_node(hyps[0])
                a = normalise_semi_words(a)

            if exclude_slu(t) or 'DOM Element:' in a:
                print "Skipping transcription:", unicode(t)
                print "Skipping ASR output:   ", unicode(a)
                continue

            # The silence does not have a label in the language model.
            t = t.replace('_SIL_','')

            trn.append((wav_key, t))

            print "Parsing transcription:", unicode(t)
            print "                  ASR:", unicode(a)

            # HDC SLU on transcription
            s = slu.parse_1_best({'utt':Utterance(t)}).get_best_da()
            trn_hdc_sem.append((wav_key, s))

            if '--uniq' not in sys.argv:
                # HDC SLU on 1 best ASR
                if '--asr-log' not in sys.argv:
                    a = unicode(asr_rec_nbl.get_best())
                else:  
                    a = various.get_text_from_xml_node(hyps[0])
                    a = normalise_semi_words(a)

                asr.append((wav_key, a))

                s = slu.parse_1_best({'utt':Utterance(a)}).get_best_da()
                asr_hdc_sem.append((wav_key, s))

                # HDC SLU on N best ASR
                n = UtteranceNBList()
                if '--asr-log' not in sys.argv:
                   n = asr_rec_nbl
                   
                   print 'ASR RECOGNITION NBLIST\n',unicode(n)
                else:
                    for h in hyps:
                        txt = various.get_text_from_xml_node(h)
                        txt = normalise_semi_words(txt)

                        n.add(abs(float(h.getAttribute('p'))),Utterance(txt))

                n.merge()
                n.normalise()

                nbl.append((wav_key, n.serialise()))

                if '--fast' not in sys.argv:
                    s = slu.parse_nblist({'utt_nbl':n}).get_best_da()
                nbl_hdc_sem.append((wav_key, s))

            # there is no manual semantics in the transcriptions yet
            sem.append((wav_key, None))


    uniq_trn = {}
    uniq_trn_hdc_sem = {}
    uniq_trn_sem = {}
    trn_set = set()

    sem = dict(trn_hdc_sem)
    for k, v in trn:
        if not v in trn_set:
            trn_set.add(v)
            uniq_trn[k] = v
            uniq_trn_hdc_sem[k] = sem[k]
            uniq_trn_sem[k] = v + " <=> " + unicode(sem[k])

    save_wavaskey(fn_uniq_trn, uniq_trn)
    save_wavaskey(fn_uniq_trn_hdc_sem, uniq_trn_hdc_sem, trans = lambda da: '&'.join(sorted(unicode(da).split('&'))))
    save_wavaskey(fn_uniq_trn_sem, uniq_trn_sem)

    # all
    save_wavaskey(fn_all_trn, dict(trn))
    save_wavaskey(fn_all_trn_hdc_sem, dict(trn_hdc_sem), trans = lambda da: '&'.join(sorted(unicode(da).split('&'))))

    if '--uniq' not in sys.argv:
        save_wavaskey(fn_all_asr, dict(asr))
        save_wavaskey(fn_all_asr_hdc_sem, dict(asr_hdc_sem), trans = lambda da: '&'.join(sorted(unicode(da).split('&'))))

        save_wavaskey(fn_all_nbl, dict(nbl))
        save_wavaskey(fn_all_nbl_hdc_sem, dict(nbl_hdc_sem), trans = lambda da: '&'.join(sorted(unicode(da).split('&'))))


        seed_value = 10

        random.seed(seed_value)
        random.shuffle(trn)
        random.seed(seed_value)
        random.shuffle(trn_hdc_sem)
        random.seed(seed_value)
        random.shuffle(asr)
        random.seed(seed_value)
        random.shuffle(asr_hdc_sem)
        random.seed(seed_value)
        random.shuffle(nbl)
        random.seed(seed_value)
        random.shuffle(nbl_hdc_sem)

        # trn
        train_trn = trn[:int(0.8*len(trn))]
        dev_trn = trn[int(0.8*len(trn)):int(0.9*len(trn))]
        test_trn = trn[int(0.9*len(trn)):]

        save_wavaskey(fn_train_trn, dict(train_trn))
        save_wavaskey(fn_dev_trn, dict(dev_trn))
        save_wavaskey(fn_test_trn, dict(test_trn))

        # trn_hdc_sem
        train_trn_hdc_sem = trn_hdc_sem[:int(0.8*len(trn_hdc_sem))]
        dev_trn_hdc_sem = trn_hdc_sem[int(0.8*len(trn_hdc_sem)):int(0.9*len(trn_hdc_sem))]
        test_trn_hdc_sem = trn_hdc_sem[int(0.9*len(trn_hdc_sem)):]

        save_wavaskey(fn_train_trn_hdc_sem, dict(train_trn_hdc_sem), trans = lambda da: '&'.join(sorted(unicode(da).split('&'))))
        save_wavaskey(fn_dev_trn_hdc_sem, dict(dev_trn_hdc_sem), trans = lambda da: '&'.join(sorted(unicode(da).split('&'))))
        save_wavaskey(fn_test_trn_hdc_sem, dict(test_trn_hdc_sem), trans = lambda da: '&'.join(sorted(unicode(da).split('&'))))

        # asr
        train_asr = asr[:int(0.8*len(asr))]
        dev_asr = asr[int(0.8*len(asr)):int(0.9*len(asr))]
        test_asr = asr[int(0.9*len(asr)):]

        save_wavaskey(fn_train_asr, dict(train_asr))
        save_wavaskey(fn_dev_asr, dict(dev_asr))
        save_wavaskey(fn_test_asr, dict(test_asr))

        # asr_hdc_sem
        train_asr_hdc_sem = asr_hdc_sem[:int(0.8*len(asr_hdc_sem))]
        dev_asr_hdc_sem = asr_hdc_sem[int(0.8*len(asr_hdc_sem)):int(0.9*len(asr_hdc_sem))]
        test_asr_hdc_sem = asr_hdc_sem[int(0.9*len(asr_hdc_sem)):]

        save_wavaskey(fn_train_asr_hdc_sem, dict(train_asr_hdc_sem), trans = lambda da: '&'.join(sorted(unicode(da).split('&'))))
        save_wavaskey(fn_dev_asr_hdc_sem, dict(dev_asr_hdc_sem), trans = lambda da: '&'.join(sorted(unicode(da).split('&'))))
        save_wavaskey(fn_test_asr_hdc_sem, dict(test_asr_hdc_sem), trans = lambda da: '&'.join(sorted(unicode(da).split('&'))))

        # n-best lists
        train_nbl = nbl[:int(0.8*len(nbl))]
        dev_nbl = nbl[int(0.8*len(nbl)):int(0.9*len(nbl))]
        test_nbl = nbl[int(0.9*len(nbl)):]

        save_wavaskey(fn_train_nbl, dict(train_nbl))
        save_wavaskey(fn_dev_nbl, dict(dev_nbl))
        save_wavaskey(fn_test_nbl, dict(test_nbl))

        # nbl_hdc_sem
        train_nbl_hdc_sem = nbl_hdc_sem[:int(0.8*len(nbl_hdc_sem))]
        dev_nbl_hdc_sem = nbl_hdc_sem[int(0.8*len(nbl_hdc_sem)):int(0.9*len(nbl_hdc_sem))]
        test_nbl_hdc_sem = nbl_hdc_sem[int(0.9*len(nbl_hdc_sem)):]

        save_wavaskey(fn_train_nbl_hdc_sem, dict(train_nbl_hdc_sem), trans = lambda da: '&'.join(sorted(unicode(da).split('&'))))
        save_wavaskey(fn_dev_nbl_hdc_sem, dict(dev_nbl_hdc_sem), trans = lambda da: '&'.join(sorted(unicode(da).split('&'))))
        save_wavaskey(fn_test_nbl_hdc_sem, dict(test_nbl_hdc_sem), trans = lambda da: '&'.join(sorted(unicode(da).split('&'))))
コード例 #12
0
ファイル: test_hdc_utt_dict.py プロジェクト: choko/alex
# -*- coding: utf-8 -*-

import sys
import autopath

from alex.utils.config import as_project_path

from alex.applications.PublicTransportInfoCS.hdc_slu import PTICSHDCSLU

from alex.applications.PublicTransportInfoCS.preprocessing import PTICSSLUPreprocessing
from alex.components.asr.utterance import Utterance
from alex.components.slu.base import CategoryLabelDatabase

"""
Serves to quickly test HDC SLU with a single utterance supplied as argument
"""

if len(sys.argv) < 2:
    print "No utterance entered as argument. Processing sample utterance instead..."
    utterance = u"CHTĚL BYCH JET ZE ZASTÁVKY ANDĚL DO ZASTÁVKY MALOSTRANSKÉ NÁMĚSTÍ"
else:
    utterance = sys.argv[1].decode('utf-8')
    sys.argv = sys.argv[:1]

cldb = CategoryLabelDatabase('../data/database.py')
preprocessing = PTICSSLUPreprocessing(cldb)
slu = PTICSHDCSLU(preprocessing, cfg = {'SLU': {PTICSHDCSLU: {'utt2da': as_project_path("applications/PublicTransportInfoCS/data/utt2da_dict.txt")}}})

da = slu.parse_1_best({'utt':Utterance(utterance)}, verbose=True).get_best_da()

print "Resulting dialogue act: \n", unicode(da)
コード例 #13
0
ファイル: test.py プロジェクト: elnaaz/alex
def hdc_slu_test(fn_input, constructor, fn_reference):
    """
    Tests a SLU DAILogRegClassifier model.

    :param fn_model:
    :param fn_input:
    :param constructor:
    :param fn_reference:
    :return:
    """
    print "="*120
    print "Testing HDC SLU: ", fn_input, fn_reference
    print "-"*120

    from alex.components.slu.base import CategoryLabelDatabase
    from alex.applications.PublicTransportInfoCS.preprocessing import PTICSSLUPreprocessing
    from alex.applications.PublicTransportInfoCS.hdc_slu import PTICSHDCSLU
    from alex.corpustools.wavaskey import load_wavaskey, save_wavaskey
    from alex.corpustools.semscore import score

    cldb = CategoryLabelDatabase('../data/database.py')
    preprocessing = PTICSSLUPreprocessing(cldb)
    hdc_slu = PTICSHDCSLU(preprocessing)

    test_utterances = load_wavaskey(fn_input, constructor, limit=100000)

    parsed_das = {}
    for utt_key, utt in sorted(test_utterances.iteritems()):
        if isinstance(utt, Utterance):
            obs = {'utt': utt}
        elif isinstance(utt, UtteranceNBList):
            obs = {'utt_nbl': utt}
        else:
            raise BaseException('Unsupported observation type')

        print '-' * 120
        print "Observation:"
        print utt_key, " ==> "
        print unicode(utt)

        da_confnet = hdc_slu.parse(obs, verbose=False)

        print "Conf net:"
        print unicode(da_confnet)

        da_confnet.prune()
        dah = da_confnet.get_best_da_hyp()

        print "1 best: "
        print unicode(dah)

        parsed_das[utt_key] = dah.da

        if 'CL_' in str(dah.da):
            print '*' * 120
            print utt
            print dah.da
            hdc_slu.parse(obs, verbose=True)

    fn_sem = os.path.basename(fn_input)+'.hdc.sem.out'

    save_wavaskey(fn_sem, parsed_das, trans = lambda da: '&'.join(sorted(unicode(da).split('&'))))

    f = codecs.open(os.path.basename(fn_sem)+'.score', 'w+', encoding='UTF-8')
    score(fn_reference, fn_sem, True, True, f)
    f.close()
コード例 #14
0
def main():
    cldb = CategoryLabelDatabase('../data/database.py')
    preprocessing = PTICSSLUPreprocessing(cldb)
    slu = PTICSHDCSLU(preprocessing)
    cfg = Config.load_configs([
        '../kaldi.cfg',
    ], use_default=True)
    asr_rec = asr_factory(cfg)

    fn_uniq_trn = 'uniq.trn'
    fn_uniq_trn_hdc_sem = 'uniq.trn.hdc.sem'
    fn_uniq_trn_sem = 'uniq.trn.sem'

    fn_all_sem = 'all.sem'
    fn_all_trn = 'all.trn'
    fn_all_trn_hdc_sem = 'all.trn.hdc.sem'
    fn_all_asr = 'all.asr'
    fn_all_asr_hdc_sem = 'all.asr.hdc.sem'
    fn_all_nbl = 'all.nbl'
    fn_all_nbl_hdc_sem = 'all.nbl.hdc.sem'

    fn_train_sem = 'train.sem'
    fn_train_trn = 'train.trn'
    fn_train_trn_hdc_sem = 'train.trn.hdc.sem'
    fn_train_asr = 'train.asr'
    fn_train_asr_hdc_sem = 'train.asr.hdc.sem'
    fn_train_nbl = 'train.nbl'
    fn_train_nbl_hdc_sem = 'train.nbl.hdc.sem'

    fn_dev_sem = 'dev.sem'
    fn_dev_trn = 'dev.trn'
    fn_dev_trn_hdc_sem = 'dev.trn.hdc.sem'
    fn_dev_asr = 'dev.asr'
    fn_dev_asr_hdc_sem = 'dev.asr.hdc.sem'
    fn_dev_nbl = 'dev.nbl'
    fn_dev_nbl_hdc_sem = 'dev.nbl.hdc.sem'

    fn_test_sem = 'test.sem'
    fn_test_trn = 'test.trn'
    fn_test_trn_hdc_sem = 'test.trn.hdc.sem'
    fn_test_asr = 'test.asr'
    fn_test_asr_hdc_sem = 'test.asr.hdc.sem'
    fn_test_nbl = 'test.nbl'
    fn_test_nbl_hdc_sem = 'test.nbl.hdc.sem'

    indomain_data_dir = "indomain_data"

    print "Generating the SLU train and test data"
    print "-" * 120
    ###############################################################################################

    files = []
    files.append(
        glob.glob(os.path.join(indomain_data_dir, 'asr_transcribed.xml')))
    files.append(
        glob.glob(os.path.join(indomain_data_dir, '*', 'asr_transcribed.xml')))
    files.append(
        glob.glob(
            os.path.join(indomain_data_dir, '*', '*', 'asr_transcribed.xml')))
    files.append(
        glob.glob(
            os.path.join(indomain_data_dir, '*', '*', '*',
                         'asr_transcribed.xml')))
    files.append(
        glob.glob(
            os.path.join(indomain_data_dir, '*', '*', '*', '*',
                         'asr_transcribed.xml')))
    files.append(
        glob.glob(
            os.path.join(indomain_data_dir, '*', '*', '*', '*', '*',
                         'asr_transcribed.xml')))
    files = various.flatten(files)

    sem = []
    trn = []
    trn_hdc_sem = []
    asr = []
    asr_hdc_sem = []
    nbl = []
    nbl_hdc_sem = []

    for fn in files[:100000]:
        f_dir = os.path.dirname(fn)

        print "Processing:", fn
        doc = xml.dom.minidom.parse(fn)
        turns = doc.getElementsByTagName("turn")

        for i, turn in enumerate(turns):
            if turn.getAttribute('speaker') != 'user':
                continue

            recs = turn.getElementsByTagName("rec")
            trans = turn.getElementsByTagName("asr_transcription")
            asrs = turn.getElementsByTagName("asr")

            if len(recs) != 1:
                print "Skipping a turn {turn} in file: {fn} - recs: {recs}".format(
                    turn=i, fn=fn, recs=len(recs))
                continue

            if len(asrs) == 0 and (i + 1) < len(turns):
                next_asrs = turns[i + 1].getElementsByTagName("asr")
                if len(next_asrs) != 2:
                    print "Skipping a turn {turn} in file: {fn} - asrs: {asrs} - next_asrs: {next_asrs}".format(
                        turn=i,
                        fn=fn,
                        asrs=len(asrs),
                        next_asrs=len(next_asrs))
                    continue
                print "Recovered from missing ASR output by using a delayed ASR output from the following turn of turn {turn}. File: {fn} - next_asrs: {asrs}".format(
                    turn=i, fn=fn, asrs=len(next_asrs))
                hyps = next_asrs[0].getElementsByTagName("hypothesis")
            elif len(asrs) == 1:
                hyps = asrs[0].getElementsByTagName("hypothesis")
            elif len(asrs) == 2:
                print "Recovered from EXTRA ASR outputs by using a the last ASR output from the turn. File: {fn} - asrs: {asrs}".format(
                    fn=fn, asrs=len(asrs))
                hyps = asrs[-1].getElementsByTagName("hypothesis")
            else:
                print "Skipping a turn {turn} in file {fn} - asrs: {asrs}".format(
                    turn=i, fn=fn, asrs=len(asrs))
                continue

            if len(trans) == 0:
                print "Skipping a turn in {fn} - trans: {trans}".format(
                    fn=fn, trans=len(trans))
                continue

            wav_key = recs[0].getAttribute('fname')
            wav_path = os.path.join(f_dir, wav_key)

            # FIXME: Check whether the last transcription is really the best! FJ
            t = various.get_text_from_xml_node(trans[-1])
            t = normalise_text(t)

            if '--asr-log' not in sys.argv:
                asr_rec_nbl = asr_rec.rec_wav_file(wav_path)
                a = unicode(asr_rec_nbl.get_best())
            else:
                a = various.get_text_from_xml_node(hyps[0])
                a = normalise_semi_words(a)

            if exclude_slu(t) or 'DOM Element:' in a:
                print "Skipping transcription:", unicode(t)
                print "Skipping ASR output:   ", unicode(a)
                continue

            # The silence does not have a label in the language model.
            t = t.replace('_SIL_', '')

            trn.append((wav_key, t))

            print "Parsing transcription:", unicode(t)
            print "                  ASR:", unicode(a)

            # HDC SLU on transcription
            s = slu.parse_1_best({'utt': Utterance(t)}).get_best_da()
            trn_hdc_sem.append((wav_key, s))

            if '--uniq' not in sys.argv:
                # HDC SLU on 1 best ASR
                if '--asr-log' not in sys.argv:
                    a = unicode(asr_rec_nbl.get_best())
                else:
                    a = various.get_text_from_xml_node(hyps[0])
                    a = normalise_semi_words(a)

                asr.append((wav_key, a))

                s = slu.parse_1_best({'utt': Utterance(a)}).get_best_da()
                asr_hdc_sem.append((wav_key, s))

                # HDC SLU on N best ASR
                n = UtteranceNBList()
                if '--asr-log' not in sys.argv:
                    n = asr_rec_nbl

                    print 'ASR RECOGNITION NBLIST\n', unicode(n)
                else:
                    for h in hyps:
                        txt = various.get_text_from_xml_node(h)
                        txt = normalise_semi_words(txt)

                        n.add(abs(float(h.getAttribute('p'))), Utterance(txt))

                n.merge()
                n.normalise()

                nbl.append((wav_key, n.serialise()))

                if '--fast' not in sys.argv:
                    s = slu.parse_nblist({'utt_nbl': n}).get_best_da()
                nbl_hdc_sem.append((wav_key, s))

            # there is no manual semantics in the transcriptions yet
            sem.append((wav_key, None))

    uniq_trn = {}
    uniq_trn_hdc_sem = {}
    uniq_trn_sem = {}
    trn_set = set()

    sem = dict(trn_hdc_sem)
    for k, v in trn:
        if not v in trn_set:
            trn_set.add(v)
            uniq_trn[k] = v
            uniq_trn_hdc_sem[k] = sem[k]
            uniq_trn_sem[k] = v + " <=> " + unicode(sem[k])

    save_wavaskey(fn_uniq_trn, uniq_trn)
    save_wavaskey(fn_uniq_trn_hdc_sem,
                  uniq_trn_hdc_sem,
                  trans=lambda da: '&'.join(sorted(unicode(da).split('&'))))
    save_wavaskey(fn_uniq_trn_sem, uniq_trn_sem)

    # all
    save_wavaskey(fn_all_trn, dict(trn))
    save_wavaskey(fn_all_trn_hdc_sem,
                  dict(trn_hdc_sem),
                  trans=lambda da: '&'.join(sorted(unicode(da).split('&'))))

    if '--uniq' not in sys.argv:
        save_wavaskey(fn_all_asr, dict(asr))
        save_wavaskey(
            fn_all_asr_hdc_sem,
            dict(asr_hdc_sem),
            trans=lambda da: '&'.join(sorted(unicode(da).split('&'))))

        save_wavaskey(fn_all_nbl, dict(nbl))
        save_wavaskey(
            fn_all_nbl_hdc_sem,
            dict(nbl_hdc_sem),
            trans=lambda da: '&'.join(sorted(unicode(da).split('&'))))

        seed_value = 10

        random.seed(seed_value)
        random.shuffle(trn)
        random.seed(seed_value)
        random.shuffle(trn_hdc_sem)
        random.seed(seed_value)
        random.shuffle(asr)
        random.seed(seed_value)
        random.shuffle(asr_hdc_sem)
        random.seed(seed_value)
        random.shuffle(nbl)
        random.seed(seed_value)
        random.shuffle(nbl_hdc_sem)

        # trn
        train_trn = trn[:int(0.8 * len(trn))]
        dev_trn = trn[int(0.8 * len(trn)):int(0.9 * len(trn))]
        test_trn = trn[int(0.9 * len(trn)):]

        save_wavaskey(fn_train_trn, dict(train_trn))
        save_wavaskey(fn_dev_trn, dict(dev_trn))
        save_wavaskey(fn_test_trn, dict(test_trn))

        # trn_hdc_sem
        train_trn_hdc_sem = trn_hdc_sem[:int(0.8 * len(trn_hdc_sem))]
        dev_trn_hdc_sem = trn_hdc_sem[int(0.8 * len(trn_hdc_sem)
                                          ):int(0.9 * len(trn_hdc_sem))]
        test_trn_hdc_sem = trn_hdc_sem[int(0.9 * len(trn_hdc_sem)):]

        save_wavaskey(
            fn_train_trn_hdc_sem,
            dict(train_trn_hdc_sem),
            trans=lambda da: '&'.join(sorted(unicode(da).split('&'))))
        save_wavaskey(
            fn_dev_trn_hdc_sem,
            dict(dev_trn_hdc_sem),
            trans=lambda da: '&'.join(sorted(unicode(da).split('&'))))
        save_wavaskey(
            fn_test_trn_hdc_sem,
            dict(test_trn_hdc_sem),
            trans=lambda da: '&'.join(sorted(unicode(da).split('&'))))

        # asr
        train_asr = asr[:int(0.8 * len(asr))]
        dev_asr = asr[int(0.8 * len(asr)):int(0.9 * len(asr))]
        test_asr = asr[int(0.9 * len(asr)):]

        save_wavaskey(fn_train_asr, dict(train_asr))
        save_wavaskey(fn_dev_asr, dict(dev_asr))
        save_wavaskey(fn_test_asr, dict(test_asr))

        # asr_hdc_sem
        train_asr_hdc_sem = asr_hdc_sem[:int(0.8 * len(asr_hdc_sem))]
        dev_asr_hdc_sem = asr_hdc_sem[int(0.8 * len(asr_hdc_sem)
                                          ):int(0.9 * len(asr_hdc_sem))]
        test_asr_hdc_sem = asr_hdc_sem[int(0.9 * len(asr_hdc_sem)):]

        save_wavaskey(
            fn_train_asr_hdc_sem,
            dict(train_asr_hdc_sem),
            trans=lambda da: '&'.join(sorted(unicode(da).split('&'))))
        save_wavaskey(
            fn_dev_asr_hdc_sem,
            dict(dev_asr_hdc_sem),
            trans=lambda da: '&'.join(sorted(unicode(da).split('&'))))
        save_wavaskey(
            fn_test_asr_hdc_sem,
            dict(test_asr_hdc_sem),
            trans=lambda da: '&'.join(sorted(unicode(da).split('&'))))

        # n-best lists
        train_nbl = nbl[:int(0.8 * len(nbl))]
        dev_nbl = nbl[int(0.8 * len(nbl)):int(0.9 * len(nbl))]
        test_nbl = nbl[int(0.9 * len(nbl)):]

        save_wavaskey(fn_train_nbl, dict(train_nbl))
        save_wavaskey(fn_dev_nbl, dict(dev_nbl))
        save_wavaskey(fn_test_nbl, dict(test_nbl))

        # nbl_hdc_sem
        train_nbl_hdc_sem = nbl_hdc_sem[:int(0.8 * len(nbl_hdc_sem))]
        dev_nbl_hdc_sem = nbl_hdc_sem[int(0.8 * len(nbl_hdc_sem)
                                          ):int(0.9 * len(nbl_hdc_sem))]
        test_nbl_hdc_sem = nbl_hdc_sem[int(0.9 * len(nbl_hdc_sem)):]

        save_wavaskey(
            fn_train_nbl_hdc_sem,
            dict(train_nbl_hdc_sem),
            trans=lambda da: '&'.join(sorted(unicode(da).split('&'))))
        save_wavaskey(
            fn_dev_nbl_hdc_sem,
            dict(dev_nbl_hdc_sem),
            trans=lambda da: '&'.join(sorted(unicode(da).split('&'))))
        save_wavaskey(
            fn_test_nbl_hdc_sem,
            dict(test_nbl_hdc_sem),
            trans=lambda da: '&'.join(sorted(unicode(da).split('&'))))