コード例 #1
0
    def get_dataset_and_dataloader_of_heuristic(
            self,
            mode: str,
            heuristic: str,
            batch_size: int,
            random: bool) -> Tuple[SubsetDataset,
                                   torch.utils.data.DataLoader]:

        if mode not in ["train", "eval"]:
            raise ValueError

        if mode == "train":
            dataset = self._hans_train_dataset
        else:
            dataset = self._hans_eval_dataset

        if dataset is None:
            raise ValueError("`dataset` is None")

        indices = self.get_indices_of_heuristic(
            mode=mode, heuristic=heuristic)

        heuristic_dataset = SubsetDataset(dataset=dataset, indices=indices)
        heuristic_dataloader = misc_utils.get_dataloader(
            dataset=heuristic_dataset,
            batch_size=batch_size,
            random=random)

        return heuristic_dataset, heuristic_dataloader
コード例 #2
0
def prepare_scattered_inputs_and_indices(
    device_ids: List[int],
    dataset: torch.utils.data.Dataset,
    indices_to_include: Optional[List[int]] = None,
) -> Tuple[List[List[Any]], List[List[int]]]:
    """Scatter the data into devices"""

    indices_list = []
    # inputs_collections = {}
    inputs_collections_list = []
    instance_dataloader = misc_utils.get_dataloader(dataset=dataset,
                                                    batch_size=1)
    for index, train_inputs in enumerate(tqdm(instance_dataloader)):

        # Skip indices when a subset is specified to be included
        if (indices_to_include is not None) and (index
                                                 not in indices_to_include):
            continue

        indices_list.append(index)
        # inputs_collections[index] = train_inputs
        inputs_collections_list.append(train_inputs)

    scattered_inputs, scattered_indices = scatter_inputs_and_indices(
        Xs=inputs_collections_list,
        indices=indices_list,
        device_ids=device_ids)

    return scattered_inputs, scattered_indices
コード例 #3
0
def prepare_small_dataloaders(
        dataset: torch.utils.data.Dataset, random: bool, batch_size: int,
        num_datasets: int,
        num_examples_per_dataset: int) -> List[SimpleDataset]:
    """Only pass to child processes the data we will really use"""

    examples = []
    total_num_examples = batch_size * num_datasets * num_examples_per_dataset

    if random is True:
        indices = np.random.choice(
            len(dataset),
            size=total_num_examples,
            # Sample without replacement
            replace=False)
    else:
        indices = list(range(total_num_examples))

    for index in indices:
        example = dataset[index]
        examples.append(example)

    dataloaders = []
    for i in range(num_datasets):
        start_index = i * batch_size * num_examples_per_dataset
        end_index = (i + 1) * batch_size * num_examples_per_dataset
        new_dataset = SimpleDataset(examples[start_index:end_index])
        dataloader = misc_utils.get_dataloader(
            dataset=new_dataset,
            batch_size=batch_size,
            # The random here doesn't matter?
            random=random)
        dataloaders.append(dataloader)

    return dataloaders
コード例 #4
0
def one_experiment(
    model: torch.nn.Module,
    train_dataset: GlueDataset,
    test_inputs: Dict[str, torch.Tensor],
    batch_size: int,
    random: bool,
    n_gpu: int,
    device: torch.device,
    damp: float,
    scale: float,
    num_samples: int,
) -> List[torch.Tensor]:

    params_filter = [
        n for n, p in model.named_parameters() if not p.requires_grad
    ]

    weight_decay_ignores = ["bias", "LayerNorm.weight"] + [
        n for n, p in model.named_parameters() if not p.requires_grad
    ]

    # Make sure each dataloader is re-initialized
    batch_train_data_loader = misc_utils.get_dataloader(dataset=train_dataset,
                                                        batch_size=batch_size,
                                                        random=random)

    s_test = compute_s_test(n_gpu=n_gpu,
                            device=device,
                            model=model,
                            test_inputs=test_inputs,
                            train_data_loaders=[batch_train_data_loader],
                            params_filter=params_filter,
                            weight_decay=constants.WEIGHT_DECAY,
                            weight_decay_ignores=weight_decay_ignores,
                            damp=damp,
                            scale=scale,
                            num_samples=num_samples)

    return [X.cpu() for X in s_test]
コード例 #5
0
def create_FAISS_index(
    train_task_name: str,
    trained_on_task_name: str,
) -> faiss_utils.FAISSIndex:
    if train_task_name not in ["mnli-2", "hans"]:
        raise ValueError

    if trained_on_task_name not in ["mnli-2", "hans"]:
        raise ValueError

    if trained_on_task_name == "mnli-2":
        tokenizer, model = misc_utils.create_tokenizer_and_model(
            constants.MNLI2_MODEL_PATH)

    if trained_on_task_name == "hans":
        tokenizer, model = misc_utils.create_tokenizer_and_model(
            constants.HANS_MODEL_PATH)

    train_dataset, _ = misc_utils.create_datasets(task_name=train_task_name,
                                                  tokenizer=tokenizer)

    faiss_index = faiss_utils.FAISSIndex(768, "Flat")

    model.cuda()
    device = model.device
    train_batch_data_loader = misc_utils.get_dataloader(dataset=train_dataset,
                                                        batch_size=128,
                                                        random=False)

    for inputs in tqdm(train_batch_data_loader):
        for k, v in inputs.items():
            inputs[k] = v.to(device)
        features = misc_utils.compute_BERT_CLS_feature(model, **inputs)
        features = features.cpu().detach().numpy()
        faiss_index.add(features)

    return faiss_index
コード例 #6
0
def compute_influences_simplified(
    k: int,
    faiss_index: faiss_utils.FAISSIndex,
    model: torch.nn.Module,
    inputs: Dict[str, torch.Tensor],
    train_dataset: torch.utils.data.DataLoader,
    use_parallel: bool,
    s_test_damp: float,
    s_test_scale: float,
    s_test_num_samples: int,
    device_ids: Optional[List[int]] = None,
    precomputed_s_test: Optional[List[torch.FloatTensor]] = None,
    faiss_index_use_mean_features_as_query: bool = False,
) -> Tuple[Dict[int, float]]:

    # Make sure indices are sorted according to distances
    # KNN_distances[(
    #     KNN_indices.squeeze(axis=0)[
    #         np.argsort(KNN_distances.squeeze(axis=0))
    #     ] != KNN_indices)]

    params_filter = [
        n for n, p in model.named_parameters() if not p.requires_grad
    ]

    weight_decay_ignores = ["bias", "LayerNorm.weight"] + [
        n for n, p in model.named_parameters() if not p.requires_grad
    ]

    if faiss_index is not None:
        features = misc_utils.compute_BERT_CLS_feature(model, **inputs)
        features = features.cpu().detach().numpy()

        if faiss_index_use_mean_features_as_query is True:
            # We use the mean embedding as the final query here
            features = features.mean(axis=0, keepdims=True)

        KNN_distances, KNN_indices = faiss_index.search(k=k, queries=features)
    else:
        KNN_indices = None

    if not use_parallel:
        model.cuda()
        batch_train_data_loader = misc_utils.get_dataloader(train_dataset,
                                                            batch_size=1,
                                                            random=True)

        instance_train_data_loader = misc_utils.get_dataloader(train_dataset,
                                                               batch_size=1,
                                                               random=False)

        influences, _, _ = nn_influence_utils.compute_influences(
            n_gpu=1,
            device=torch.device("cuda"),
            batch_train_data_loader=batch_train_data_loader,
            instance_train_data_loader=instance_train_data_loader,
            model=model,
            test_inputs=inputs,
            params_filter=params_filter,
            weight_decay=constants.WEIGHT_DECAY,
            weight_decay_ignores=weight_decay_ignores,
            s_test_damp=s_test_damp,
            s_test_scale=s_test_scale,
            s_test_num_samples=s_test_num_samples,
            train_indices_to_include=KNN_indices,
            precomputed_s_test=precomputed_s_test)
    else:
        if device_ids is None:
            raise ValueError("`device_ids` cannot be None")

        influences, _ = parallel.compute_influences_parallel(
            # Avoid clash with main process
            device_ids=device_ids,
            train_dataset=train_dataset,
            batch_size=1,
            model=model,
            test_inputs=inputs,
            params_filter=params_filter,
            weight_decay=constants.WEIGHT_DECAY,
            weight_decay_ignores=weight_decay_ignores,
            s_test_damp=s_test_damp,
            s_test_scale=s_test_scale,
            s_test_num_samples=s_test_num_samples,
            train_indices_to_include=KNN_indices,
            return_s_test=False,
            debug=False)

    return influences
コード例 #7
0
def main(
    mode: str,
    num_examples_to_test: int = 5,
    num_repetitions: int = 4,
) -> List[Dict[str, Any]]:

    if mode not in ["only-correct", "only-incorrect"]:
        raise ValueError(f"Unrecognized mode {mode}")

    task_tokenizer, task_model = misc_utils.create_tokenizer_and_model(
        constants.MNLI_MODEL_PATH)
    train_dataset, eval_dataset = misc_utils.create_datasets(
        task_name="mnli", tokenizer=task_tokenizer)
    eval_instance_data_loader = misc_utils.get_dataloader(dataset=eval_dataset,
                                                          batch_size=1,
                                                          random=False)

    output_mode = glue_output_modes["mnli"]

    def build_compute_metrics_fn(task_name: str):
        def compute_metrics_fn(p):
            if output_mode == "classification":
                preds = np.argmax(p.predictions, axis=1)
            elif output_mode == "regression":
                preds = np.squeeze(p.predictions)
            return glue_compute_metrics(task_name, preds, p.label_ids)

        return compute_metrics_fn

    # Most of these arguments are placeholders
    # and are not really used at all, so ignore
    # the exact values of these.
    trainer = transformers.Trainer(
        model=task_model,
        args=TrainingArguments(output_dir="./tmp-output",
                               per_device_train_batch_size=128,
                               per_device_eval_batch_size=128,
                               learning_rate=5e-5,
                               logging_steps=100),
        data_collator=default_data_collator,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        compute_metrics=build_compute_metrics_fn("mnli"),
    )

    task_model.cuda()
    num_examples_tested = 0
    output_collections = []
    for test_index, test_inputs in enumerate(eval_instance_data_loader):
        if num_examples_tested >= num_examples_to_test:
            break

        # Skip when we only want cases of correction prediction but the
        # prediction is incorrect, or vice versa
        prediction_is_correct = misc_utils.is_prediction_correct(
            trainer=trainer, model=task_model, inputs=test_inputs)

        if mode == "only-correct" and prediction_is_correct is False:
            continue

        if mode == "only-incorrect" and prediction_is_correct is True:
            continue

        for k, v in test_inputs.items():
            if isinstance(v, torch.Tensor):
                test_inputs[k] = v.to(torch.device("cuda"))

        # with batch-size 128, 1500 iterations is enough
        for num_samples in range(700, 1300 + 1, 100):  # 7 choices
            for batch_size in [1, 2, 4, 8, 16, 32, 64, 128]:  # 8 choices
                for repetition in range(num_repetitions):
                    print(
                        f"Running #{test_index} "
                        f"N={num_samples} "
                        f"B={batch_size} "
                        f"R={repetition} takes ...",
                        end=" ")
                    with Timer() as timer:
                        s_test = one_experiment(
                            model=task_model,
                            train_dataset=train_dataset,
                            test_inputs=test_inputs,
                            batch_size=batch_size,
                            random=True,
                            n_gpu=1,
                            device=torch.device("cuda"),
                            damp=constants.DEFAULT_INFLUENCE_HPARAMS["mnli"]
                            ["mnli"]["damp"],
                            scale=constants.DEFAULT_INFLUENCE_HPARAMS["mnli"]
                            ["mnli"]["scale"],
                            num_samples=num_samples)
                        time_elapsed = timer.elapsed
                        print(f"{time_elapsed:.2f} seconds")

                    outputs = {
                        "test_index": test_index,
                        "num_samples": num_samples,
                        "batch_size": batch_size,
                        "repetition": repetition,
                        "s_test": s_test,
                        "time_elapsed": time_elapsed,
                        "correct": prediction_is_correct,
                    }
                    output_collections.append(outputs)
                    remote_utils.save_and_mirror_scp_to_remote(
                        object_to_save=outputs,
                        file_name=f"stest.{mode}.{num_examples_to_test}."
                        f"{test_index}.{num_samples}."
                        f"{batch_size}.{repetition}.pth")

        num_examples_tested += 1

    return output_collections
コード例 #8
0
def imitator_main(mode: str,
                  num_examples_to_test: int) -> List[Dict[str, Any]]:
    if mode not in ["only-correct", "only-incorrect"]:
        raise ValueError(f"Unrecognized mode {mode}")

    task_tokenizer, task_model = misc_utils.create_tokenizer_and_model(
        constants.MNLI_MODEL_PATH)

    imitator_tokenizer, imitator_model = misc_utils.create_tokenizer_and_model(
        constants.MNLI_IMITATOR_MODEL_PATH)

    (mnli_train_dataset,
     mnli_eval_dataset) = misc_utils.create_datasets(task_name="mnli",
                                                     tokenizer=task_tokenizer)

    task_model.cuda()
    imitator_model.cuda()
    if task_model.training is True or imitator_model.training is True:
        raise ValueError("One of the model is in training mode")
    print(task_model.device, imitator_model.device)

    # Most of these arguments are placeholders
    # and are not really used at all, so ignore
    # the exact values of these.
    trainer = transformers.Trainer(
        model=task_model,
        args=TrainingArguments(output_dir="./tmp-output",
                               per_device_train_batch_size=128,
                               per_device_eval_batch_size=128,
                               learning_rate=5e-5,
                               logging_steps=100),
    )

    eval_instance_data_loader = misc_utils.get_dataloader(
        mnli_eval_dataset, batch_size=1, data_collator=default_data_collator)

    train_inputs_collections = torch.load(
        constants.MNLI_TRAIN_INPUT_COLLECTIONS_PATH)

    inputs_by_label: Dict[str, List[int]] = defaultdict(list)
    for i in range(len(train_inputs_collections)):
        label = mnli_train_dataset.label_list[train_inputs_collections[i]
                                              ["labels"]]
        inputs_by_label[label].append(i)

    outputs_collections = []
    for i, test_inputs in enumerate(eval_instance_data_loader):
        if mode == "only-correct" and i not in CORRECT_INDICES[:
                                                               num_examples_to_test]:
            continue
        if mode == "only-incorrect" and i not in INCORRECT_INDICES[:
                                                                   num_examples_to_test]:
            continue

        start_time = time.time()
        for using_ground_truth in [True, False]:
            outputs = run_one_imitator_experiment(
                task_model=task_model,
                imitator_model=imitator_model,
                test_inputs=test_inputs,
                trainer=trainer,
                train_dataset=mnli_train_dataset,
                train_inputs_collections=train_inputs_collections,
                inputs_by_label=inputs_by_label,
                finetune_using_ground_truth_label=using_ground_truth)
            outputs["index"] = i
            outputs_collections.append(outputs)

        end_time = time.time()
        print(f"#{len(outputs_collections)}/{len(outputs_collections)}: "
              f"Elapsed {(end_time - start_time) / 60:.2f}")

    torch.save(outputs_collections,
               f"imiator_experiments.{mode}.{num_examples_to_test}.pt")

    return outputs_collections
コード例 #9
0
def run_full_influence_functions(
        mode: str,
        num_examples_to_test: int,
        s_test_num_samples: int = 1000) -> Dict[int, Dict[str, Any]]:

    if mode not in ["only-correct", "only-incorrect"]:
        raise ValueError(f"Unrecognized mode {mode}")

    tokenizer, model = misc_utils.create_tokenizer_and_model(
        constants.MNLI_MODEL_PATH)

    (mnli_train_dataset,
     mnli_eval_dataset) = misc_utils.create_datasets(task_name="mnli",
                                                     tokenizer=tokenizer)

    batch_train_data_loader = misc_utils.get_dataloader(mnli_train_dataset,
                                                        batch_size=128,
                                                        random=True)

    instance_train_data_loader = misc_utils.get_dataloader(mnli_train_dataset,
                                                           batch_size=1,
                                                           random=False)

    eval_instance_data_loader = misc_utils.get_dataloader(
        dataset=mnli_eval_dataset, batch_size=1, random=False)

    output_mode = glue_output_modes["mnli"]

    def build_compute_metrics_fn(task_name: str):
        def compute_metrics_fn(p):
            if output_mode == "classification":
                preds = np.argmax(p.predictions, axis=1)
            elif output_mode == "regression":
                preds = np.squeeze(p.predictions)
            return glue_compute_metrics(task_name, preds, p.label_ids)

        return compute_metrics_fn

    # Most of these arguments are placeholders
    # and are not really used at all, so ignore
    # the exact values of these.
    trainer = transformers.Trainer(
        model=model,
        args=TrainingArguments(output_dir="./tmp-output",
                               per_device_train_batch_size=128,
                               per_device_eval_batch_size=128,
                               learning_rate=5e-5,
                               logging_steps=100),
        data_collator=default_data_collator,
        train_dataset=mnli_train_dataset,
        eval_dataset=mnli_eval_dataset,
        compute_metrics=build_compute_metrics_fn("mnli"),
    )

    params_filter = [
        n for n, p in model.named_parameters() if not p.requires_grad
    ]

    weight_decay_ignores = ["bias", "LayerNorm.weight"] + [
        n for n, p in model.named_parameters() if not p.requires_grad
    ]

    model.cuda()
    num_examples_tested = 0
    outputs_collections = {}
    for test_index, test_inputs in enumerate(eval_instance_data_loader):
        if num_examples_tested >= num_examples_to_test:
            break

        # Skip when we only want cases of correction prediction but the
        # prediction is incorrect, or vice versa
        prediction_is_correct = misc_utils.is_prediction_correct(
            trainer=trainer, model=model, inputs=test_inputs)

        if mode == "only-correct" and prediction_is_correct is False:
            continue

        if mode == "only-incorrect" and prediction_is_correct is True:
            continue

        with Timer() as timer:
            influences, _, s_test = nn_influence_utils.compute_influences(
                n_gpu=1,
                device=torch.device("cuda"),
                batch_train_data_loader=batch_train_data_loader,
                instance_train_data_loader=instance_train_data_loader,
                model=model,
                test_inputs=test_inputs,
                params_filter=params_filter,
                weight_decay=constants.WEIGHT_DECAY,
                weight_decay_ignores=weight_decay_ignores,
                s_test_damp=5e-3,
                s_test_scale=1e4,
                s_test_num_samples=s_test_num_samples,
                train_indices_to_include=None,
                s_test_iterations=1,
                precomputed_s_test=None)

            outputs = {
                "test_index": test_index,
                "influences": influences,
                "s_test": s_test,
                "time": timer.elapsed,
                "correct": prediction_is_correct,
            }
            num_examples_tested += 1
            outputs_collections[test_index] = outputs

            remote_utils.save_and_mirror_scp_to_remote(
                object_to_save=outputs,
                file_name=
                f"KNN-recall.{mode}.{num_examples_to_test}.{test_index}.pth")
            print(
                f"Status: #{test_index} | {num_examples_tested} / {num_examples_to_test}"
            )

    return outputs_collections
コード例 #10
0
def main(
    mode: str,
    train_task_name: str,
    eval_task_name: str,
    num_eval_to_collect: int,
    use_parallel: bool = True,
    kNN_k: Optional[int] = None,
    hans_heuristic: Optional[str] = None,
    trained_on_task_name: Optional[str] = None,
) -> List[Dict[str, Union[int, Dict[int, float]]]]:

    if train_task_name not in ["mnli", "mnli-2", "hans"]:
        raise ValueError

    if eval_task_name not in ["mnli", "mnli-2", "hans"]:
        raise ValueError

    if trained_on_task_name is None:
        # The task the model was trained on
        # can be different from `train_task_name`
        # which is used to determine on which the
        # influence values will be computed.
        trained_on_task_name = train_task_name

    if trained_on_task_name not in ["mnli", "mnli-2", "hans"]:
        raise ValueError

    if mode not in ["only-correct", "only-incorrect"]:
        raise ValueError(f"Unrecognized mode {mode}")

    if kNN_k is None:
        kNN_k = DEFAULT_KNN_K

    # `trained_on_task_name` determines the model to load
    if trained_on_task_name in ["mnli"]:
        tokenizer, model = misc_utils.create_tokenizer_and_model(
            constants.MNLI_MODEL_PATH)

    if trained_on_task_name in ["mnli-2"]:
        tokenizer, model = misc_utils.create_tokenizer_and_model(
            constants.MNLI2_MODEL_PATH)

    if trained_on_task_name in ["hans"]:
        tokenizer, model = misc_utils.create_tokenizer_and_model(
            constants.HANS_MODEL_PATH)

    train_dataset, _ = misc_utils.create_datasets(task_name=train_task_name,
                                                  tokenizer=tokenizer)

    _, eval_dataset = misc_utils.create_datasets(task_name=eval_task_name,
                                                 tokenizer=tokenizer)

    faiss_index = influence_helpers.load_faiss_index(
        trained_on_task_name=trained_on_task_name,
        train_task_name=train_task_name)

    trainer = Trainer(
        model=model,
        args=TrainingArguments(output_dir="./tmp-output",
                               per_device_train_batch_size=128,
                               per_device_eval_batch_size=128,
                               learning_rate=5e-5,
                               logging_steps=100),
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
    )

    if eval_task_name in ["mnli", "mnli-2"]:
        eval_instance_data_loader = misc_utils.get_dataloader(
            dataset=eval_dataset, batch_size=1, random=False)

    if eval_task_name in ["hans"]:
        if hans_heuristic is None:
            raise ValueError("`hans_heuristic` cannot be None for now")

        hans_helper = HansHelper(hans_train_dataset=None,
                                 hans_eval_dataset=eval_dataset)

        _, eval_instance_data_loader = hans_helper.get_dataset_and_dataloader_of_heuristic(
            mode="eval", heuristic=hans_heuristic, batch_size=1, random=False)

    # Data-points where the model got wrong
    correct_input_collections = []
    incorrect_input_collections = []
    for index, test_inputs in enumerate(eval_instance_data_loader):
        logits, labels, step_eval_loss = misc_utils.predict(trainer=trainer,
                                                            model=model,
                                                            inputs=test_inputs)
        if logits.argmax(axis=-1).item() != labels.item():
            incorrect_input_collections.append((index, test_inputs))
        else:
            correct_input_collections.append((index, test_inputs))

    if mode == "only-incorrect":
        input_collections = incorrect_input_collections
    else:
        input_collections = correct_input_collections

    # Other settings are not supported as of now
    (s_test_damp, s_test_scale,
     s_test_num_samples) = influence_helpers.select_s_test_config(
         trained_on_task_name=trained_on_task_name,
         train_task_name=train_task_name,
         eval_task_name=eval_task_name)

    influences_collections = []
    for index, inputs in input_collections[:num_eval_to_collect]:
        print(f"#{index}")
        influences = influence_helpers.compute_influences_simplified(
            k=kNN_k,
            faiss_index=faiss_index,
            model=model,
            inputs=inputs,
            train_dataset=train_dataset,
            use_parallel=use_parallel,
            s_test_damp=s_test_damp,
            s_test_scale=s_test_scale,
            s_test_num_samples=s_test_num_samples,
            device_ids=[0, 1, 2, 3],
            precomputed_s_test=None)

        influences_collections.append({
            "index": index,
            "influences": influences,
        })

    remote_utils.save_and_mirror_scp_to_remote(
        object_to_save=influences_collections,
        file_name=(f"visualization"
                   f".{mode}.{num_eval_to_collect}"
                   f".{train_task_name}-{eval_task_name}"
                   f"-{hans_heuristic}-{trained_on_task_name}"
                   f".{kNN_k}.{use_parallel}.pth"))

    return influences_collections