Пример #1
0
"""
Evaluate NLU models on specified dataset
Usage: python evaluate.py [MultiWOZ|CrossWOZ] [TRADE|mdbt|sumbt|rule]
"""
import random
import numpy
import torch
from convlab2.dst.trade.crosswoz.trade import CrossWOZTRADE


def format_history(context):
    history = []
    for i in range(len(context)):
        history.append(['system' if i % 2 == 1 else 'user', context[i]])
    return history


if __name__ == '__main__':
    seed = 2020
    random.seed(seed)
    numpy.random.seed(seed)
    torch.manual_seed(seed)

    model = CrossWOZTRADE('model/TRADE-multiwozdst/HDD100BSZ4DR0.2ACC-0.3228')
    model.evaluate()
Пример #2
0
        evaluation_metrics = {
            "Joint Acc": joint_acc_score_ptr,
            "Turn Acc": turn_acc_score_ptr,
            "Joint F1": F1_score_ptr
        }
        print(evaluation_metrics)
    elif dataset_name.startswith('CrossWOZ'):
        en = dataset_name.endswith('en')
        if en:
            if model_name == 'sumbt':
                from convlab2.dst.sumbt.crosswoz_en.sumbt import SUMBTTracker
                model = SUMBTTracker()
        else:
            if model_name == 'TRADE':
                from convlab2.dst.trade.crosswoz.trade import CrossWOZTRADE
                model = CrossWOZTRADE()
            elif model_name == 'mdbt':
                pass
            elif model_name == 'sumbt':
                pass
            elif model_name == 'rule':
                pass
            else:
                raise Exception("Available models: TRADE")

        ## load data
        from convlab2.util.dataloader.module_dataloader import CrossWOZAgentDSTDataloader
        from convlab2.util.dataloader.dataset_dataloader import CrossWOZDataloader

        dataloader = CrossWOZAgentDSTDataloader(
            dataset_dataloader=CrossWOZDataloader(en))