Esempio n. 1
0
def build_task(xlnet_model_name, last_hidden_dropout_prob=0.0):

    xlnet_module = XLNetModule(xlnet_model_name)
    xlnet_output_dim = 768 if "base" in xlnet_model_name else 1024

    task_cardinality = (
        len(SuperGLUE_LABEL_MAPPING[TASK_NAME].keys())
        if SuperGLUE_LABEL_MAPPING[TASK_NAME] is not None
        else 1
    )

    metrics = (
        SuperGLUE_TASK_METRIC_MAPPING[TASK_NAME]
        if TASK_NAME in SuperGLUE_TASK_METRIC_MAPPING
        else []
    )

    custom_metric_funcs = {}

    loss_fn = partial(utils.ce_loss, f"{TASK_NAME}_pred_head")
    output_fn = partial(utils.output, f"{TASK_NAME}_pred_head")

    task = Task(
        name=TASK_NAME,
        module_pool=nn.ModuleDict(
            {
                "xlnet_module": xlnet_module,
                f"{TASK_NAME}_feature": XLNetLastCLSModule(
                    dropout_prob=last_hidden_dropout_prob
                ),
                f"{TASK_NAME}_pred_head": nn.Linear(xlnet_output_dim, task_cardinality),
            }
        ),
        task_flow=[
            Operation(
                name=f"{TASK_NAME}_xlnet_module",
                module_name="xlnet_module",
                inputs=[
                    ("_input_", "token_ids"),
                    ("_input_", "token_segments"),
                    ("_input_", "token_masks"),
                ],
            ),
            Operation(
                name=f"{TASK_NAME}_feature",
                module_name=f"{TASK_NAME}_feature",
                inputs=[(f"{TASK_NAME}_xlnet_module", 0)],
            ),
            Operation(
                name=f"{TASK_NAME}_pred_head",
                module_name=f"{TASK_NAME}_pred_head",
                inputs=[(f"{TASK_NAME}_feature", 0)],
            ),
        ],
        loss_func=loss_fn,
        output_func=output_fn,
        scorer=Scorer(metrics=metrics, custom_metric_funcs=custom_metric_funcs),
    )

    return task
Esempio n. 2
0
def build_task(xlnet_model_name, last_hidden_dropout_prob=None):
    if last_hidden_dropout_prob:
        raise NotImplementedError(
            f"TODO: last_hidden_dropout_prob for {TASK_NAME}")

    xlnet_module = XLNetModule(xlnet_model_name)
    xlnet_output_dim = 768 if "base" in xlnet_model_name else 1024

    task_cardinality = (len(SuperGLUE_LABEL_MAPPING[TASK_NAME].keys()) if
                        SuperGLUE_LABEL_MAPPING[TASK_NAME] is not None else 1)

    metrics = (SuperGLUE_TASK_METRIC_MAPPING[TASK_NAME]
               if TASK_NAME in SuperGLUE_TASK_METRIC_MAPPING else [])

    custom_metric_funcs = {}

    loss_fn = partial(utils.ce_loss, f"{TASK_NAME}_pred_head")
    output_fn = partial(utils.output, f"{TASK_NAME}_pred_head")

    task = Task(
        name=TASK_NAME,
        module_pool=nn.ModuleDict({
            "xlnet_module":
            xlnet_module,
            f"{TASK_NAME}_pred_head":
            SpanClassifierModule(d_inp=xlnet_output_dim,
                                 proj_dim=xlnet_output_dim // 2),
        }),
        task_flow=[
            Operation(
                name=f"{TASK_NAME}_xlnet_module",
                module_name="xlnet_module",
                inputs=[
                    ("_input_", "token_ids"),
                    ("_input_", "token_segments"),
                    ("_input_", "token_masks"),
                ],
            ),
            Operation(
                name=f"{TASK_NAME}_pred_head",
                module_name=f"{TASK_NAME}_pred_head",
                inputs=[
                    (f"{TASK_NAME}_xlnet_module", 0),
                    ("_input_", "token1_idx"),
                    ("_input_", "token2_idx"),
                    ("_input_", "token_masks"),
                ],
            ),
        ],
        loss_func=loss_fn,
        output_func=output_fn,
        scorer=Scorer(metrics=metrics,
                      custom_metric_funcs=custom_metric_funcs),
    )

    return task
Esempio n. 3
0
def build_task(bert_model_name, last_hidden_dropout_prob=0.0):

    bert_module = BertModule(bert_model_name)
    bert_output_dim = 768 if "base" in bert_model_name else 1024

    task_cardinality = (len(SuperGLUE_LABEL_MAPPING[TASK_NAME].keys()) if
                        SuperGLUE_LABEL_MAPPING[TASK_NAME] is not None else 1)

    metrics = (SuperGLUE_TASK_METRIC_MAPPING[TASK_NAME]
               if TASK_NAME in SuperGLUE_TASK_METRIC_MAPPING else [])

    custom_metric_funcs = {}

    loss_fn = partial(utils.ce_loss, f"{TASK_NAME}_pred_head")
    output_fn = partial(utils.output, f"{TASK_NAME}_pred_head")

    task = Task(
        name=TASK_NAME,
        module_pool=nn.ModuleDict({
            "bert_module":
            bert_module,
            "bert_last_cls":
            BertLastCLSModule(dropout_prob=last_hidden_dropout_prob),
            "linear_module":
            nn.Linear(bert_output_dim, 1),
            f"{TASK_NAME}_pred_head":
            ChoiceModule(task_cardinality),
        }),
        task_flow=[
            Operation(
                name="choice0",
                module_name="bert_module",
                inputs=[("_input_", "token1_ids")],
            ),
            Operation(
                name="choice1",
                module_name="bert_module",
                inputs=[("_input_", "token2_ids")],
            ),
            Operation(
                name="choice2",
                module_name="bert_module",
                inputs=[("_input_", "token3_ids")],
            ),
            Operation(
                name="choice3",
                module_name="bert_module",
                inputs=[("_input_", "token4_ids")],
            ),
            Operation(
                name="choice0_bert_last_cls",
                module_name="bert_last_cls",
                inputs=[("choice0", 0)],
            ),
            Operation(
                name="choice1_bert_last_cls",
                module_name="bert_last_cls",
                inputs=[("choice1", 0)],
            ),
            Operation(
                name="choice2_bert_last_cls",
                module_name="bert_last_cls",
                inputs=[("choice2", 0)],
            ),
            Operation(
                name="choice3_bert_last_cls",
                module_name="bert_last_cls",
                inputs=[("choice3", 0)],
            ),
            Operation(
                name="choice0rep",
                module_name="linear_module",
                inputs=[("choice0_bert_last_cls", 0)],
            ),
            Operation(
                name="choice1rep",
                module_name="linear_module",
                inputs=[("choice1_bert_last_cls", 0)],
            ),
            Operation(
                name="choice2rep",
                module_name="linear_module",
                inputs=[("choice2_bert_last_cls", 0)],
            ),
            Operation(
                name="choice3rep",
                module_name="linear_module",
                inputs=[("choice3_bert_last_cls", 0)],
            ),
            Operation(
                name=f"{TASK_NAME}_pred_head",
                module_name=f"{TASK_NAME}_pred_head",
                inputs=[],
            ),
        ],
        loss_func=loss_fn,
        output_func=output_fn,
        scorer=Scorer(metrics=metrics,
                      custom_metric_funcs=custom_metric_funcs),
    )

    return task
Esempio n. 4
0
def build_task(bert_model_name, last_hidden_dropout_prob=None):
    if last_hidden_dropout_prob:
        raise NotImplementedError(
            f"TODO: last_hidden_dropout_prob for {TASK_NAME}")

    bert_module = BertModule(bert_model_name)
    bert_output_dim = 768 if "base" in bert_model_name else 1024

    task_cardinality = (len(SuperGLUE_LABEL_MAPPING[TASK_NAME].keys()) if
                        SuperGLUE_LABEL_MAPPING[TASK_NAME] is not None else 1)

    metrics = (SuperGLUE_TASK_METRIC_MAPPING[TASK_NAME]
               if TASK_NAME in SuperGLUE_TASK_METRIC_MAPPING else [])

    custom_metric_funcs = {}

    loss_fn = partial(utils.ce_loss, f"{TASK_NAME}_pred_head")
    output_fn = partial(utils.output, f"{TASK_NAME}_pred_head")

    task = Task(
        name=TASK_NAME,
        module_pool=nn.ModuleDict({
            "bert_module":
            bert_module,
            f"{TASK_NAME}_feature":
            BertContactLastCLSWithTwoTokensModule(),
            f"{TASK_NAME}_pred_head":
            nn.Linear(bert_output_dim * 3, task_cardinality),
        }),
        task_flow=[
            Operation(
                name=f"{TASK_NAME}_bert_module",
                module_name="bert_module",
                inputs=[
                    ("_input_", "token_ids"),
                    ("_input_", "token_segments"),
                    ("_input_", "token_masks"),
                ],
            ),
            Operation(
                name=f"{TASK_NAME}_feature",
                module_name=f"{TASK_NAME}_feature",
                inputs=[
                    (f"{TASK_NAME}_bert_module", 0),
                    ("_input_", "token1_idx"),
                    ("_input_", "token2_idx"),
                ],
            ),
            Operation(
                name=f"{TASK_NAME}_pred_head",
                module_name=f"{TASK_NAME}_pred_head",
                inputs=[(f"{TASK_NAME}_feature", 0)],
            ),
        ],
        loss_func=loss_fn,
        output_func=output_fn,
        scorer=Scorer(metrics=metrics,
                      custom_metric_funcs=custom_metric_funcs),
    )

    return task