def build_model(bert_model_name, last_hidden_dropout_prob=0.0): bert_module = BertModule(bert_model_name) bert_output_dim = 768 if "base" in bert_model_name else 1024 task_cardinality = ( len(SuperGLUE_LABEL_MAPPING[TASK_NAME].keys()) if SuperGLUE_LABEL_MAPPING[TASK_NAME] is not None else 1 ) metrics = ( SuperGLUE_TASK_METRIC_MAPPING[TASK_NAME] if TASK_NAME in SuperGLUE_TASK_METRIC_MAPPING else [] ) customize_metric_funcs = {} loss_fn = partial(utils.ce_loss, f"{TASK_NAME}_pred_head") output_fn = partial(utils.output, f"{TASK_NAME}_pred_head") task = EmmentalTask( name=TASK_NAME, module_pool=nn.ModuleDict( { "bert_module": bert_module, f"{TASK_NAME}_feature": BertLastCLSModule( dropout_prob=last_hidden_dropout_prob ), f"{TASK_NAME}_pred_head": nn.Linear(bert_output_dim, task_cardinality), } ), task_flow=[ { "name": f"{TASK_NAME}_bert_module", "module": "bert_module", "inputs": [ ("_input_", "token_ids"), ("_input_", "token_segments"), ("_input_", "token_masks"), ], }, { "name": f"{TASK_NAME}_feature", "module": f"{TASK_NAME}_feature", "inputs": [(f"{TASK_NAME}_bert_module", 0)], }, { "name": f"{TASK_NAME}_pred_head", "module": f"{TASK_NAME}_pred_head", "inputs": [(f"{TASK_NAME}_feature", 0)], }, ], loss_func=loss_fn, output_func=output_fn, scorer=Scorer(metrics=metrics, customize_metric_funcs=customize_metric_funcs), ) return task
def get_superglue_task(task_names, bert_model_name): tasks = dict() bert_module = BertModule(bert_model_name) bert_output_dim = 768 if "base" in bert_model_name else 1024 for task_name in task_names: task_cardinality = (len(SuperGLUE_LABEL_MAPPING[task_name].keys()) if SuperGLUE_LABEL_MAPPING[task_name] is not None else 1) metrics = (SuperGLUE_TASK_METRIC_MAPPING[task_name] if task_name in SuperGLUE_TASK_METRIC_MAPPING else []) customize_metric_funcs = ({ "em": em, "em_f1": em_f1 } if task_name == "MultiRC" else {}) loss_fn = partial(ce_loss, f"{task_name}_pred_head") output_fn = partial(output, f"{task_name}_pred_head") if task_name == "MultiRC": task = EmmentalTask( name=task_name, module_pool=nn.ModuleDict({ "bert_module": bert_module, "bert_last_CLS": BertLastCLSModule(), f"{task_name}_pred_head": nn.Linear(bert_output_dim, task_cardinality), }), task_flow=[ { "name": f"{task_name}_bert_module", "module": "bert_module", "inputs": [("_input_", "token_ids")], }, { "name": f"{task_name}_bert_last_CLS", "module": "bert_last_CLS", "inputs": [(f"{task_name}_bert_module", 0)], }, { "name": f"{task_name}_pred_head", "module": f"{task_name}_pred_head", "inputs": [(f"{task_name}_bert_last_CLS", 0)], }, ], loss_func=loss_fn, output_func=output_fn, scorer=Scorer(metrics=metrics, customize_metric_funcs=customize_metric_funcs), ) tasks[task_name] = task return tasks
def build_model(bert_model_name, last_hidden_dropout_prob=0.0): bert_module = BertModule(bert_model_name) bert_output_dim = 768 if "base" in bert_model_name else 1024 metrics = (SuperGLUE_TASK_METRIC_MAPPING[TASK_NAME] if TASK_NAME in SuperGLUE_TASK_METRIC_MAPPING else []) customize_metric_funcs = {} loss_fn = partial(utils.ce_loss_multiple_choice, f"{TASK_NAME}_pred_head", NUM_CHOICES) output_fn = partial(utils.output_multiple_choice, f"{TASK_NAME}_pred_head", NUM_CHOICES) task = EmmentalTask( name=TASK_NAME, module_pool=nn.ModuleDict({ f"{TASK_NAME}_multiple_choice_module": MultipleChoiceModule(NUM_CHOICES), "bert_module": bert_module, f"{TASK_NAME}_feature": BertLastCLSModule(dropout_prob=last_hidden_dropout_prob), f"{TASK_NAME}_pred_head": nn.Linear(bert_output_dim, 1), }), task_flow=[ { "name": f"{TASK_NAME}_multiple_choice_module", "module": f"{TASK_NAME}_multiple_choice_module", "inputs": [], }, { "name": f"{TASK_NAME}_bert_module", "module": "bert_module", "inputs": [ (f"{TASK_NAME}_multiple_choice_module", 0), (f"{TASK_NAME}_multiple_choice_module", 1), (f"{TASK_NAME}_multiple_choice_module", 2), ], }, { "name": f"{TASK_NAME}_feature", "module": f"{TASK_NAME}_feature", "inputs": [(f"{TASK_NAME}_bert_module", 0)], }, { "name": f"{TASK_NAME}_pred_head", "module": f"{TASK_NAME}_pred_head", "inputs": [(f"{TASK_NAME}_feature", 0)], }, ], loss_func=loss_fn, output_func=output_fn, scorer=Scorer(metrics=metrics, customize_metric_funcs=customize_metric_funcs), ) return task