示例#1
0
    def test_parse_X(self):
        cldb = CategoryLabelDatabase()
        class db:
            database = {
                "task": {
                    "find_connection": ["najít spojení", "najít spoj", "zjistit spojení",
                                        "zjistit spoj", "hledám spojení", 'spojení', 'spoj',
                                       ],
                    "find_platform": ["najít nástupiště", "zjistit nástupiště", ],
                    'weather': ['pocasi', ],
                },
                "number": {
                    "1": ["jednu"]
                },
                "time": {
                    "now": ["nyní", "teď", "teďka", "hned", "nejbližší", "v tuto chvíli", "co nejdřív"],
                },
            }

        cldb.load(db_mod=db)

        preprocessing = SLUPreprocessing(cldb)
        clf = DAILogRegClassifier(cldb, preprocessing, features_size=4)

        # Train a simple classifier.
        das = {
            '1': DialogueAct('inform(task=weather)'),
            '2': DialogueAct('inform(time=now)'),
            '3': DialogueAct('inform(task=weather)'),
        }
        utterances = {
            '1': Utterance('pocasi'),
            '2': Utterance('hned'),
            '3': Utterance('jak bude'),
        }
        clf.extract_classifiers(das, utterances, verbose=False)
        clf.prune_classifiers(min_classifier_count=0)
        clf.gen_classifiers_data(min_pos_feature_count=0,
                                 min_neg_feature_count=0,
                                 verbose2=False)

        clf.train(inverse_regularisation=1e1, verbose=False)

        # Parse some sentences.
        utterance_list = UtteranceNBList()
        utterance_list.add(0.7, Utterance('pocasi'))
        utterance_list.add(0.7, Utterance('pocasi jak bude'))
        utterance_list.add(0.2, Utterance('hned'))

        da_confnet = clf.parse_X(utterance_list, verbose=False)


        self.assertTrue(da_confnet.get_prob(DialogueActItem(dai='inform(task=weather)')) > 0.5)
        self.assertTrue(da_confnet.get_prob(DialogueActItem(dai='inform(time=now)')) < 0.5)
示例#2
0
def train(fn_model,
          fn_transcription,
          constructor,
          fn_annotation,
          fn_bs_transcription,
          fn_bs_annotation,
          min_pos_feature_count,
          min_neg_feature_count,
          min_classifier_count,
          limit=100000):
    """
    Trains a SLU DAILogRegClassifier model.

    :param fn_model:
    :param fn_transcription:
    :param constructor:
    :param fn_annotation:
    :param limit:
    :return:
    """
    bs_utterances = load_wavaskey(fn_bs_transcription, Utterance, limit=limit)
    increase_weight(bs_utterances, min_pos_feature_count + 10)
    bs_das = load_wavaskey(fn_bs_annotation, DialogueAct, limit=limit)
    increase_weight(bs_das, min_pos_feature_count + 10)

    utterances = load_wavaskey(fn_transcription, constructor, limit=limit)
    das = load_wavaskey(fn_annotation, DialogueAct, limit=limit)

    utterances.update(bs_utterances)
    das.update(bs_das)

    cldb = CategoryLabelDatabase('../../data/database.py')
    preprocessing = PTICSSLUPreprocessing(cldb)
    slu = DAILogRegClassifier(cldb, preprocessing, features_size=4)

    slu.extract_classifiers(das, utterances, verbose=True)
    slu.prune_classifiers(min_classifier_count=min_classifier_count)
    slu.print_classifiers()
    slu.gen_classifiers_data(min_pos_feature_count=min_pos_feature_count,
                             min_neg_feature_count=min_neg_feature_count,
                             verbose2=True)

    slu.train(inverse_regularisation=1e1, verbose=True)

    slu.save_model(fn_model)
示例#3
0
文件: train.py 项目: UFAL-DSG/alex
def train(fn_model,
          fn_transcription, constructor, fn_annotation,
          fn_bs_transcription, fn_bs_annotation,
          min_pos_feature_count,
          min_neg_feature_count,
          min_classifier_count,
          limit = 100000):
    """
    Trains a SLU DAILogRegClassifier model.

    :param fn_model:
    :param fn_transcription:
    :param constructor:
    :param fn_annotation:
    :param limit:
    :return:
    """
    bs_utterances = load_wavaskey(fn_bs_transcription, Utterance, limit = limit)
    increase_weight(bs_utterances, min_pos_feature_count+10)
    bs_das = load_wavaskey(fn_bs_annotation, DialogueAct, limit = limit)
    increase_weight(bs_das, min_pos_feature_count+10)

    utterances = load_wavaskey(fn_transcription, constructor, limit = limit)
    das = load_wavaskey(fn_annotation, DialogueAct, limit = limit)

    utterances.update(bs_utterances)
    das.update(bs_das)

    cldb = CategoryLabelDatabase('../../data/database.py')
    preprocessing = PTICSSLUPreprocessing(cldb)
    slu = DAILogRegClassifier(cldb, preprocessing, features_size=4)

    slu.extract_classifiers(das, utterances, verbose=True)
    slu.prune_classifiers(min_classifier_count = min_classifier_count)
    slu.print_classifiers()
    slu.gen_classifiers_data(min_pos_feature_count = min_pos_feature_count,
                             min_neg_feature_count = min_neg_feature_count,
                             verbose2 = True)

    slu.train(inverse_regularisation=1e1, verbose=True)

    slu.save_model(fn_model)