示例#1
0
def slu_factory(cfg, slu_type=None):
    """
    Creates an SLU parser.

    :param cfg:
    :param slu_type:
    :param require_model:
    :param training:
    :param verbose:

    """

    #This new and simple factory code.
    if slu_type is None:
        slu_type = get_slu_type(cfg)

    if inspect.isclass(slu_type) and issubclass(slu_type, DAILogRegClassifier):
        cldb = CategoryLabelDatabase(cfg['SLU'][slu_type]['cldb_fname'])
        preprocessing = cfg['SLU'][slu_type]['preprocessing_cls'](cldb)
        slu = slu_type(cldb, preprocessing)
        slu.load_model(cfg['SLU'][slu_type]['model_fname'])
        return slu
    elif inspect.isclass(slu_type) and issubclass(slu_type, SLUInterface):
        cldb = CategoryLabelDatabase(cfg['SLU'][slu_type]['cldb_fname'])
        preprocessing = cfg['SLU'][slu_type]['preprocessing_cls'](cldb)
        slu = slu_type(preprocessing)
        return slu

    raise SLUException('Unsupported SLU parser: %s' % slu_type)
示例#2
0
    def test_parse_X(self):
        from alex.components.slu.dainnclassifier import DAINNClassifier
        
        np.random.seed(0)

        cldb = CategoryLabelDatabase()
        class db:
            database = {
                "task": {
                    "find_connection": ["najít spojení", "najít spoj", "zjistit spojení",
                                        "zjistit spoj", "hledám spojení", 'spojení', 'spoj',
                                       ],
                    "find_platform": ["najít nástupiště", "zjistit nástupiště", ],
                    'weather': ['pocasi', 'jak bude', ],
                },
                "number": {
                    "1": ["jednu"]
                },
                "time": {
                    "now": ["nyní", "teď", "teďka", "hned", "nejbližší", "v tuto chvíli", "co nejdřív"],
                },
            }

        cldb.load(db_mod=db)

        preprocessing = SLUPreprocessing(cldb)
        clf = DAINNClassifier(cldb, preprocessing, features_size=4)

        # Train a simple classifier.
        das = {
            '1': DialogueAct('inform(task=weather)'),
            '2': DialogueAct('inform(time=now)'),
            '3': DialogueAct('inform(task=weather)'),
            '4': DialogueAct('inform(task=connection)'),
        }
        utterances = {
            '1': Utterance('pocasi pocasi pocasi pocasi pocasi'),
            '2': Utterance('hned ted nyni hned ted nyni'),
            '3': Utterance('jak bude jak bude jak bude jak bude'),
            '4': Utterance('kdy a odkat mi to jede'),
        }
        clf.extract_classifiers(das, utterances, verbose=False)
        clf.prune_classifiers(min_classifier_count=0)
        clf.gen_classifiers_data(min_pos_feature_count=0,
                                 min_neg_feature_count=0,
                                 verbose2=False)

        clf.train(inverse_regularisation=1e1, verbose=False)

        # Parse some sentences.
        utterance_list = UtteranceNBList()
        utterance_list.add(0.7, Utterance('pocasi'))
        utterance_list.add(0.7, Utterance('jak bude pocasi'))
        utterance_list.add(0.2, Utterance('hned'))
        utterance_list.add(0.2, Utterance('hned'))
        da_confnet = clf.parse_X(utterance_list, verbose=False)

        self.assertTrue(da_confnet.get_prob(DialogueActItem(dai='inform(task=weather)')) != 0.0)
        self.assertTrue(da_confnet.get_prob(DialogueActItem(dai='inform(time=now)')) != 0.0)
示例#3
0
    def setUpClass(cls):
        cfg = {
            'SLU': {
                'debug': True,
                'type': PTICSHDCSLU,
                PTICSHDCSLU: {
                    'preprocessing_cls': PTICSSLUPreprocessing
                },
            },
        }
        slu_type = cfg['SLU']['type']
        cldb = CategoryLabelDatabase()

        class db:
            database = {
                "task": {
                    "find_connection": [
                        "najít spojení",
                        "najít spoj",
                        "zjistit spojení",
                        "zjistit spoj",
                        "hledám spojení",
                        'spojení',
                        'spoj',
                    ],
                    "find_platform": [
                        "najít nástupiště",
                        "zjistit nástupiště",
                    ],
                    'weather': [
                        'pocasi',
                    ],
                },
                "number": {
                    "1": ["jednu"]
                },
                "time": {
                    "now": [
                        "nyní", "teď", "teďka", "hned", "nejbližší",
                        "v tuto chvíli", "co nejdřív"
                    ],
                    "18": ["osmnáct", "osmnact"]
                },
                "date_rel": {
                    "tomorrow": ["zítra", "zitra"],
                }
            }

        cldb.load(db_mod=db)
        preprocessing = cfg['SLU'][slu_type]['preprocessing_cls'](cldb)
        cls.slu = slu_type(preprocessing, cfg)
示例#4
0
    def test_catlab_substitution(self):
        utterances_dict = load_utterances(
            os.path.join(SCRIPT_DIR, 'resources', 'towninfo-train.trn'))
        semantics_dict = load_das(
            os.path.join(SCRIPT_DIR, 'resources', 'towninfo-train.sem'))

        cldb = CategoryLabelDatabase(
            os.path.join(SCRIPT_DIR, 'resources', 'database.py'))
        preprocessing = SLUPreprocessing(cldb)

        for k in semantics_dict:
            print '=' * 120
            print utterances_dict[k]
            print semantics_dict[k]

            utterance, da, category_labels = (
                preprocessing.values2category_labels_in_da(
                    utterances_dict[k], semantics_dict[k]))

            print '-' * 120
            print utterance
            print da
            print category_labels
            print '-' * 120

            full_utt = preprocessing.category_labels2values_in_utterance(
                utterance, category_labels)
            full_da = preprocessing.category_labels2values_in_da(
                da, category_labels)

            print full_utt
            print full_da
示例#5
0
def process_file(file_path):

    cldb = CategoryLabelDatabase(as_project_path('applications/PublicTransportInfoCS/data/database.py'))
    preprocessing = PTICSSLUPreprocessing(cldb)
    hdc_slu = PTICSHDCSLU(preprocessing, cfg = {'SLU': {PTICSHDCSLU: {'utt2da': as_project_path('applications/PublicTransportInfoCS/data/utt2da_dict.txt')}}})
    stdout = codecs.getwriter('UTF-8')(sys.stdout)

    with open(file_path, 'r') as fh:
        for line in codecs.getreader('UTF-8')(fh):
            line = line.strip("\r\n")

            # skip empty lines (dialogue boundaries)
            if not line:
                continue

            person, da, utt = line.split("\t")
            # skip system utterances, use just user utterances
            if 'SYSTEM' in person:
                continue

            # reparse utterance using transcription
            utt = re.sub(r',', r' ', utt)
            utt = Utterance(utt)
            sem = hdc_slu.parse({'utt': utt})

            # get abstracted utterance text
            abutt = hdc_slu.abstract_utterance(utt)
            abutt_str = get_abutt_str(utt, abutt)

            # get abstracted DA
            best_da = sem.get_best_da()
            best_da_str = unicode(best_da)
            abstract_da(best_da)

            print >> stdout, unicode(utt) + "\t" + abutt_str + "\t" + best_da_str + "\t" + unicode(best_da)
示例#6
0
    def setUpClass(cls):
        cfg = {
            'SLU': {
                'debug': True,
                'type': PTICSHDCSLU,
                PTICSHDCSLU: {
                    'preprocessing_cls': PTICSSLUPreprocessing
                },
                },
            }
        slu_type = cfg['SLU']['type']
        cldb = CategoryLabelDatabase()
        class db:
            database = {
                "task": {
                    "find_connection": ["najít spojení", "najít spoj", "zjistit spojení",
                                        "zjistit spoj", "hledám spojení", 'spojení', 'spoj',
                                        ],
                    "find_platform": ["najít nástupiště", "zjistit nástupiště", ],
                    'weather': ['pocasi', ],
                    },
                "number": {
                    "1": ["jednu"]
                },
                "time": {
                    "now": ["nyní", "teď", "teďka", "hned", "nejbližší", "v tuto chvíli", "co nejdřív"],
                    "18": ["osmnáct", "osmnact"]
                },
                "date_rel": {
                    "tomorrow": ["zítra", "zitra"],
                    }
            }

        cldb.load(db_mod=db)
        preprocessing = cfg['SLU'][slu_type]['preprocessing_cls'](cldb)
        cls.slu = slu_type(preprocessing, cfg)
示例#7
0
def main():

    cldb = CategoryLabelDatabase('../data/database.py')
    preprocessing = PTICSSLUPreprocessing(cldb)
    slu = PTICSHDCSLU(
        preprocessing,
        cfg={
            'SLU': {
                PTICSHDCSLU: {
                    'utt2da':
                    as_project_path(
                        "applications/PublicTransportInfoCS/data/utt2da_dict.txt"
                    )
                }
            }
        })

    output_utterance = True
    output_abstraction = False
    output_da = True

    fn_uniq_trn_sem = 'uniq.trn.sem.tmp'

    if len(sys.argv) < 2:
        fn_uniq_trn = 'uniq.trn'
    else:
        fn_uniq_trn = sys.argv[1]

    print "Processing input from file", fn_uniq_trn
    uniq_trn = codecs.open(fn_uniq_trn, "r", encoding='utf8')
    uniq_trn_sem = {}
    for line in uniq_trn:
        wav_key, utterance = line.split(" => ", 2)
        annotation = []
        if output_utterance:
            annotation += [utterance.rstrip()]
        if output_abstraction:
            norm_utterance = slu.preprocessing.normalise_utterance(utterance)
            abutterance, _ = slu.abstract_utterance(norm_utterance)
            annotation += [abutterance]
        if output_da:
            da = slu.parse_1_best({'utt': Utterance(utterance)}).get_best_da()
            annotation += [unicode(da)]

        uniq_trn_sem[wav_key] = " <=> ".join(annotation)

    print "Saving output to file", fn_uniq_trn_sem
    save_wavaskey(fn_uniq_trn_sem, uniq_trn_sem)
示例#8
0
def train(fn_model,
          fn_transcription,
          constructor,
          fn_annotation,
          fn_bs_transcription,
          fn_bs_annotation,
          min_pos_feature_count,
          min_neg_feature_count,
          min_classifier_count,
          limit=100000):
    """
    Trains a SLU DAILogRegClassifier model.

    :param fn_model:
    :param fn_transcription:
    :param constructor:
    :param fn_annotation:
    :param limit:
    :return:
    """
    bs_utterances = load_wavaskey(fn_bs_transcription, Utterance, limit=limit)
    increase_weight(bs_utterances, min_pos_feature_count + 10)
    bs_das = load_wavaskey(fn_bs_annotation, DialogueAct, limit=limit)
    increase_weight(bs_das, min_pos_feature_count + 10)

    utterances = load_wavaskey(fn_transcription, constructor, limit=limit)
    das = load_wavaskey(fn_annotation, DialogueAct, limit=limit)

    utterances.update(bs_utterances)
    das.update(bs_das)

    cldb = CategoryLabelDatabase('../../data/database.py')
    preprocessing = PTICSSLUPreprocessing(cldb)
    slu = DAILogRegClassifier(cldb, preprocessing, features_size=4)

    slu.extract_classifiers(das, utterances, verbose=True)
    slu.prune_classifiers(min_classifier_count=min_classifier_count)
    slu.print_classifiers()
    slu.gen_classifiers_data(min_pos_feature_count=min_pos_feature_count,
                             min_neg_feature_count=min_neg_feature_count,
                             verbose2=True)

    slu.train(inverse_regularisation=1e1, verbose=True)

    slu.save_model(fn_model)
示例#9
0
def hdc_slu_test(fn_input, constructor, fn_reference):
    """
    Tests the HDC SLU.

    :param fn_model:
    :param fn_input:
    :param constructor:
    :param fn_reference:
    :return:
    """
    print "=" * 120
    print "Testing HDC SLU: ", fn_input, fn_reference
    print "-" * 120

    from alex.components.slu.base import CategoryLabelDatabase
    from alex.applications.PublicTransportInfoCS.preprocessing import PTICSSLUPreprocessing
    from alex.applications.PublicTransportInfoCS.hdc_slu import PTICSHDCSLU
    from alex.corpustools.wavaskey import load_wavaskey, save_wavaskey
    from alex.corpustools.semscore import score

    cldb = CategoryLabelDatabase('../data/database.py')
    preprocessing = PTICSSLUPreprocessing(cldb)
    hdc_slu = PTICSHDCSLU(
        preprocessing,
        cfg={
            'SLU': {
                PTICSHDCSLU: {
                    'utt2da':
                    as_project_path(
                        "applications/PublicTransportInfoCS/data/utt2da_dict.txt"
                    )
                }
            }
        })

    test_utterances = load_wavaskey(fn_input, constructor, limit=100000)

    parsed_das = {}
    for utt_key, utt in sorted(test_utterances.iteritems()):
        if isinstance(utt, Utterance):
            obs = {'utt': utt}
        elif isinstance(utt, UtteranceNBList):
            obs = {'utt_nbl': utt}
        else:
            raise BaseException('Unsupported observation type')

        print '-' * 120
        print "Observation:"
        print utt_key, " ==> "
        print unicode(utt)

        da_confnet = hdc_slu.parse(obs, verbose=False)

        print "Conf net:"
        print unicode(da_confnet)

        da_confnet.prune()
        dah = da_confnet.get_best_da_hyp()

        print "1 best: "
        print unicode(dah)

        parsed_das[utt_key] = dah.da

        if 'CL_' in str(dah.da):
            print '*' * 120
            print utt
            print dah.da
            hdc_slu.parse(obs, verbose=True)

    fn_sem = os.path.basename(fn_input) + '.hdc.sem.out'

    save_wavaskey(fn_sem,
                  parsed_das,
                  trans=lambda da: '&'.join(sorted(unicode(da).split('&'))))

    f = codecs.open(os.path.basename(fn_sem) + '.score',
                    'w+',
                    encoding='UTF-8')
    score(fn_reference, fn_sem, True, True, f)
    f.close()
示例#10
0
def main():

    cldb = CategoryLabelDatabase('../data/database.py')
    preprocessing = PTICSSLUPreprocessing(cldb)
    slu = PTICSHDCSLU(
        preprocessing,
        cfg={
            'SLU': {
                PTICSHDCSLU: {
                    'utt2da':
                    as_project_path(
                        "applications/PublicTransportInfoCS/data/utt2da_dict.txt"
                    )
                }
            }
        })

    output_alignment = False
    output_utterance = True
    output_abstraction = False
    output_da = True

    if len(sys.argv) < 2:
        fn_uniq_trn = 'uniq.trn'
    else:
        fn_uniq_trn = sys.argv[1]
    fn_uniq_trn_sem = fn_uniq_trn + '.sem.tmp'

    print "Processing input from file", fn_uniq_trn
    uniq_trn = codecs.open(fn_uniq_trn, "r", encoding='utf8')
    uniq_trn_sem = {}
    for line in uniq_trn:
        wav_key, utterance = line.split(" => ", 2)
        annotation = []
        if output_alignment:
            norm_utterance = slu.preprocessing.normalise_utterance(
                Utterance(utterance))
            abutterance, _, _ = slu.abstract_utterance(norm_utterance)
            abutterance = slu.handle_false_abstractions(abutterance)
            da = slu.parse_1_best({'utt': Utterance(utterance)}).get_best_da()

            max_alignment_idx = lambda _dai: max(
                _dai.alignment) if _dai.alignment else len(abutterance)
            for i, dai in enumerate(sorted(da, key=max_alignment_idx)):
                if not dai.alignment:
                    print "Empty alignment:", unicode(abutterance), ";", dai

                if not dai.alignment or dai.alignment == {-1}:
                    dai_alignment_idx = len(abutterance)
                else:
                    dai_alignment_idx = max(dai.alignment) + i + 1
                abutterance.insert(
                    dai_alignment_idx, "[{} - {}]".format(
                        unicode(dai),
                        list(dai.alignment if dai.alignment else [])))
            annotation += [unicode(abutterance)]
        else:
            if output_utterance:
                annotation += [utterance.rstrip()]
            if output_abstraction:
                norm_utterance = slu.preprocessing.normalise_utterance(
                    Utterance(utterance))
                abutterance, _ = slu.abstract_utterance(norm_utterance)
                annotation += [abutterance]
            if output_da:
                da = slu.parse_1_best({
                    'utt': Utterance(utterance)
                }).get_best_da()
                annotation += [unicode(da)]

        uniq_trn_sem[wav_key] = " <=> ".join(annotation)

    print "Saving output to file", fn_uniq_trn_sem
    save_wavaskey(fn_uniq_trn_sem, uniq_trn_sem)
示例#11
0
文件: test.py 项目: tkraut/alex
def trained_slu_test(fn_model, fn_input, constructor, fn_reference):
    """
    Tests a SLU DAILogRegClassifier model.

    :param fn_model:
    :param fn_input:
    :param constructor:
    :param fn_reference:
    :return:
    """
    print "="*120
    print "Testing: ", fn_model, fn_input, fn_reference
    print "-"*120

    from alex.applications.PublicTransportInfoCS.preprocessing import PTICSSLUPreprocessing
    from alex.components.slu.base import CategoryLabelDatabase
    from alex.components.slu.dailrclassifier import DAILogRegClassifier
    from alex.corpustools.wavaskey import load_wavaskey, save_wavaskey
    from alex.corpustools.semscore import score

    cldb = CategoryLabelDatabase('../data/database.py')
    preprocessing = PTICSSLUPreprocessing(cldb)
    slu = DAILogRegClassifier(cldb, preprocessing)

    slu.load_model(fn_model)

    test_utterances = load_wavaskey(fn_input, constructor, limit=100000)

    parsed_das = {}
    for utt_key, utt in sorted(test_utterances.iteritems()):
        if isinstance(utt, Utterance):
            obs = {'utt': utt}
        elif isinstance(utt, UtteranceNBList):
            obs = {'utt_nbl': utt}
        else:
            raise BaseException('Unsupported observation type')

        print '-' * 120
        print "Observation:"
        print utt_key, " ==> "
        print unicode(utt)

        da_confnet = slu.parse(obs, verbose=False)

        print "Conf net:"
        print unicode(da_confnet)

        da_confnet.prune()
        dah = da_confnet.get_best_da_hyp()

        print "1 best: "
        print unicode(dah)

        parsed_das[utt_key] = dah.da

        if 'CL_' in str(dah.da):
            print '*' * 120
            print utt
            print dah.da
            slu.parse(obs, verbose=True)

    if 'trn' in fn_model:
        fn_sem = os.path.basename(fn_input)+'.model.trn.sem.out'
    elif 'asr' in fn_model:
        fn_sem = os.path.basename(fn_input)+'.model.asr.sem.out'
    elif 'nbl' in fn_model:
        fn_sem = os.path.basename(fn_input)+'.model.nbl.sem.out'
    else:
        fn_sem = os.path.basename(fn_input)+'.XXX.sem.out'

    save_wavaskey(fn_sem, parsed_das, trans = lambda da: '&'.join(sorted(unicode(da).split('&'))))

    f = codecs.open(os.path.basename(fn_sem)+'.score', 'w+', encoding='UTF-8')
    score(fn_reference, fn_sem, True, True, f)
    f.close()
示例#12
0
from alex.components.asr.common import asr_factory
from alex.components.asr.utterance import Utterance, UtteranceNBList
from alex.components.slu.base import CategoryLabelDatabase
from alex.applications.PublicTransportInfoCS.preprocessing import PTICSSLUPreprocessing
from alex.applications.PublicTransportInfoCS.hdc_slu import PTICSHDCSLU
from alex.utils.config import Config
""" The script has commands:

--asr-log   it uses the asr hypotheses from call logs

"""

asr_log = 0
num_workers = 1

cldb = CategoryLabelDatabase('../data/database.py')
preprocessing = PTICSSLUPreprocessing(cldb)
slu = PTICSHDCSLU(
    preprocessing,
    cfg={
        'SLU': {
            PTICSHDCSLU: {
                'utt2da':
                as_project_path(
                    "applications/PublicTransportInfoCS/data/utt2da_dict.txt")
            }
        }
    })
cfg = Config.load_configs([
    '../kaldi.cfg',
], use_default=True)
示例#13
0
def main():

    cldb = CategoryLabelDatabase('../data/database.py')

    f_dupl = 0
    f_subs = 0
    examples = defaultdict(list)
    
    for f in cldb.form2value2cl:
        if len(cldb.form2value2cl[f]) >= 2:
            f_dupl += 1
            
        if len(f) < 2:
            continue
            
        for w in f:
            w = (w,)
            if w in cldb.form2value2cl:
                for v in cldb.form2value2cl[w]:
                    for c in cldb.form2value2cl[w][v]:
                        cc = c
                        break
                
                print '{w},{cc} -> {f}'.format(w=w, cc=cc, f=' '.join(f))
                break
        else:
            continue
        
        f_subs += 1
        for v in cldb.form2value2cl[f]:
            for c in cldb.form2value2cl[f][v]:
                examples[(c,cc)].extend(inform(f,v,c))
                examples[(c,cc)].extend(confirm(f,v,c))

    print "There were {f} surface forms.".format(f=len(cldb.form2value2cl))
    print "There were {f_dupl} surface form duplicits.".format(f_dupl=f_dupl)
    print "There were {f_subs} surface forms with substring surface forms.".format(f_subs=f_subs)
    
    max_examples = 100
    ex = []
    for c in sorted(examples.keys()):             
        print c
        z = examples[c]
        if max_examples < len(z):
            z = random.sample(z, max_examples)
        for s, t in z:
            print ' - ', s, '<=>', t
            ex.append((s, t))
            
    examples = ex
    
    examples.sort()
    
    sem = {}
    trn = {}
    for i, e in enumerate(examples):
        key = 'bootstrap_gen_{i:06d}.wav'.format(i=i)
        sem[key] = e[0]
        trn[key] = e[1]

    save_wavaskey('bootstrap_gen.sem', sem)
    save_wavaskey('bootstrap_gen.trn', trn)
示例#14
0
def main():
    cldb = CategoryLabelDatabase('../data/database.py')
    preprocessing = PTIENSLUPreprocessing(cldb)
    slu = PTIENHDCSLU(preprocessing, cfg={'SLU': {PTIENHDCSLU: {'utt2da': as_project_path("applications/PublicTransportInfoEN/data/utt2da_dict.txt")}}})
    cfg = Config.load_configs(['../kaldi.cfg',], use_default=True)
    asr_rec = asr_factory(cfg)                    

    fn_uniq_trn = 'uniq.trn'
    fn_uniq_trn_hdc_sem = 'uniq.trn.hdc.sem'
    fn_uniq_trn_sem = 'uniq.trn.sem'

    fn_all_sem = 'all.sem'
    fn_all_trn = 'all.trn'
    fn_all_trn_hdc_sem = 'all.trn.hdc.sem'
    fn_all_asr = 'all.asr'
    fn_all_asr_hdc_sem = 'all.asr.hdc.sem'
    fn_all_nbl = 'all.nbl'
    fn_all_nbl_hdc_sem = 'all.nbl.hdc.sem'

    fn_train_sem = 'train.sem'
    fn_train_trn = 'train.trn'
    fn_train_trn_hdc_sem = 'train.trn.hdc.sem'
    fn_train_asr = 'train.asr'
    fn_train_asr_hdc_sem = 'train.asr.hdc.sem'
    fn_train_nbl = 'train.nbl'
    fn_train_nbl_hdc_sem = 'train.nbl.hdc.sem'

    fn_dev_sem = 'dev.sem'
    fn_dev_trn = 'dev.trn'
    fn_dev_trn_hdc_sem = 'dev.trn.hdc.sem'
    fn_dev_asr = 'dev.asr'
    fn_dev_asr_hdc_sem = 'dev.asr.hdc.sem'
    fn_dev_nbl = 'dev.nbl'
    fn_dev_nbl_hdc_sem = 'dev.nbl.hdc.sem'

    fn_test_sem = 'test.sem'
    fn_test_trn = 'test.trn'
    fn_test_trn_hdc_sem = 'test.trn.hdc.sem'
    fn_test_asr = 'test.asr'
    fn_test_asr_hdc_sem = 'test.asr.hdc.sem'
    fn_test_nbl = 'test.nbl'
    fn_test_nbl_hdc_sem = 'test.nbl.hdc.sem'

    indomain_data_dir = "indomain_data"

    print "Generating the SLU train and test data"
    print "-"*120
    ###############################################################################################

    files = []
    files.append(glob.glob(os.path.join(indomain_data_dir, 'asr_transcribed.xml')))
    files.append(glob.glob(os.path.join(indomain_data_dir, '*', 'asr_transcribed.xml')))
    files.append(glob.glob(os.path.join(indomain_data_dir, '*', '*', 'asr_transcribed.xml')))
    files.append(glob.glob(os.path.join(indomain_data_dir, '*', '*', '*', 'asr_transcribed.xml')))
    files.append(glob.glob(os.path.join(indomain_data_dir, '*', '*', '*', '*', 'asr_transcribed.xml')))
    files.append(glob.glob(os.path.join(indomain_data_dir, '*', '*', '*', '*', '*', 'asr_transcribed.xml')))
    files = various.flatten(files)

    sem = []
    trn = []
    trn_hdc_sem = []
    asr = []
    asr_hdc_sem = []
    nbl = []
    nbl_hdc_sem = []

    for fn in files[:100000]:
        f_dir = os.path.dirname(fn)

        print "Processing:", fn
        doc = xml.dom.minidom.parse(fn)
        turns = doc.getElementsByTagName("turn")

        for i, turn in enumerate(turns):
            if turn.getAttribute('speaker') != 'user':
                continue

            recs = turn.getElementsByTagName("rec")
            trans = turn.getElementsByTagName("asr_transcription")
            asrs = turn.getElementsByTagName("asr")

            if len(recs) != 1:
                print "Skipping a turn {turn} in file: {fn} - recs: {recs}".format(turn=i,fn=fn, recs=len(recs))
                continue

            if len(asrs) == 0 and (i + 1) < len(turns):
                next_asrs = turns[i+1].getElementsByTagName("asr")
                if len(next_asrs) != 2:
                    print "Skipping a turn {turn} in file: {fn} - asrs: {asrs} - next_asrs: {next_asrs}".format(turn=i, fn=fn, asrs=len(asrs), next_asrs=len(next_asrs))
                    continue
                print "Recovered from missing ASR output by using a delayed ASR output from the following turn of turn {turn}. File: {fn} - next_asrs: {asrs}".format(turn=i, fn=fn, asrs=len(next_asrs))
                hyps = next_asrs[0].getElementsByTagName("hypothesis")
            elif len(asrs) == 1:
                hyps = asrs[0].getElementsByTagName("hypothesis")
            elif len(asrs) == 2:
                print "Recovered from EXTRA ASR outputs by using a the last ASR output from the turn. File: {fn} - asrs: {asrs}".format(fn=fn, asrs=len(asrs))
                hyps = asrs[-1].getElementsByTagName("hypothesis")
            else:
                print "Skipping a turn {turn} in file {fn} - asrs: {asrs}".format(turn=i,fn=fn, asrs=len(asrs))
                continue

            if len(trans) == 0:
                print "Skipping a turn in {fn} - trans: {trans}".format(fn=fn, trans=len(trans))
                continue

            wav_key = recs[0].getAttribute('fname')
            wav_path = os.path.join(f_dir, wav_key)
            
            # FIXME: Check whether the last transcription is really the best! FJ
            t = various.get_text_from_xml_node(trans[-1])
            t = normalise_text(t)

            
            if '--asr-log' not in sys.argv:
                asr_rec_nbl = asr_rec.rec_wav_file(wav_path)
                a = unicode(asr_rec_nbl.get_best())
            else:  
                a = various.get_text_from_xml_node(hyps[0])
                a = normalise_semi_words(a)

            if exclude_slu(t) or 'DOM Element:' in a:
                print "Skipping transcription:", unicode(t)
                print "Skipping ASR output:   ", unicode(a)
                continue

            # The silence does not have a label in the language model.
            t = t.replace('_SIL_','')

            trn.append((wav_key, t))

            print "Parsing transcription:", unicode(t)
            print "                  ASR:", unicode(a)

            # HDC SLU on transcription
            s = slu.parse_1_best({'utt':Utterance(t)}).get_best_da()
            trn_hdc_sem.append((wav_key, s))

            if '--uniq' not in sys.argv:
                # HDC SLU on 1 best ASR
                if '--asr-log' not in sys.argv:
                    a = unicode(asr_rec_nbl.get_best())
                else:  
                    a = various.get_text_from_xml_node(hyps[0])
                    a = normalise_semi_words(a)

                asr.append((wav_key, a))

                s = slu.parse_1_best({'utt':Utterance(a)}).get_best_da()
                asr_hdc_sem.append((wav_key, s))

                # HDC SLU on N best ASR
                n = UtteranceNBList()
                if '--asr-log' not in sys.argv:
                   n = asr_rec_nbl
                   
                   print 'ASR RECOGNITION NBLIST\n',unicode(n)
                else:
                    for h in hyps:
                        txt = various.get_text_from_xml_node(h)
                        txt = normalise_semi_words(txt)

                        n.add(abs(float(h.getAttribute('p'))),Utterance(txt))

                n.merge()
                n.normalise()

                nbl.append((wav_key, n.serialise()))

                if '--fast' not in sys.argv:
                    s = slu.parse_nblist({'utt_nbl':n}).get_best_da()
                nbl_hdc_sem.append((wav_key, s))

            # there is no manual semantics in the transcriptions yet
            sem.append((wav_key, None))


    uniq_trn = {}
    uniq_trn_hdc_sem = {}
    uniq_trn_sem = {}
    trn_set = set()

    sem = dict(trn_hdc_sem)
    for k, v in trn:
        if not v in trn_set:
            trn_set.add(v)
            uniq_trn[k] = v
            uniq_trn_hdc_sem[k] = sem[k]
            uniq_trn_sem[k] = v + " <=> " + unicode(sem[k])

    save_wavaskey(fn_uniq_trn, uniq_trn)
    save_wavaskey(fn_uniq_trn_hdc_sem, uniq_trn_hdc_sem, trans = lambda da: '&'.join(sorted(unicode(da).split('&'))))
    save_wavaskey(fn_uniq_trn_sem, uniq_trn_sem)

    # all
    save_wavaskey(fn_all_trn, dict(trn))
    save_wavaskey(fn_all_trn_hdc_sem, dict(trn_hdc_sem), trans = lambda da: '&'.join(sorted(unicode(da).split('&'))))

    if '--uniq' not in sys.argv:
        save_wavaskey(fn_all_asr, dict(asr))
        save_wavaskey(fn_all_asr_hdc_sem, dict(asr_hdc_sem), trans = lambda da: '&'.join(sorted(unicode(da).split('&'))))

        save_wavaskey(fn_all_nbl, dict(nbl))
        save_wavaskey(fn_all_nbl_hdc_sem, dict(nbl_hdc_sem), trans = lambda da: '&'.join(sorted(unicode(da).split('&'))))


        seed_value = 10

        random.seed(seed_value)
        random.shuffle(trn)
        random.seed(seed_value)
        random.shuffle(trn_hdc_sem)
        random.seed(seed_value)
        random.shuffle(asr)
        random.seed(seed_value)
        random.shuffle(asr_hdc_sem)
        random.seed(seed_value)
        random.shuffle(nbl)
        random.seed(seed_value)
        random.shuffle(nbl_hdc_sem)

        # trn
        train_trn = trn[:int(0.8*len(trn))]
        dev_trn = trn[int(0.8*len(trn)):int(0.9*len(trn))]
        test_trn = trn[int(0.9*len(trn)):]

        save_wavaskey(fn_train_trn, dict(train_trn))
        save_wavaskey(fn_dev_trn, dict(dev_trn))
        save_wavaskey(fn_test_trn, dict(test_trn))

        # trn_hdc_sem
        train_trn_hdc_sem = trn_hdc_sem[:int(0.8*len(trn_hdc_sem))]
        dev_trn_hdc_sem = trn_hdc_sem[int(0.8*len(trn_hdc_sem)):int(0.9*len(trn_hdc_sem))]
        test_trn_hdc_sem = trn_hdc_sem[int(0.9*len(trn_hdc_sem)):]

        save_wavaskey(fn_train_trn_hdc_sem, dict(train_trn_hdc_sem), trans = lambda da: '&'.join(sorted(unicode(da).split('&'))))
        save_wavaskey(fn_dev_trn_hdc_sem, dict(dev_trn_hdc_sem), trans = lambda da: '&'.join(sorted(unicode(da).split('&'))))
        save_wavaskey(fn_test_trn_hdc_sem, dict(test_trn_hdc_sem), trans = lambda da: '&'.join(sorted(unicode(da).split('&'))))

        # asr
        train_asr = asr[:int(0.8*len(asr))]
        dev_asr = asr[int(0.8*len(asr)):int(0.9*len(asr))]
        test_asr = asr[int(0.9*len(asr)):]

        save_wavaskey(fn_train_asr, dict(train_asr))
        save_wavaskey(fn_dev_asr, dict(dev_asr))
        save_wavaskey(fn_test_asr, dict(test_asr))

        # asr_hdc_sem
        train_asr_hdc_sem = asr_hdc_sem[:int(0.8*len(asr_hdc_sem))]
        dev_asr_hdc_sem = asr_hdc_sem[int(0.8*len(asr_hdc_sem)):int(0.9*len(asr_hdc_sem))]
        test_asr_hdc_sem = asr_hdc_sem[int(0.9*len(asr_hdc_sem)):]

        save_wavaskey(fn_train_asr_hdc_sem, dict(train_asr_hdc_sem), trans = lambda da: '&'.join(sorted(unicode(da).split('&'))))
        save_wavaskey(fn_dev_asr_hdc_sem, dict(dev_asr_hdc_sem), trans = lambda da: '&'.join(sorted(unicode(da).split('&'))))
        save_wavaskey(fn_test_asr_hdc_sem, dict(test_asr_hdc_sem), trans = lambda da: '&'.join(sorted(unicode(da).split('&'))))

        # n-best lists
        train_nbl = nbl[:int(0.8*len(nbl))]
        dev_nbl = nbl[int(0.8*len(nbl)):int(0.9*len(nbl))]
        test_nbl = nbl[int(0.9*len(nbl)):]

        save_wavaskey(fn_train_nbl, dict(train_nbl))
        save_wavaskey(fn_dev_nbl, dict(dev_nbl))
        save_wavaskey(fn_test_nbl, dict(test_nbl))

        # nbl_hdc_sem
        train_nbl_hdc_sem = nbl_hdc_sem[:int(0.8*len(nbl_hdc_sem))]
        dev_nbl_hdc_sem = nbl_hdc_sem[int(0.8*len(nbl_hdc_sem)):int(0.9*len(nbl_hdc_sem))]
        test_nbl_hdc_sem = nbl_hdc_sem[int(0.9*len(nbl_hdc_sem)):]

        save_wavaskey(fn_train_nbl_hdc_sem, dict(train_nbl_hdc_sem), trans = lambda da: '&'.join(sorted(unicode(da).split('&'))))
        save_wavaskey(fn_dev_nbl_hdc_sem, dict(dev_nbl_hdc_sem), trans = lambda da: '&'.join(sorted(unicode(da).split('&'))))
        save_wavaskey(fn_test_nbl_hdc_sem, dict(test_nbl_hdc_sem), trans = lambda da: '&'.join(sorted(unicode(da).split('&'))))
示例#15
0
 def setUpClass(cls):
     cfg = cls.get_cfg()
     slu_type = cfg['SLU']['type']
     cldb = CategoryLabelDatabase(cfg['SLU'][slu_type]['cldb_fname'])
     preprocessing = cfg['SLU'][slu_type]['preprocessing_cls'](cldb)
     cls.slu = slu_type(preprocessing, cfg)