Ejemplo n.º 1
0
    def __init__(self, task_def_path):
        self._task_def_dic = yaml.safe_load(open(task_def_path))
        global_map = {}
        n_class_map = {}
        data_type_map = {}
        task_type_map = {}
        metric_meta_map = {}
        enable_san_map = {}
        dropout_p_map = {}
        encoderType_map = {}
        loss_map = {}
        kd_loss_map = {}

        uniq_encoderType = set()
        for task, task_def in self._task_def_dic.items():
            assert "_" not in task, "task name should not contain '_', current task name: %s" % task
            n_class_map[task] = task_def["n_class"]
            data_format = DataFormat[task_def["data_format"]]
            data_type_map[task] = data_format
            task_type_map[task] = TaskType[task_def["task_type"]]
            metric_meta_map[task] = tuple(
                Metric[metric_name] for metric_name in task_def["metric_meta"])
            enable_san_map[task] = task_def["enable_san"]
            uniq_encoderType.add(EncoderModelType[task_def["encoder_type"]])
            if "labels" in task_def:
                labels = task_def["labels"]
                label_mapper = Vocabulary(True)
                for label in labels:
                    label_mapper.add(label)
                global_map[task] = label_mapper
            if "dropout_p" in task_def:
                dropout_p_map[task] = task_def["dropout_p"]
            # loss map
            if "loss" in task_def:
                t_loss = task_def["loss"]
                loss_crt = LossCriterion[t_loss]
                loss_map[task] = loss_crt
            else:
                loss_map[task] = None

            if "kd_loss" in task_def:
                t_loss = task_def["kd_loss"]
                loss_crt = LossCriterion[t_loss]
                kd_loss_map[task] = loss_crt
            else:
                kd_loss_map[task] = None

        assert len(
            uniq_encoderType) == 1, 'The shared encoder has to be the same.'
        self.global_map = global_map
        self.n_class_map = n_class_map
        self.data_type_map = data_type_map
        self.task_type_map = task_type_map
        self.metric_meta_map = metric_meta_map
        self.enable_san_map = enable_san_map
        self.dropout_p_map = dropout_p_map
        self.encoderType = uniq_encoderType.pop()
        self.loss_map = loss_map
        self.kd_loss_map = kd_loss_map
Ejemplo n.º 2
0
    def __init__(self, task_def_path):
        self._task_def_dic = yaml.safe_load(open(task_def_path))
        global_map = {}
        n_class_map = {}
        data_type_map = {}
        task_type_map = {}
        metric_meta_map = {}
        split_names_map = {}
        enable_san_map = {}
        dropout_p_map = {}
        loss_map = {}
        kd_loss_map = {}

        for task, task_def in self._task_def_dic.items():
            assert "_" not in task, "task name should not contain '_', current task name: %s" % task
            n_class_map[task] = task_def["n_class"]
            data_format = DataFormat[task_def["data_format"]]
            data_type_map[task] = data_format
            task_type_map[task] = TaskType[task_def["task_type"]]
            metric_meta_map[task] = tuple(
                Metric[metric_name] for metric_name in task_def["metric_meta"])
            split_names_map[task] = task_def.get("split_names",
                                                 ["train", "dev", "test"])
            enable_san_map[task] = task_def["enable_san"]
            if "labels" in task_def:
                labels = task_def["labels"]
                label_mapper = Vocabulary(True)
                for label in labels:
                    label_mapper.add(label)
                global_map[task] = label_mapper
            if "dropout_p" in task_def:
                dropout_p_map[task] = task_def["dropout_p"]
            # loss map
            if "loss" in task_def:
                t_loss = task_def["loss"]
                loss_crt = LossCriterion[t_loss]
                loss_map[task] = loss_crt
            else:
                loss_map[task] = None

            if "kd_loss" in task_def:
                t_loss = task_def["kd_loss"]
                loss_crt = LossCriterion[t_loss]
                kd_loss_map[task] = loss_crt
            else:
                kd_loss_map[task] = None

        self._global_map = global_map
        self._n_class_map = n_class_map
        self._data_type_map = data_type_map
        self._task_type_map = task_type_map
        self._metric_meta_map = metric_meta_map
        self._split_names_map = split_names_map
        self._enable_san_map = enable_san_map
        self._dropout_p_map = dropout_p_map
        self._loss_map = loss_map
        self._kd_loss_map = kd_loss_map

        self._task_def_dic = {}
Ejemplo n.º 3
0
    def __init__(self, task_def_path):
        self._task_def_dic = yaml.safe_load(open(task_def_path))
        global_map = {}
        n_class_map = {}
        data_format_map = {}
        data_type_map = {}
        task_type_map = {}
        metric_meta_map = {}
        enable_san_map = {}
        dropout_p_map = {}
        tasks = []
        split_names_map = {}
        for task, task_def in self._task_def_dic.items():
            tasks.append(task)
            assert "_" not in task, "task name should not contain '_', current task name: %s" % task
            n_class_map[task] = task_def["n_class"]
            data_format = DataFormat[task_def["data_format"]]
            data_format_map[task] = data_format
            if data_format == DataFormat.PremiseOnly:
                data_type_map[task] = 1
            elif data_format in (DataFormat.PremiseAndMultiHypothesis,
                                 DataFormat.PremiseAndOneHypothesis):
                data_type_map[task] = 0
            else:
                raise ValueError(data_format)
            task_type_map[task] = TaskType[task_def["task_type"]]
            metric_meta_map[task] = tuple(
                Metric[metric_name] for metric_name in task_def["metric_meta"])
            enable_san_map[task] = task_def["enable_san"]
            if "labels" in task_def:
                labels = task_def["labels"]
                label_mapper = Vocabulary(True)
                for label in labels:
                    label_mapper.add(label)
                global_map[task] = label_mapper
            if "dropout_p" in task_def:
                dropout_p_map[task] = task_def["dropout_p"]
            split_names = task_def.get("split_names", ["train", "dev", "test"])
            split_names_map[task] = split_names

        self.tasks = tasks
        self.global_map = global_map
        self.n_class_map = n_class_map
        self.data_format_map = data_format_map
        self.data_type_map = data_type_map
        self.task_type_map = task_type_map
        self.metric_meta_map = metric_meta_map
        self.enable_san_map = enable_san_map
        self.dropout_p_map = dropout_p_map
        self.split_names_map = split_names_map
Ejemplo n.º 4
0
    def __init__(self, task_def_path):
        self._task_def_dic = yaml.safe_load(open(task_def_path))
        global_map = {}
        n_class_map = {}
        data_type_map = {}
        task_type_map = {}
        metric_meta_map = {}
        enable_san_map = {}
        dropout_p_map = {}
        encoderType_map = {}
        uniq_encoderType = set()
        for task, task_def in self._task_def_dic.items():
            assert "_" not in task, "task name should not contain '_', current task name: %s" % task
            n_class_map[task] = task_def["n_class"]
            data_format = DataFormat[task_def["data_format"]]
            if data_format == DataFormat.PremiseOnly:
                data_type_map[task] = 1
            elif data_format in (DataFormat.PremiseAndMultiHypothesis,
                                 DataFormat.PremiseAndOneHypothesis):
                data_type_map[task] = 0
            else:
                raise ValueError(data_format)
            task_type_map[task] = TaskType[task_def["task_type"]]
            metric_meta_map[task] = tuple(
                Metric[metric_name] for metric_name in task_def["metric_meta"])
            enable_san_map[task] = task_def["enable_san"]
            uniq_encoderType.add(EncoderModelType[task_def["encoder_type"]])
            if "labels" in task_def:
                labels = task_def["labels"]
                label_mapper = Vocabulary(True)
                for label in labels:
                    label_mapper.add(label)
                global_map[task] = label_mapper
            if "dropout_p" in task_def:
                dropout_p_map[task] = task_def["dropout_p"]

        assert len(
            uniq_encoderType) == 1, 'The shared encoder has to be the same.'
        self.global_map = global_map
        self.n_class_map = n_class_map
        self.data_type_map = data_type_map
        self.task_type_map = task_type_map
        self.metric_meta_map = metric_meta_map
        self.enable_san_map = enable_san_map
        self.dropout_p_map = dropout_p_map
        self.encoderType = uniq_encoderType.pop()
Ejemplo n.º 5
0
def main(args):
    ## hyper param
    do_lower_case = args.do_lower_case
    root = args.root_dir
    assert os.path.exists(root)
    is_uncased = False
    if 'uncased' in args.bert_model:
        is_uncased = True

    bert_tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=do_lower_case)

    mt_dnn_suffix = 'mt_dnn'
    if is_uncased:
        mt_dnn_suffix = '{}_uncased'.format(mt_dnn_suffix)
    else:
        mt_dnn_suffix = '{}_cased'.format(mt_dnn_suffix)

    if do_lower_case:
        mt_dnn_suffix = '{}_lower'.format(mt_dnn_suffix)

    mt_dnn_root = os.path.join(root, mt_dnn_suffix)
    if not os.path.isdir(mt_dnn_root):
        os.mkdir(mt_dnn_root)

    task_def_dic = yaml.safe_load(open(args.task_def))

    for task, task_def in task_def_dic.items():
        logger.info("Task %s" % task)
        data_format = DataFormat[task_def["data_format"]]
        task_type = TaskType[task_def["task_type"]]
        label_mapper = None
        if "labels" in task_def:
            labels = task_def["labels"]
            label_mapper = Vocabulary(True)
            for label in labels:
                label_mapper.add(label)
        split_names = task_def.get("split_names", ["train", "dev", "test"])
        for split_name in split_names:
            rows = load_data(os.path.join(root, "%s_%s.tsv" % (task, split_name)), data_format, task_type, label_mapper)
            dump_path = os.path.join(mt_dnn_root, "%s_%s.json" % (task, split_name))
            logger.info(dump_path)
            build_data(rows, dump_path, bert_tokenizer, data_format)
Ejemplo n.º 6
0
def load_dict(path):
    vocab = Vocabulary(neat=True)
    vocab.add('<s>')
    vocab.add('<pad>')
    vocab.add('</s>')
    vocab.add('<unk>')
    with open(path, 'r', encoding='utf8') as reader:
        for line in reader:
            idx = line.rfind(' ')
            if idx == -1:
                raise ValueError(
                    "Incorrect dictionary format, expected '<token> <cnt>'")
            word = line[:idx]
            vocab.add(word)
    return vocab
Ejemplo n.º 7
0
# Copyright (c) Microsoft. All rights reserved.

from data_utils.vocab import Vocabulary
from data_utils.metrics import compute_acc, compute_f1, compute_mcc, compute_pearson, compute_spearman

# scitail
ScitailLabelMapper = Vocabulary(True)
ScitailLabelMapper.add('neutral')
ScitailLabelMapper.add('entails')

# label map
SNLI_LabelMapper = Vocabulary(True)
SNLI_LabelMapper.add('contradiction')
SNLI_LabelMapper.add('neutral')
SNLI_LabelMapper.add('entailment')

# qnli
QNLILabelMapper = Vocabulary(True)
QNLILabelMapper.add('not_entailment')
QNLILabelMapper.add('entailment')

GLOBAL_MAP = {
    'scitail': ScitailLabelMapper,
    'mnli': SNLI_LabelMapper,
    'snli': SNLI_LabelMapper,
    'qnli': QNLILabelMapper,
    'qnnli': QNLILabelMapper,
    'rte': QNLILabelMapper,
    'diag': SNLI_LabelMapper,
}