コード例 #1
0
    def __init__(
            self,
            name,
            input_module=IdentityModule(),
            middle_module=IdentityModule(),
            attention_module=IdentityModule(),
            head_module=IdentityModule(),
            output_hat_func=(lambda X: X["data"]),
            # Note: no sigmoid (target labels can be in any range)
            loss_hat_func=(
                lambda X, Y: F.mse_loss(X["data"].view(-1), Y.view(-1))),
            loss_multiplier=1.0,
            scorer=Scorer(standard_metrics=[]),
    ) -> None:

        super(RegressionTask, self).__init__(
            name,
            input_module,
            middle_module,
            attention_module,
            head_module,
            output_hat_func,
            loss_hat_func,
            loss_multiplier,
            scorer,
        )
コード例 #2
0
    def __init__(
        self,
        name,
        input_module=IdentityModule(),
        middle_module=IdentityModule(),
        attention_module=IdentityModule(),
        head_module=IdentityModule(),
        output_hat_func=output_hat_func,
        loss_hat_func=categorical_cross_entropy,
        loss_multiplier=1.0,
        scorer=Scorer(standard_metrics=["accuracy"]),
        slice_head_type=None,
    ) -> None:

        if (head_module and not isinstance(head_module, IdentityModule)
                and head_module.out_features != 1):
            raise ValueError(
                f"{self.__class__.__name__} must have an output dim 1.")

        super(BinaryClassificationTask, self).__init__(
            name,
            input_module,
            middle_module,
            attention_module,
            head_module,
            output_hat_func,
            loss_hat_func,
            loss_multiplier,
            scorer,
        )

        # Add an additional attribute to indicator head type
        assert slice_head_type in ["ind", "pred", None]
        self.slice_head_type = slice_head_type
コード例 #3
0
    def __init__(
            self,
            name,
            input_module=IdentityModule(),
            middle_module=IdentityModule(),
            attention_module=IdentityModule(),
            head_module=IdentityModule(),
            output_hat_func=(lambda X: F.softmax(X["data"], dim=1)),
            loss_hat_func=(
                lambda X, Y: F.cross_entropy(X["data"],
                                             Y.view(-1) - 1)),
            loss_multiplier=1.0,
            scorer=Scorer(standard_metrics=["accuracy"]),
    ) -> None:

        super(ClassificationTask, self).__init__(
            name,
            input_module,
            middle_module,
            attention_module,
            head_module,
            output_hat_func,
            loss_hat_func,
            loss_multiplier,
            scorer,
        )
コード例 #4
0
    def __init__(
        self,
        name,
        input_module=IdentityModule(),
        middle_module=IdentityModule(),
        head_module=IdentityModule(),
        output_hat_func=tokenwise_softmax,
        loss_hat_func=tokenwise_ce_loss,
        loss_multiplier=1.0,
        scorer=Scorer(custom_metric_funcs={tokenwise_accuracy: ["token_acc"]}),
    ) -> None:

        super().__init__(
            name,
            input_module,
            middle_module,
            head_module,
            output_hat_func,
            loss_hat_func,
            loss_multiplier,
            scorer,
        )
コード例 #5
0
ファイル: glue_tasks.py プロジェクト: swagnercarena/metal
def create_glue_tasks_payloads(task_names, skip_payloads=False, **kwargs):
    assert len(task_names) > 0

    config = recursive_merge_dicts(task_defaults, kwargs)

    if config["seed"] is None:
        config["seed"] = np.random.randint(1e6)
        print(f"Using random seed: {config['seed']}")
    set_seed(config["seed"])

    # share bert encoder for all tasks

    if config["encoder_type"] == "bert":
        bert_kwargs = config["bert_kwargs"]
        bert_model = BertRaw(config["bert_model"], **bert_kwargs)
        if "base" in config["bert_model"]:
            neck_dim = 768
        elif "large" in config["bert_model"]:
            neck_dim = 1024
        input_module = bert_model
        pooler = bert_model.pooler if bert_kwargs["pooler"] else None
        cls_middle_module = BertExtractCls(pooler=pooler,
                                           dropout=config["dropout"])
    else:
        raise NotImplementedError

    # Create dict override dl_kwarg for specific task
    # e.g. {"STSB": {"batch_size": 2}}
    task_dl_kwargs = {}
    if config["task_dl_kwargs"]:
        task_configs_str = [
            tuple(config.split("."))
            for config in config["task_dl_kwargs"].split(",")
        ]
        for (task_name, kwarg_key, kwarg_val) in task_configs_str:
            if kwarg_key == "batch_size":
                kwarg_val = int(kwarg_val)
            task_dl_kwargs[task_name] = {kwarg_key: kwarg_val}

    tasks = []
    payloads = []
    for task_name in task_names:
        # If a flag is specified for attention, use it, otherwise use identity module
        if config["attention"]:
            print("Using soft attention head")
            attention_module = SoftAttentionModule(neck_dim)
        else:
            attention_module = IdentityModule()

        # Pull out names of auxiliary tasks to be dealt with in a second step
        # TODO: fix this logic for cases where auxiliary task for task_name has
        # its own payload
        has_payload = task_name not in config["auxiliary_task_dict"]

        # Note whether this task has auxiliary tasks that apply to it and require spacy
        run_spacy = False
        for aux_task, target_payloads in config["auxiliary_task_dict"].items():
            run_spacy = run_spacy or (task_name in target_payloads
                                      and aux_task in SPACY_TASKS
                                      and aux_task in task_names)

        # Override general dl kwargs with task-specific kwargs
        dl_kwargs = copy.deepcopy(config["dl_kwargs"])
        if task_name in task_dl_kwargs:
            dl_kwargs.update(task_dl_kwargs[task_name])

        # Each primary task has data_loaders to load
        if has_payload and not skip_payloads:
            if config["preprocessed"]:
                datasets = load_glue_datasets(
                    dataset_name=task_name,
                    splits=config["splits"],
                    bert_vocab=config["bert_model"],
                    max_len=config["max_len"],
                    max_datapoints=config["max_datapoints"],
                    run_spacy=run_spacy,
                    verbose=True,
                )
            else:
                datasets = create_glue_datasets(
                    dataset_name=task_name,
                    splits=config["splits"],
                    bert_vocab=config["bert_model"],
                    max_len=config["max_len"],
                    max_datapoints=config["max_datapoints"],
                    generate_uids=kwargs.get("generate_uids", False),
                    run_spacy=run_spacy,
                    verbose=True,
                )
            # Wrap datasets with DataLoader objects
            data_loaders = create_glue_dataloaders(
                datasets,
                dl_kwargs=dl_kwargs,
                split_prop=config["split_prop"],
                splits=config["splits"],
                seed=config["seed"],
            )

        if task_name == "COLA":
            scorer = Scorer(
                standard_metrics=["accuracy"],
                custom_metric_funcs={matthews_corr: ["matthews_corr"]},
            )
            task = ClassificationTask(
                name=task_name,
                input_module=input_module,
                middle_module=cls_middle_module,
                attention_module=attention_module,
                head_module=BinaryHead(neck_dim),
                scorer=scorer,
            )

        elif task_name == "SST2":
            task = ClassificationTask(
                name=task_name,
                input_module=input_module,
                middle_module=cls_middle_module,
                attention_module=attention_module,
                head_module=BinaryHead(neck_dim),
            )

        elif task_name == "MNLI":
            task = ClassificationTask(
                name=task_name,
                input_module=input_module,
                middle_module=cls_middle_module,
                attention_module=attention_module,
                head_module=MulticlassHead(neck_dim, 3),
                scorer=Scorer(standard_metrics=["accuracy"]),
            )

        elif task_name == "SNLI":
            task = ClassificationTask(
                name=task_name,
                input_module=input_module,
                middle_module=cls_middle_module,
                attention_module=attention_module,
                head_module=MulticlassHead(neck_dim, 3),
                scorer=Scorer(standard_metrics=["accuracy"]),
            )

        elif task_name == "RTE":
            task = ClassificationTask(
                name=task_name,
                input_module=input_module,
                middle_module=cls_middle_module,
                attention_module=attention_module,
                head_module=BinaryHead(neck_dim),
                scorer=Scorer(standard_metrics=["accuracy"]),
            )

        elif task_name == "WNLI":
            task = ClassificationTask(
                name=task_name,
                input_module=input_module,
                middle_module=cls_middle_module,
                attention_module=attention_module,
                head_module=BinaryHead(neck_dim),
                scorer=Scorer(standard_metrics=["accuracy"]),
            )

        elif task_name == "QQP":
            task = ClassificationTask(
                name=task_name,
                input_module=input_module,
                middle_module=cls_middle_module,
                attention_module=attention_module,
                head_module=BinaryHead(neck_dim),
                scorer=Scorer(
                    custom_metric_funcs={acc_f1: ["accuracy", "f1", "acc_f1"]
                                         }),
            )

        elif task_name == "MRPC":
            task = ClassificationTask(
                name=task_name,
                input_module=input_module,
                middle_module=cls_middle_module,
                attention_module=attention_module,
                head_module=BinaryHead(neck_dim),
                scorer=Scorer(
                    custom_metric_funcs={acc_f1: ["accuracy", "f1", "acc_f1"]
                                         }),
            )

        elif task_name == "STSB":
            scorer = Scorer(
                standard_metrics=[],
                custom_metric_funcs={
                    pearson_spearman: [
                        "pearson_corr",
                        "spearman_corr",
                        "pearson_spearman",
                    ]
                },
            )

            task = RegressionTask(
                name=task_name,
                input_module=input_module,
                middle_module=cls_middle_module,
                attention_module=attention_module,
                head_module=RegressionHead(neck_dim),
                scorer=scorer,
            )

        elif task_name == "QNLI":
            task = ClassificationTask(
                name=task_name,
                input_module=input_module,
                middle_module=cls_middle_module,
                attention_module=attention_module,
                head_module=BinaryHead(neck_dim),
                scorer=Scorer(standard_metrics=["accuracy"]),
            )

        # AUXILIARY TASKS

        elif task_name == "THIRD":
            # A toy task that predict which third of the sentence each token is in
            OUT_DIM = 3
            task = TokenClassificationTask(
                name="THIRD",
                input_module=input_module,
                attention_module=attention_module,
                head_module=BertTokenClassificationHead(neck_dim, OUT_DIM),
                loss_multiplier=config["auxiliary_loss_multiplier"],
            )

        elif task_name == "BLEU":
            task = RegressionTask(
                name=task_name,
                input_module=input_module,
                middle_module=cls_middle_module,
                attention_module=attention_module,
                head_module=RegressionHead(neck_dim),
                output_hat_func=torch.sigmoid,
                loss_hat_func=(lambda out, Y_gold: F.mse_loss(
                    torch.sigmoid(out), Y_gold)),
                scorer=Scorer(custom_metric_funcs={mse: ["mse"]}),
                loss_multiplier=config["auxiliary_loss_multiplier"],
            )

        elif task_name == "SPACY_NER":
            OUT_DIM = len(SPACY_TAGS["SPACY_NER"])
            task = TokenClassificationTask(
                name=task_name,
                input_module=input_module,
                attention_module=attention_module,
                head_module=BertTokenClassificationHead(neck_dim, OUT_DIM),
                loss_multiplier=config["auxiliary_loss_multiplier"],
            )

        elif task_name == "SPACY_POS":
            OUT_DIM = len(SPACY_TAGS["SPACY_POS"])
            task = TokenClassificationTask(
                name=task_name,
                input_module=input_module,
                attention_module=attention_module,
                head_module=BertTokenClassificationHead(neck_dim, OUT_DIM),
                loss_multiplier=config["auxiliary_loss_multiplier"],
            )

        else:
            msg = (f"Task name {task_name} was not recognized as a primary or "
                   f"auxiliary task.")
            raise Exception(msg)

        tasks.append(task)

        # Gather slice names
        slice_names = (config["slice_dict"].get(task_name, [])
                       if config["slice_dict"] else [])

        # Add a task for each slice
        for slice_name in slice_names:
            slice_task_name = f"{task_name}_slice:{slice_name}"
            slice_task = create_slice_task(task, slice_task_name)
            tasks.append(slice_task)

        if has_payload and not skip_payloads:
            # Create payloads (and add slices/auxiliary tasks as applicable)
            for split, data_loader in data_loaders.items():
                payload_name = f"{task_name}_{split}"
                labels_to_tasks = {f"{task_name}_gold": task_name}
                payload = Payload(payload_name, data_loader, labels_to_tasks,
                                  split)

                # Add auxiliary label sets if applicable
                auxiliary_task_dict = config["auxiliary_task_dict"]
                for aux_task_name, target_payloads in auxiliary_task_dict.items(
                ):
                    if aux_task_name in task_names and task_name in target_payloads:
                        aux_task_func = auxiliary_task_functions[aux_task_name]
                        payload = aux_task_func(payload)

                # Add a labelset slice to each split
                dataset = payload.data_loader.dataset
                for slice_name in slice_names:
                    slice_task_name = f"{task_name}_slice:{slice_name}"
                    slice_labels = create_slice_labels(
                        dataset,
                        base_task_name=task_name,
                        slice_name=slice_name)
                    labelset_slice_name = f"{task_name}_slice:{slice_name}"
                    payload.add_label_set(slice_task_name, labelset_slice_name,
                                          slice_labels)

                payloads.append(payload)

    return tasks, payloads