Example #1
0
def create_FAISS_index(
    train_task_name: str,
    trained_on_task_name: str,
) -> faiss_utils.FAISSIndex:
    if train_task_name not in ["mnli-2", "hans"]:
        raise ValueError

    if trained_on_task_name not in ["mnli-2", "hans"]:
        raise ValueError

    if trained_on_task_name == "mnli-2":
        tokenizer, model = misc_utils.create_tokenizer_and_model(
            constants.MNLI2_MODEL_PATH)

    if trained_on_task_name == "hans":
        tokenizer, model = misc_utils.create_tokenizer_and_model(
            constants.HANS_MODEL_PATH)

    train_dataset, _ = misc_utils.create_datasets(task_name=train_task_name,
                                                  tokenizer=tokenizer)

    faiss_index = faiss_utils.FAISSIndex(768, "Flat")

    model.cuda()
    device = model.device
    train_batch_data_loader = misc_utils.get_dataloader(dataset=train_dataset,
                                                        batch_size=128,
                                                        random=False)

    for inputs in tqdm(train_batch_data_loader):
        for k, v in inputs.items():
            inputs[k] = v.to(device)
        features = misc_utils.compute_BERT_CLS_feature(model, **inputs)
        features = features.cpu().detach().numpy()
        faiss_index.add(features)

    return faiss_index
Example #2
0
def main(
    train_task_name: str,
    train_heuristic: str,
    eval_heuristics: Optional[List[str]] = None,
    num_replicas: Optional[int] = None,
    use_parallel: bool = True,
    version: Optional[str] = None,
) -> Dict[str, List[Dict[str, Any]]]:

    if train_task_name not in ["mnli-2", "hans"]:
        raise ValueError

    if eval_heuristics is None:
        eval_heuristics = DEFAULT_EVAL_HEURISTICS

    if num_replicas is None:
        num_replicas = DEFAULT_NUM_REPLICAS

    if version not in ["new-only-z", "new-only-ztest", "new-z-and-ztest"]:
        raise ValueError

    task_tokenizer, task_model = misc_utils.create_tokenizer_and_model(
        constants.MNLI2_MODEL_PATH)

    (mnli_train_dataset,
     mnli_eval_dataset) = misc_utils.create_datasets(task_name="mnli-2",
                                                     tokenizer=task_tokenizer)

    (hans_train_dataset,
     hans_eval_dataset) = misc_utils.create_datasets(task_name="hans",
                                                     tokenizer=task_tokenizer)

    if train_task_name == "mnli-2":
        train_dataset = mnli_train_dataset

    if train_task_name == "hans":
        train_dataset = hans_train_dataset

    (s_test_damp, s_test_scale,
     s_test_num_samples) = influence_helpers.select_s_test_config(
         trained_on_task_name="mnli-2",
         train_task_name=train_task_name,
         eval_task_name="hans",
     )

    hans_helper = HansHelper(hans_train_dataset=hans_train_dataset,
                             hans_eval_dataset=hans_eval_dataset)

    # We will be running model trained on MNLI-2
    # but calculate influences on HANS dataset
    faiss_index = influence_helpers.load_faiss_index(
        trained_on_task_name="mnli-2", train_task_name=train_task_name)

    output_mode = glue_output_modes["mnli-2"]

    def build_compute_metrics_fn(task_name: str):
        def compute_metrics_fn(p):
            if output_mode == "classification":
                preds = np.argmax(p.predictions, axis=1)
            elif output_mode == "regression":
                preds = np.squeeze(p.predictions)
            return glue_compute_metrics(task_name, preds, p.label_ids)

        return compute_metrics_fn

    # Most of these arguments are placeholders
    # and are not really used at all, so ignore
    # the exact values of these.
    trainer = transformers.Trainer(
        model=task_model,
        args=TrainingArguments(output_dir="./tmp-output",
                               per_device_train_batch_size=128,
                               per_device_eval_batch_size=128,
                               learning_rate=5e-5,
                               logging_steps=100),
    )

    output_collections: Dict[str, List] = defaultdict(list)

    if version == "old":
        raise ValueError("Deprecated")

    else:
        NUM_STEPS = 10
        num_total_experiments = (len(EXPERIMENT_TYPES) * num_replicas *
                                 len(VERSION_2_NUM_DATAPOINTS_CHOICES) *
                                 len(VERSION_2_LEARNING_RATE_CHOICES) *
                                 NUM_STEPS)

        with tqdm(total=num_total_experiments) as pbar:
            for experiment_type in EXPERIMENT_TYPES:
                for replica_index in range(num_replicas):

                    (hans_eval_heuristic_inputs, hans_eval_heuristic_raw_inputs
                     ) = hans_helper.sample_batch_of_heuristic(
                         mode="eval",
                         heuristic=train_heuristic,
                         size=EVAL_HEURISTICS_SAMPLE_BATCH_SIZE,
                         return_raw_data=True)

                    misc_utils.move_inputs_to_device(
                        inputs=hans_eval_heuristic_inputs,
                        device=task_model.device)

                    for version_2_num_datapoints in VERSION_2_NUM_DATAPOINTS_CHOICES:
                        for version_2_learning_rate in VERSION_2_LEARNING_RATE_CHOICES:

                            # The model will be used for multiple
                            # steps so `deepcopy` it here.
                            _model = deepcopy(task_model)
                            for step in range(NUM_STEPS):
                                outputs_one_experiment, _model = one_experiment(
                                    use_parallel=use_parallel,
                                    train_heuristic=train_heuristic,
                                    eval_heuristics=eval_heuristics,
                                    experiment_type=experiment_type,
                                    hans_helper=hans_helper,
                                    train_dataset=train_dataset,
                                    task_model=_model,
                                    faiss_index=faiss_index,
                                    s_test_damp=s_test_damp,
                                    s_test_scale=s_test_scale,
                                    s_test_num_samples=s_test_num_samples,
                                    trainer=trainer,
                                    version=version,
                                    version_2_num_datapoints=
                                    version_2_num_datapoints,
                                    version_2_learning_rate=
                                    version_2_learning_rate,
                                    hans_eval_heuristic_inputs=
                                    hans_eval_heuristic_inputs,
                                    hans_eval_heuristic_raw_inputs=
                                    hans_eval_heuristic_raw_inputs,
                                )

                                output_collections[
                                    f"{experiment_type}-"
                                    f"{replica_index}-"
                                    f"{version_2_num_datapoints}-"
                                    f"{version_2_learning_rate}-"].append(
                                        outputs_one_experiment)

                                pbar.update(1)
                                pbar.set_description(
                                    f"{experiment_type} #{replica_index}")

        torch.save(
            output_collections, f"hans-augmentation-{version}."
            f"{train_task_name}."
            f"{train_heuristic}."
            f"{num_replicas}."
            f"{use_parallel}.pth")

    return output_collections
Example #3
0
def main(
    mode: str,
    num_examples_to_test: int = 5,
    num_repetitions: int = 4,
) -> List[Dict[str, Any]]:

    if mode not in ["only-correct", "only-incorrect"]:
        raise ValueError(f"Unrecognized mode {mode}")

    task_tokenizer, task_model = misc_utils.create_tokenizer_and_model(
        constants.MNLI_MODEL_PATH)
    train_dataset, eval_dataset = misc_utils.create_datasets(
        task_name="mnli", tokenizer=task_tokenizer)
    eval_instance_data_loader = misc_utils.get_dataloader(dataset=eval_dataset,
                                                          batch_size=1,
                                                          random=False)

    output_mode = glue_output_modes["mnli"]

    def build_compute_metrics_fn(task_name: str):
        def compute_metrics_fn(p):
            if output_mode == "classification":
                preds = np.argmax(p.predictions, axis=1)
            elif output_mode == "regression":
                preds = np.squeeze(p.predictions)
            return glue_compute_metrics(task_name, preds, p.label_ids)

        return compute_metrics_fn

    # Most of these arguments are placeholders
    # and are not really used at all, so ignore
    # the exact values of these.
    trainer = transformers.Trainer(
        model=task_model,
        args=TrainingArguments(output_dir="./tmp-output",
                               per_device_train_batch_size=128,
                               per_device_eval_batch_size=128,
                               learning_rate=5e-5,
                               logging_steps=100),
        data_collator=default_data_collator,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        compute_metrics=build_compute_metrics_fn("mnli"),
    )

    task_model.cuda()
    num_examples_tested = 0
    output_collections = []
    for test_index, test_inputs in enumerate(eval_instance_data_loader):
        if num_examples_tested >= num_examples_to_test:
            break

        # Skip when we only want cases of correction prediction but the
        # prediction is incorrect, or vice versa
        prediction_is_correct = misc_utils.is_prediction_correct(
            trainer=trainer, model=task_model, inputs=test_inputs)

        if mode == "only-correct" and prediction_is_correct is False:
            continue

        if mode == "only-incorrect" and prediction_is_correct is True:
            continue

        for k, v in test_inputs.items():
            if isinstance(v, torch.Tensor):
                test_inputs[k] = v.to(torch.device("cuda"))

        # with batch-size 128, 1500 iterations is enough
        for num_samples in range(700, 1300 + 1, 100):  # 7 choices
            for batch_size in [1, 2, 4, 8, 16, 32, 64, 128]:  # 8 choices
                for repetition in range(num_repetitions):
                    print(
                        f"Running #{test_index} "
                        f"N={num_samples} "
                        f"B={batch_size} "
                        f"R={repetition} takes ...",
                        end=" ")
                    with Timer() as timer:
                        s_test = one_experiment(
                            model=task_model,
                            train_dataset=train_dataset,
                            test_inputs=test_inputs,
                            batch_size=batch_size,
                            random=True,
                            n_gpu=1,
                            device=torch.device("cuda"),
                            damp=constants.DEFAULT_INFLUENCE_HPARAMS["mnli"]
                            ["mnli"]["damp"],
                            scale=constants.DEFAULT_INFLUENCE_HPARAMS["mnli"]
                            ["mnli"]["scale"],
                            num_samples=num_samples)
                        time_elapsed = timer.elapsed
                        print(f"{time_elapsed:.2f} seconds")

                    outputs = {
                        "test_index": test_index,
                        "num_samples": num_samples,
                        "batch_size": batch_size,
                        "repetition": repetition,
                        "s_test": s_test,
                        "time_elapsed": time_elapsed,
                        "correct": prediction_is_correct,
                    }
                    output_collections.append(outputs)
                    remote_utils.save_and_mirror_scp_to_remote(
                        object_to_save=outputs,
                        file_name=f"stest.{mode}.{num_examples_to_test}."
                        f"{test_index}.{num_samples}."
                        f"{batch_size}.{repetition}.pth")

        num_examples_tested += 1

    return output_collections
Example #4
0
def imitator_main(mode: str,
                  num_examples_to_test: int) -> List[Dict[str, Any]]:
    if mode not in ["only-correct", "only-incorrect"]:
        raise ValueError(f"Unrecognized mode {mode}")

    task_tokenizer, task_model = misc_utils.create_tokenizer_and_model(
        constants.MNLI_MODEL_PATH)

    imitator_tokenizer, imitator_model = misc_utils.create_tokenizer_and_model(
        constants.MNLI_IMITATOR_MODEL_PATH)

    (mnli_train_dataset,
     mnli_eval_dataset) = misc_utils.create_datasets(task_name="mnli",
                                                     tokenizer=task_tokenizer)

    task_model.cuda()
    imitator_model.cuda()
    if task_model.training is True or imitator_model.training is True:
        raise ValueError("One of the model is in training mode")
    print(task_model.device, imitator_model.device)

    # Most of these arguments are placeholders
    # and are not really used at all, so ignore
    # the exact values of these.
    trainer = transformers.Trainer(
        model=task_model,
        args=TrainingArguments(output_dir="./tmp-output",
                               per_device_train_batch_size=128,
                               per_device_eval_batch_size=128,
                               learning_rate=5e-5,
                               logging_steps=100),
    )

    eval_instance_data_loader = misc_utils.get_dataloader(
        mnli_eval_dataset, batch_size=1, data_collator=default_data_collator)

    train_inputs_collections = torch.load(
        constants.MNLI_TRAIN_INPUT_COLLECTIONS_PATH)

    inputs_by_label: Dict[str, List[int]] = defaultdict(list)
    for i in range(len(train_inputs_collections)):
        label = mnli_train_dataset.label_list[train_inputs_collections[i]
                                              ["labels"]]
        inputs_by_label[label].append(i)

    outputs_collections = []
    for i, test_inputs in enumerate(eval_instance_data_loader):
        if mode == "only-correct" and i not in CORRECT_INDICES[:
                                                               num_examples_to_test]:
            continue
        if mode == "only-incorrect" and i not in INCORRECT_INDICES[:
                                                                   num_examples_to_test]:
            continue

        start_time = time.time()
        for using_ground_truth in [True, False]:
            outputs = run_one_imitator_experiment(
                task_model=task_model,
                imitator_model=imitator_model,
                test_inputs=test_inputs,
                trainer=trainer,
                train_dataset=mnli_train_dataset,
                train_inputs_collections=train_inputs_collections,
                inputs_by_label=inputs_by_label,
                finetune_using_ground_truth_label=using_ground_truth)
            outputs["index"] = i
            outputs_collections.append(outputs)

        end_time = time.time()
        print(f"#{len(outputs_collections)}/{len(outputs_collections)}: "
              f"Elapsed {(end_time - start_time) / 60:.2f}")

    torch.save(outputs_collections,
               f"imiator_experiments.{mode}.{num_examples_to_test}.pt")

    return outputs_collections
Example #5
0
def run_full_influence_functions(
        mode: str,
        num_examples_to_test: int,
        s_test_num_samples: int = 1000) -> Dict[int, Dict[str, Any]]:

    if mode not in ["only-correct", "only-incorrect"]:
        raise ValueError(f"Unrecognized mode {mode}")

    tokenizer, model = misc_utils.create_tokenizer_and_model(
        constants.MNLI_MODEL_PATH)

    (mnli_train_dataset,
     mnli_eval_dataset) = misc_utils.create_datasets(task_name="mnli",
                                                     tokenizer=tokenizer)

    batch_train_data_loader = misc_utils.get_dataloader(mnli_train_dataset,
                                                        batch_size=128,
                                                        random=True)

    instance_train_data_loader = misc_utils.get_dataloader(mnli_train_dataset,
                                                           batch_size=1,
                                                           random=False)

    eval_instance_data_loader = misc_utils.get_dataloader(
        dataset=mnli_eval_dataset, batch_size=1, random=False)

    output_mode = glue_output_modes["mnli"]

    def build_compute_metrics_fn(task_name: str):
        def compute_metrics_fn(p):
            if output_mode == "classification":
                preds = np.argmax(p.predictions, axis=1)
            elif output_mode == "regression":
                preds = np.squeeze(p.predictions)
            return glue_compute_metrics(task_name, preds, p.label_ids)

        return compute_metrics_fn

    # Most of these arguments are placeholders
    # and are not really used at all, so ignore
    # the exact values of these.
    trainer = transformers.Trainer(
        model=model,
        args=TrainingArguments(output_dir="./tmp-output",
                               per_device_train_batch_size=128,
                               per_device_eval_batch_size=128,
                               learning_rate=5e-5,
                               logging_steps=100),
        data_collator=default_data_collator,
        train_dataset=mnli_train_dataset,
        eval_dataset=mnli_eval_dataset,
        compute_metrics=build_compute_metrics_fn("mnli"),
    )

    params_filter = [
        n for n, p in model.named_parameters() if not p.requires_grad
    ]

    weight_decay_ignores = ["bias", "LayerNorm.weight"] + [
        n for n, p in model.named_parameters() if not p.requires_grad
    ]

    model.cuda()
    num_examples_tested = 0
    outputs_collections = {}
    for test_index, test_inputs in enumerate(eval_instance_data_loader):
        if num_examples_tested >= num_examples_to_test:
            break

        # Skip when we only want cases of correction prediction but the
        # prediction is incorrect, or vice versa
        prediction_is_correct = misc_utils.is_prediction_correct(
            trainer=trainer, model=model, inputs=test_inputs)

        if mode == "only-correct" and prediction_is_correct is False:
            continue

        if mode == "only-incorrect" and prediction_is_correct is True:
            continue

        with Timer() as timer:
            influences, _, s_test = nn_influence_utils.compute_influences(
                n_gpu=1,
                device=torch.device("cuda"),
                batch_train_data_loader=batch_train_data_loader,
                instance_train_data_loader=instance_train_data_loader,
                model=model,
                test_inputs=test_inputs,
                params_filter=params_filter,
                weight_decay=constants.WEIGHT_DECAY,
                weight_decay_ignores=weight_decay_ignores,
                s_test_damp=5e-3,
                s_test_scale=1e4,
                s_test_num_samples=s_test_num_samples,
                train_indices_to_include=None,
                s_test_iterations=1,
                precomputed_s_test=None)

            outputs = {
                "test_index": test_index,
                "influences": influences,
                "s_test": s_test,
                "time": timer.elapsed,
                "correct": prediction_is_correct,
            }
            num_examples_tested += 1
            outputs_collections[test_index] = outputs

            remote_utils.save_and_mirror_scp_to_remote(
                object_to_save=outputs,
                file_name=
                f"KNN-recall.{mode}.{num_examples_to_test}.{test_index}.pth")
            print(
                f"Status: #{test_index} | {num_examples_tested} / {num_examples_to_test}"
            )

    return outputs_collections
def main(
    mode: str,
    train_task_name: str,
    eval_task_name: str,
    num_eval_to_collect: int,
    use_parallel: bool = True,
    kNN_k: Optional[int] = None,
    hans_heuristic: Optional[str] = None,
    trained_on_task_name: Optional[str] = None,
) -> List[Dict[str, Union[int, Dict[int, float]]]]:

    if train_task_name not in ["mnli", "mnli-2", "hans"]:
        raise ValueError

    if eval_task_name not in ["mnli", "mnli-2", "hans"]:
        raise ValueError

    if trained_on_task_name is None:
        # The task the model was trained on
        # can be different from `train_task_name`
        # which is used to determine on which the
        # influence values will be computed.
        trained_on_task_name = train_task_name

    if trained_on_task_name not in ["mnli", "mnli-2", "hans"]:
        raise ValueError

    if mode not in ["only-correct", "only-incorrect"]:
        raise ValueError(f"Unrecognized mode {mode}")

    if kNN_k is None:
        kNN_k = DEFAULT_KNN_K

    # `trained_on_task_name` determines the model to load
    if trained_on_task_name in ["mnli"]:
        tokenizer, model = misc_utils.create_tokenizer_and_model(
            constants.MNLI_MODEL_PATH)

    if trained_on_task_name in ["mnli-2"]:
        tokenizer, model = misc_utils.create_tokenizer_and_model(
            constants.MNLI2_MODEL_PATH)

    if trained_on_task_name in ["hans"]:
        tokenizer, model = misc_utils.create_tokenizer_and_model(
            constants.HANS_MODEL_PATH)

    train_dataset, _ = misc_utils.create_datasets(task_name=train_task_name,
                                                  tokenizer=tokenizer)

    _, eval_dataset = misc_utils.create_datasets(task_name=eval_task_name,
                                                 tokenizer=tokenizer)

    faiss_index = influence_helpers.load_faiss_index(
        trained_on_task_name=trained_on_task_name,
        train_task_name=train_task_name)

    trainer = Trainer(
        model=model,
        args=TrainingArguments(output_dir="./tmp-output",
                               per_device_train_batch_size=128,
                               per_device_eval_batch_size=128,
                               learning_rate=5e-5,
                               logging_steps=100),
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
    )

    if eval_task_name in ["mnli", "mnli-2"]:
        eval_instance_data_loader = misc_utils.get_dataloader(
            dataset=eval_dataset, batch_size=1, random=False)

    if eval_task_name in ["hans"]:
        if hans_heuristic is None:
            raise ValueError("`hans_heuristic` cannot be None for now")

        hans_helper = HansHelper(hans_train_dataset=None,
                                 hans_eval_dataset=eval_dataset)

        _, eval_instance_data_loader = hans_helper.get_dataset_and_dataloader_of_heuristic(
            mode="eval", heuristic=hans_heuristic, batch_size=1, random=False)

    # Data-points where the model got wrong
    correct_input_collections = []
    incorrect_input_collections = []
    for index, test_inputs in enumerate(eval_instance_data_loader):
        logits, labels, step_eval_loss = misc_utils.predict(trainer=trainer,
                                                            model=model,
                                                            inputs=test_inputs)
        if logits.argmax(axis=-1).item() != labels.item():
            incorrect_input_collections.append((index, test_inputs))
        else:
            correct_input_collections.append((index, test_inputs))

    if mode == "only-incorrect":
        input_collections = incorrect_input_collections
    else:
        input_collections = correct_input_collections

    # Other settings are not supported as of now
    (s_test_damp, s_test_scale,
     s_test_num_samples) = influence_helpers.select_s_test_config(
         trained_on_task_name=trained_on_task_name,
         train_task_name=train_task_name,
         eval_task_name=eval_task_name)

    influences_collections = []
    for index, inputs in input_collections[:num_eval_to_collect]:
        print(f"#{index}")
        influences = influence_helpers.compute_influences_simplified(
            k=kNN_k,
            faiss_index=faiss_index,
            model=model,
            inputs=inputs,
            train_dataset=train_dataset,
            use_parallel=use_parallel,
            s_test_damp=s_test_damp,
            s_test_scale=s_test_scale,
            s_test_num_samples=s_test_num_samples,
            device_ids=[0, 1, 2, 3],
            precomputed_s_test=None)

        influences_collections.append({
            "index": index,
            "influences": influences,
        })

    remote_utils.save_and_mirror_scp_to_remote(
        object_to_save=influences_collections,
        file_name=(f"visualization"
                   f".{mode}.{num_eval_to_collect}"
                   f".{train_task_name}-{eval_task_name}"
                   f"-{hans_heuristic}-{trained_on_task_name}"
                   f".{kNN_k}.{use_parallel}.pth"))

    return influences_collections