def build_task(bert_model_name, last_hidden_dropout_prob=0.0): bert_module = BertModule(bert_model_name) bert_output_dim = 768 if "base" in bert_model_name else 1024 task_cardinality = ( len(SuperGLUE_LABEL_MAPPING[TASK_NAME].keys()) if SuperGLUE_LABEL_MAPPING[TASK_NAME] is not None else 1 ) metrics = ( SuperGLUE_TASK_METRIC_MAPPING[TASK_NAME] if TASK_NAME in SuperGLUE_TASK_METRIC_MAPPING else [] ) custom_metric_funcs = {} loss_fn = partial(utils.ce_loss, f"{TASK_NAME}_pred_head") output_fn = partial(utils.output, f"{TASK_NAME}_pred_head") task = Task( name=TASK_NAME, module_pool=nn.ModuleDict( { "bert_module": bert_module, "bert_last_CLS": BertLastCLSModule( dropout_prob=last_hidden_dropout_prob ), f"{TASK_NAME}_pred_head": nn.Linear(bert_output_dim, task_cardinality), } ), task_flow=[ Operation( name=f"{TASK_NAME}_bert_module", module_name="bert_module", inputs=[ ("_input_", "token_ids"), ("_input_", "token_segments"), ("_input_", "token_masks"), ], ), Operation( name=f"{TASK_NAME}_bert_last_CLS", module_name="bert_last_CLS", inputs=[(f"{TASK_NAME}_bert_module", 0)], ), Operation( name=f"{TASK_NAME}_pred_head", module_name=f"{TASK_NAME}_pred_head", inputs=[(f"{TASK_NAME}_bert_last_CLS", 0)], ), ], loss_func=loss_fn, output_func=output_fn, scorer=Scorer(metrics=metrics, custom_metric_funcs=custom_metric_funcs), ) return task
def build_task(bert_model_name, last_hidden_dropout_prob=None): if last_hidden_dropout_prob: raise NotImplementedError( f"TODO: last_hidden_dropout_prob for {TASK_NAME}") bert_module = BertModule(bert_model_name) bert_output_dim = 768 if "base" in bert_model_name else 1024 task_cardinality = (len(SuperGLUE_LABEL_MAPPING[TASK_NAME].keys()) if SuperGLUE_LABEL_MAPPING[TASK_NAME] is not None else 1) metrics = (SuperGLUE_TASK_METRIC_MAPPING[TASK_NAME] if TASK_NAME in SuperGLUE_TASK_METRIC_MAPPING else []) custom_metric_funcs = {} loss_fn = partial(utils.ce_loss, f"{TASK_NAME}_pred_head") output_fn = partial(utils.output, f"{TASK_NAME}_pred_head") task = Task( name=TASK_NAME, module_pool=nn.ModuleDict({ "bert_module": bert_module, f"{TASK_NAME}_pred_head": SpanClassifierModule(d_inp=bert_output_dim, proj_dim=bert_output_dim // 2), }), task_flow=[ Operation( name=f"{TASK_NAME}_bert_module", module_name="bert_module", inputs=[ ("_input_", "token_ids"), ("_input_", "token_segments"), ("_input_", "token_masks"), ], ), Operation( name=f"{TASK_NAME}_pred_head", module_name=f"{TASK_NAME}_pred_head", inputs=[ (f"{TASK_NAME}_bert_module", 0), ("_input_", "token1_idx"), ("_input_", "token2_idx"), ("_input_", "token_masks"), ], ), ], loss_func=loss_fn, output_func=output_fn, scorer=Scorer(metrics=metrics, custom_metric_funcs=custom_metric_funcs), ) return task