Example #1
0
def build_model(bert_model_name, last_hidden_dropout_prob=0.0):

    bert_module = BertModule(bert_model_name)
    bert_output_dim = 768 if "base" in bert_model_name else 1024

    task_cardinality = (
        len(SuperGLUE_LABEL_MAPPING[TASK_NAME].keys())
        if SuperGLUE_LABEL_MAPPING[TASK_NAME] is not None
        else 1
    )

    metrics = (
        SuperGLUE_TASK_METRIC_MAPPING[TASK_NAME]
        if TASK_NAME in SuperGLUE_TASK_METRIC_MAPPING
        else []
    )

    customize_metric_funcs = {}

    loss_fn = partial(utils.ce_loss, f"{TASK_NAME}_pred_head")
    output_fn = partial(utils.output, f"{TASK_NAME}_pred_head")

    task = EmmentalTask(
        name=TASK_NAME,
        module_pool=nn.ModuleDict(
            {
                "bert_module": bert_module,
                f"{TASK_NAME}_feature": BertLastCLSModule(
                    dropout_prob=last_hidden_dropout_prob
                ),
                f"{TASK_NAME}_pred_head": nn.Linear(bert_output_dim, task_cardinality),
            }
        ),
        task_flow=[
            {
                "name": f"{TASK_NAME}_bert_module",
                "module": "bert_module",
                "inputs": [
                    ("_input_", "token_ids"),
                    ("_input_", "token_segments"),
                    ("_input_", "token_masks"),
                ],
            },
            {
                "name": f"{TASK_NAME}_feature",
                "module": f"{TASK_NAME}_feature",
                "inputs": [(f"{TASK_NAME}_bert_module", 0)],
            },
            {
                "name": f"{TASK_NAME}_pred_head",
                "module": f"{TASK_NAME}_pred_head",
                "inputs": [(f"{TASK_NAME}_feature", 0)],
            },
        ],
        loss_func=loss_fn,
        output_func=output_fn,
        scorer=Scorer(metrics=metrics, customize_metric_funcs=customize_metric_funcs),
    )

    return task
Example #2
0
def get_superglue_task(task_names, bert_model_name):

    tasks = dict()

    bert_module = BertModule(bert_model_name)
    bert_output_dim = 768 if "base" in bert_model_name else 1024

    for task_name in task_names:
        task_cardinality = (len(SuperGLUE_LABEL_MAPPING[task_name].keys())
                            if SuperGLUE_LABEL_MAPPING[task_name] is not None
                            else 1)

        metrics = (SuperGLUE_TASK_METRIC_MAPPING[task_name]
                   if task_name in SuperGLUE_TASK_METRIC_MAPPING else [])
        customize_metric_funcs = ({
            "em": em,
            "em_f1": em_f1
        } if task_name == "MultiRC" else {})

        loss_fn = partial(ce_loss, f"{task_name}_pred_head")
        output_fn = partial(output, f"{task_name}_pred_head")

        if task_name == "MultiRC":
            task = EmmentalTask(
                name=task_name,
                module_pool=nn.ModuleDict({
                    "bert_module":
                    bert_module,
                    "bert_last_CLS":
                    BertLastCLSModule(),
                    f"{task_name}_pred_head":
                    nn.Linear(bert_output_dim, task_cardinality),
                }),
                task_flow=[
                    {
                        "name": f"{task_name}_bert_module",
                        "module": "bert_module",
                        "inputs": [("_input_", "token_ids")],
                    },
                    {
                        "name": f"{task_name}_bert_last_CLS",
                        "module": "bert_last_CLS",
                        "inputs": [(f"{task_name}_bert_module", 0)],
                    },
                    {
                        "name": f"{task_name}_pred_head",
                        "module": f"{task_name}_pred_head",
                        "inputs": [(f"{task_name}_bert_last_CLS", 0)],
                    },
                ],
                loss_func=loss_fn,
                output_func=output_fn,
                scorer=Scorer(metrics=metrics,
                              customize_metric_funcs=customize_metric_funcs),
            )

            tasks[task_name] = task

    return tasks
Example #3
0
def build_model(bert_model_name, last_hidden_dropout_prob=0.0):

    bert_module = BertModule(bert_model_name)
    bert_output_dim = 768 if "base" in bert_model_name else 1024

    metrics = (SuperGLUE_TASK_METRIC_MAPPING[TASK_NAME]
               if TASK_NAME in SuperGLUE_TASK_METRIC_MAPPING else [])

    customize_metric_funcs = {}

    loss_fn = partial(utils.ce_loss_multiple_choice, f"{TASK_NAME}_pred_head",
                      NUM_CHOICES)
    output_fn = partial(utils.output_multiple_choice, f"{TASK_NAME}_pred_head",
                        NUM_CHOICES)

    task = EmmentalTask(
        name=TASK_NAME,
        module_pool=nn.ModuleDict({
            f"{TASK_NAME}_multiple_choice_module":
            MultipleChoiceModule(NUM_CHOICES),
            "bert_module":
            bert_module,
            f"{TASK_NAME}_feature":
            BertLastCLSModule(dropout_prob=last_hidden_dropout_prob),
            f"{TASK_NAME}_pred_head":
            nn.Linear(bert_output_dim, 1),
        }),
        task_flow=[
            {
                "name": f"{TASK_NAME}_multiple_choice_module",
                "module": f"{TASK_NAME}_multiple_choice_module",
                "inputs": [],
            },
            {
                "name":
                f"{TASK_NAME}_bert_module",
                "module":
                "bert_module",
                "inputs": [
                    (f"{TASK_NAME}_multiple_choice_module", 0),
                    (f"{TASK_NAME}_multiple_choice_module", 1),
                    (f"{TASK_NAME}_multiple_choice_module", 2),
                ],
            },
            {
                "name": f"{TASK_NAME}_feature",
                "module": f"{TASK_NAME}_feature",
                "inputs": [(f"{TASK_NAME}_bert_module", 0)],
            },
            {
                "name": f"{TASK_NAME}_pred_head",
                "module": f"{TASK_NAME}_pred_head",
                "inputs": [(f"{TASK_NAME}_feature", 0)],
            },
        ],
        loss_func=loss_fn,
        output_func=output_fn,
        scorer=Scorer(metrics=metrics,
                      customize_metric_funcs=customize_metric_funcs),
    )

    return task