예제 #1
0
파일: train.py 프로젝트: UFAL-DSG/alex
def train(fn_model,
          fn_transcription, constructor, fn_annotation,
          fn_bs_transcription, fn_bs_annotation,
          min_pos_feature_count,
          min_neg_feature_count,
          min_classifier_count,
          limit = 100000):
    """
    Trains a SLU DAILogRegClassifier model.

    :param fn_model:
    :param fn_transcription:
    :param constructor:
    :param fn_annotation:
    :param limit:
    :return:
    """
    bs_utterances = load_wavaskey(fn_bs_transcription, Utterance, limit = limit)
    increase_weight(bs_utterances, min_pos_feature_count+10)
    bs_das = load_wavaskey(fn_bs_annotation, DialogueAct, limit = limit)
    increase_weight(bs_das, min_pos_feature_count+10)

    utterances = load_wavaskey(fn_transcription, constructor, limit = limit)
    das = load_wavaskey(fn_annotation, DialogueAct, limit = limit)

    utterances.update(bs_utterances)
    das.update(bs_das)

    cldb = CategoryLabelDatabase('../../data/database.py')
    preprocessing = PTICSSLUPreprocessing(cldb)
    slu = DAILogRegClassifier(cldb, preprocessing, features_size=4)

    slu.extract_classifiers(das, utterances, verbose=True)
    slu.prune_classifiers(min_classifier_count = min_classifier_count)
    slu.print_classifiers()
    slu.gen_classifiers_data(min_pos_feature_count = min_pos_feature_count,
                             min_neg_feature_count = min_neg_feature_count,
                             verbose2 = True)

    slu.train(inverse_regularisation=1e1, verbose=True)

    slu.save_model(fn_model)
예제 #2
0
파일: test.py 프로젝트: tkraut/alex
def trained_slu_test(fn_model, fn_input, constructor, fn_reference):
    """
    Tests a SLU DAILogRegClassifier model.

    :param fn_model:
    :param fn_input:
    :param constructor:
    :param fn_reference:
    :return:
    """
    print "="*120
    print "Testing: ", fn_model, fn_input, fn_reference
    print "-"*120

    from alex.applications.PublicTransportInfoCS.preprocessing import PTICSSLUPreprocessing
    from alex.components.slu.base import CategoryLabelDatabase
    from alex.components.slu.dailrclassifier import DAILogRegClassifier
    from alex.corpustools.wavaskey import load_wavaskey, save_wavaskey
    from alex.corpustools.semscore import score

    cldb = CategoryLabelDatabase('../data/database.py')
    preprocessing = PTICSSLUPreprocessing(cldb)
    slu = DAILogRegClassifier(cldb, preprocessing)

    slu.load_model(fn_model)

    test_utterances = load_wavaskey(fn_input, constructor, limit=100000)

    parsed_das = {}
    for utt_key, utt in sorted(test_utterances.iteritems()):
        if isinstance(utt, Utterance):
            obs = {'utt': utt}
        elif isinstance(utt, UtteranceNBList):
            obs = {'utt_nbl': utt}
        else:
            raise BaseException('Unsupported observation type')

        print '-' * 120
        print "Observation:"
        print utt_key, " ==> "
        print unicode(utt)

        da_confnet = slu.parse(obs, verbose=False)

        print "Conf net:"
        print unicode(da_confnet)

        da_confnet.prune()
        dah = da_confnet.get_best_da_hyp()

        print "1 best: "
        print unicode(dah)

        parsed_das[utt_key] = dah.da

        if 'CL_' in str(dah.da):
            print '*' * 120
            print utt
            print dah.da
            slu.parse(obs, verbose=True)

    if 'trn' in fn_model:
        fn_sem = os.path.basename(fn_input)+'.model.trn.sem.out'
    elif 'asr' in fn_model:
        fn_sem = os.path.basename(fn_input)+'.model.asr.sem.out'
    elif 'nbl' in fn_model:
        fn_sem = os.path.basename(fn_input)+'.model.nbl.sem.out'
    else:
        fn_sem = os.path.basename(fn_input)+'.XXX.sem.out'

    save_wavaskey(fn_sem, parsed_das, trans = lambda da: '&'.join(sorted(unicode(da).split('&'))))

    f = codecs.open(os.path.basename(fn_sem)+'.score', 'w+', encoding='UTF-8')
    score(fn_reference, fn_sem, True, True, f)
    f.close()
예제 #3
0
    def test_parse_X(self):
        cldb = CategoryLabelDatabase()
        class db:
            database = {
                "task": {
                    "find_connection": ["najít spojení", "najít spoj", "zjistit spojení",
                                        "zjistit spoj", "hledám spojení", 'spojení', 'spoj',
                                       ],
                    "find_platform": ["najít nástupiště", "zjistit nástupiště", ],
                    'weather': ['pocasi', ],
                },
                "number": {
                    "1": ["jednu"]
                },
                "time": {
                    "now": ["nyní", "teď", "teďka", "hned", "nejbližší", "v tuto chvíli", "co nejdřív"],
                },
            }

        cldb.load(db_mod=db)

        preprocessing = SLUPreprocessing(cldb)
        clf = DAILogRegClassifier(cldb, preprocessing, features_size=4)

        # Train a simple classifier.
        das = {
            '1': DialogueAct('inform(task=weather)'),
            '2': DialogueAct('inform(time=now)'),
            '3': DialogueAct('inform(task=weather)'),
        }
        utterances = {
            '1': Utterance('pocasi'),
            '2': Utterance('hned'),
            '3': Utterance('jak bude'),
        }
        clf.extract_classifiers(das, utterances, verbose=False)
        clf.prune_classifiers(min_classifier_count=0)
        clf.gen_classifiers_data(min_pos_feature_count=0,
                                 min_neg_feature_count=0,
                                 verbose2=False)

        clf.train(inverse_regularisation=1e1, verbose=False)

        # Parse some sentences.
        utterance_list = UtteranceNBList()
        utterance_list.add(0.7, Utterance('pocasi'))
        utterance_list.add(0.7, Utterance('pocasi jak bude'))
        utterance_list.add(0.2, Utterance('hned'))

        da_confnet = clf.parse_X(utterance_list, verbose=False)


        self.assertTrue(da_confnet.get_prob(DialogueActItem(dai='inform(task=weather)')) > 0.5)
        self.assertTrue(da_confnet.get_prob(DialogueActItem(dai='inform(time=now)')) < 0.5)
예제 #4
0
def train(fn_model,
          fn_transcription,
          constructor,
          fn_annotation,
          fn_bs_transcription,
          fn_bs_annotation,
          min_feature_count=2,
          min_classifier_count=2,
          limit=100000):
    """
    Trains a SLU DAILogRegClassifier model.

    :param fn_model:
    :param fn_transcription:
    :param constructor:
    :param fn_annotation:
    :param limit:
    :return:
    """
    bs_utterances = load_wavaskey(fn_bs_transcription, Utterance, limit=limit)
    increase_weight(bs_utterances, min_feature_count + 10)
    bs_das = load_wavaskey(fn_bs_annotation, DialogueAct, limit=limit)
    increase_weight(bs_das, min_feature_count + 10)

    utterances = load_wavaskey(fn_transcription, constructor, limit=limit)
    das = load_wavaskey(fn_annotation, DialogueAct, limit=limit)

    utterances.update(bs_utterances)
    das.update(bs_das)

    cldb = CategoryLabelDatabase('../data/database.py')
    preprocessing = PTICSSLUPreprocessing(cldb)
    slu = DAILogRegClassifier(cldb, preprocessing, features_size=4)

    slu.extract_classifiers(das, utterances, verbose=True)
    slu.prune_classifiers(min_classifier_count=min_classifier_count)
    slu.print_classifiers()
    slu.gen_classifiers_data()
    slu.prune_features(min_feature_count=min_feature_count, verbose=True)

    slu.train(inverse_regularisation=1e1, verbose=True)

    slu.save_model(fn_model)
예제 #5
0
파일: test.py 프로젝트: elnaaz/alex
def trained_slu_test(fn_model, fn_input, constructor, fn_reference):
    """
    Tests a SLU DAILogRegClassifier model.

    :param fn_model:
    :param fn_input:
    :param constructor:
    :param fn_reference:
    :return:
    """
    print "="*120
    print "Testing: ", fn_model, fn_input, fn_reference
    print "-"*120

    from alex.applications.PublicTransportInfoCS.preprocessing import PTICSSLUPreprocessing
    from alex.components.slu.base import CategoryLabelDatabase
    from alex.components.slu.dailrclassifier import DAILogRegClassifier
    from alex.corpustools.wavaskey import load_wavaskey, save_wavaskey
    from alex.corpustools.semscore import score

    cldb = CategoryLabelDatabase('../data/database.py')
    preprocessing = PTICSSLUPreprocessing(cldb)
    slu = DAILogRegClassifier(cldb, preprocessing)

    slu.load_model(fn_model)

    test_utterances = load_wavaskey(fn_input, constructor, limit=100000)

    parsed_das = {}
    for utt_key, utt in sorted(test_utterances.iteritems()):
        if isinstance(utt, Utterance):
            obs = {'utt': utt}
        elif isinstance(utt, UtteranceNBList):
            obs = {'utt_nbl': utt}
        else:
            raise BaseException('Unsupported observation type')

        print '-' * 120
        print "Observation:"
        print utt_key, " ==> "
        print unicode(utt)

        da_confnet = slu.parse(obs, verbose=False)

        print "Conf net:"
        print unicode(da_confnet)

        da_confnet.prune()
        dah = da_confnet.get_best_da_hyp()

        print "1 best: "
        print unicode(dah)

        parsed_das[utt_key] = dah.da

        if 'CL_' in str(dah.da):
            print '*' * 120
            print utt
            print dah.da
            slu.parse(obs, verbose=True)

    if 'trn' in fn_model:
        fn_sem = os.path.basename(fn_input)+'.model.trn.sem.out'
    elif 'asr' in fn_model:
        fn_sem = os.path.basename(fn_input)+'.model.asr.sem.out'
    elif 'nbl' in fn_model:
        fn_sem = os.path.basename(fn_input)+'.model.nbl.sem.out'
    else:
        fn_sem = os.path.basename(fn_input)+'.XXX.sem.out'

    save_wavaskey(fn_sem, parsed_das, trans = lambda da: '&'.join(sorted(unicode(da).split('&'))))

    f = codecs.open(os.path.basename(fn_sem)+'.score', 'w+', encoding='UTF-8')
    score(fn_reference, fn_sem, True, True, f)
    f.close()