def get_dataset_and_dataloader_of_heuristic( self, mode: str, heuristic: str, batch_size: int, random: bool) -> Tuple[SubsetDataset, torch.utils.data.DataLoader]: if mode not in ["train", "eval"]: raise ValueError if mode == "train": dataset = self._hans_train_dataset else: dataset = self._hans_eval_dataset if dataset is None: raise ValueError("`dataset` is None") indices = self.get_indices_of_heuristic( mode=mode, heuristic=heuristic) heuristic_dataset = SubsetDataset(dataset=dataset, indices=indices) heuristic_dataloader = misc_utils.get_dataloader( dataset=heuristic_dataset, batch_size=batch_size, random=random) return heuristic_dataset, heuristic_dataloader
def prepare_scattered_inputs_and_indices( device_ids: List[int], dataset: torch.utils.data.Dataset, indices_to_include: Optional[List[int]] = None, ) -> Tuple[List[List[Any]], List[List[int]]]: """Scatter the data into devices""" indices_list = [] # inputs_collections = {} inputs_collections_list = [] instance_dataloader = misc_utils.get_dataloader(dataset=dataset, batch_size=1) for index, train_inputs in enumerate(tqdm(instance_dataloader)): # Skip indices when a subset is specified to be included if (indices_to_include is not None) and (index not in indices_to_include): continue indices_list.append(index) # inputs_collections[index] = train_inputs inputs_collections_list.append(train_inputs) scattered_inputs, scattered_indices = scatter_inputs_and_indices( Xs=inputs_collections_list, indices=indices_list, device_ids=device_ids) return scattered_inputs, scattered_indices
def prepare_small_dataloaders( dataset: torch.utils.data.Dataset, random: bool, batch_size: int, num_datasets: int, num_examples_per_dataset: int) -> List[SimpleDataset]: """Only pass to child processes the data we will really use""" examples = [] total_num_examples = batch_size * num_datasets * num_examples_per_dataset if random is True: indices = np.random.choice( len(dataset), size=total_num_examples, # Sample without replacement replace=False) else: indices = list(range(total_num_examples)) for index in indices: example = dataset[index] examples.append(example) dataloaders = [] for i in range(num_datasets): start_index = i * batch_size * num_examples_per_dataset end_index = (i + 1) * batch_size * num_examples_per_dataset new_dataset = SimpleDataset(examples[start_index:end_index]) dataloader = misc_utils.get_dataloader( dataset=new_dataset, batch_size=batch_size, # The random here doesn't matter? random=random) dataloaders.append(dataloader) return dataloaders
def one_experiment( model: torch.nn.Module, train_dataset: GlueDataset, test_inputs: Dict[str, torch.Tensor], batch_size: int, random: bool, n_gpu: int, device: torch.device, damp: float, scale: float, num_samples: int, ) -> List[torch.Tensor]: params_filter = [ n for n, p in model.named_parameters() if not p.requires_grad ] weight_decay_ignores = ["bias", "LayerNorm.weight"] + [ n for n, p in model.named_parameters() if not p.requires_grad ] # Make sure each dataloader is re-initialized batch_train_data_loader = misc_utils.get_dataloader(dataset=train_dataset, batch_size=batch_size, random=random) s_test = compute_s_test(n_gpu=n_gpu, device=device, model=model, test_inputs=test_inputs, train_data_loaders=[batch_train_data_loader], params_filter=params_filter, weight_decay=constants.WEIGHT_DECAY, weight_decay_ignores=weight_decay_ignores, damp=damp, scale=scale, num_samples=num_samples) return [X.cpu() for X in s_test]
def create_FAISS_index( train_task_name: str, trained_on_task_name: str, ) -> faiss_utils.FAISSIndex: if train_task_name not in ["mnli-2", "hans"]: raise ValueError if trained_on_task_name not in ["mnli-2", "hans"]: raise ValueError if trained_on_task_name == "mnli-2": tokenizer, model = misc_utils.create_tokenizer_and_model( constants.MNLI2_MODEL_PATH) if trained_on_task_name == "hans": tokenizer, model = misc_utils.create_tokenizer_and_model( constants.HANS_MODEL_PATH) train_dataset, _ = misc_utils.create_datasets(task_name=train_task_name, tokenizer=tokenizer) faiss_index = faiss_utils.FAISSIndex(768, "Flat") model.cuda() device = model.device train_batch_data_loader = misc_utils.get_dataloader(dataset=train_dataset, batch_size=128, random=False) for inputs in tqdm(train_batch_data_loader): for k, v in inputs.items(): inputs[k] = v.to(device) features = misc_utils.compute_BERT_CLS_feature(model, **inputs) features = features.cpu().detach().numpy() faiss_index.add(features) return faiss_index
def compute_influences_simplified( k: int, faiss_index: faiss_utils.FAISSIndex, model: torch.nn.Module, inputs: Dict[str, torch.Tensor], train_dataset: torch.utils.data.DataLoader, use_parallel: bool, s_test_damp: float, s_test_scale: float, s_test_num_samples: int, device_ids: Optional[List[int]] = None, precomputed_s_test: Optional[List[torch.FloatTensor]] = None, faiss_index_use_mean_features_as_query: bool = False, ) -> Tuple[Dict[int, float]]: # Make sure indices are sorted according to distances # KNN_distances[( # KNN_indices.squeeze(axis=0)[ # np.argsort(KNN_distances.squeeze(axis=0)) # ] != KNN_indices)] params_filter = [ n for n, p in model.named_parameters() if not p.requires_grad ] weight_decay_ignores = ["bias", "LayerNorm.weight"] + [ n for n, p in model.named_parameters() if not p.requires_grad ] if faiss_index is not None: features = misc_utils.compute_BERT_CLS_feature(model, **inputs) features = features.cpu().detach().numpy() if faiss_index_use_mean_features_as_query is True: # We use the mean embedding as the final query here features = features.mean(axis=0, keepdims=True) KNN_distances, KNN_indices = faiss_index.search(k=k, queries=features) else: KNN_indices = None if not use_parallel: model.cuda() batch_train_data_loader = misc_utils.get_dataloader(train_dataset, batch_size=1, random=True) instance_train_data_loader = misc_utils.get_dataloader(train_dataset, batch_size=1, random=False) influences, _, _ = nn_influence_utils.compute_influences( n_gpu=1, device=torch.device("cuda"), batch_train_data_loader=batch_train_data_loader, instance_train_data_loader=instance_train_data_loader, model=model, test_inputs=inputs, params_filter=params_filter, weight_decay=constants.WEIGHT_DECAY, weight_decay_ignores=weight_decay_ignores, s_test_damp=s_test_damp, s_test_scale=s_test_scale, s_test_num_samples=s_test_num_samples, train_indices_to_include=KNN_indices, precomputed_s_test=precomputed_s_test) else: if device_ids is None: raise ValueError("`device_ids` cannot be None") influences, _ = parallel.compute_influences_parallel( # Avoid clash with main process device_ids=device_ids, train_dataset=train_dataset, batch_size=1, model=model, test_inputs=inputs, params_filter=params_filter, weight_decay=constants.WEIGHT_DECAY, weight_decay_ignores=weight_decay_ignores, s_test_damp=s_test_damp, s_test_scale=s_test_scale, s_test_num_samples=s_test_num_samples, train_indices_to_include=KNN_indices, return_s_test=False, debug=False) return influences
def main( mode: str, num_examples_to_test: int = 5, num_repetitions: int = 4, ) -> List[Dict[str, Any]]: if mode not in ["only-correct", "only-incorrect"]: raise ValueError(f"Unrecognized mode {mode}") task_tokenizer, task_model = misc_utils.create_tokenizer_and_model( constants.MNLI_MODEL_PATH) train_dataset, eval_dataset = misc_utils.create_datasets( task_name="mnli", tokenizer=task_tokenizer) eval_instance_data_loader = misc_utils.get_dataloader(dataset=eval_dataset, batch_size=1, random=False) output_mode = glue_output_modes["mnli"] def build_compute_metrics_fn(task_name: str): def compute_metrics_fn(p): if output_mode == "classification": preds = np.argmax(p.predictions, axis=1) elif output_mode == "regression": preds = np.squeeze(p.predictions) return glue_compute_metrics(task_name, preds, p.label_ids) return compute_metrics_fn # Most of these arguments are placeholders # and are not really used at all, so ignore # the exact values of these. trainer = transformers.Trainer( model=task_model, args=TrainingArguments(output_dir="./tmp-output", per_device_train_batch_size=128, per_device_eval_batch_size=128, learning_rate=5e-5, logging_steps=100), data_collator=default_data_collator, train_dataset=train_dataset, eval_dataset=eval_dataset, compute_metrics=build_compute_metrics_fn("mnli"), ) task_model.cuda() num_examples_tested = 0 output_collections = [] for test_index, test_inputs in enumerate(eval_instance_data_loader): if num_examples_tested >= num_examples_to_test: break # Skip when we only want cases of correction prediction but the # prediction is incorrect, or vice versa prediction_is_correct = misc_utils.is_prediction_correct( trainer=trainer, model=task_model, inputs=test_inputs) if mode == "only-correct" and prediction_is_correct is False: continue if mode == "only-incorrect" and prediction_is_correct is True: continue for k, v in test_inputs.items(): if isinstance(v, torch.Tensor): test_inputs[k] = v.to(torch.device("cuda")) # with batch-size 128, 1500 iterations is enough for num_samples in range(700, 1300 + 1, 100): # 7 choices for batch_size in [1, 2, 4, 8, 16, 32, 64, 128]: # 8 choices for repetition in range(num_repetitions): print( f"Running #{test_index} " f"N={num_samples} " f"B={batch_size} " f"R={repetition} takes ...", end=" ") with Timer() as timer: s_test = one_experiment( model=task_model, train_dataset=train_dataset, test_inputs=test_inputs, batch_size=batch_size, random=True, n_gpu=1, device=torch.device("cuda"), damp=constants.DEFAULT_INFLUENCE_HPARAMS["mnli"] ["mnli"]["damp"], scale=constants.DEFAULT_INFLUENCE_HPARAMS["mnli"] ["mnli"]["scale"], num_samples=num_samples) time_elapsed = timer.elapsed print(f"{time_elapsed:.2f} seconds") outputs = { "test_index": test_index, "num_samples": num_samples, "batch_size": batch_size, "repetition": repetition, "s_test": s_test, "time_elapsed": time_elapsed, "correct": prediction_is_correct, } output_collections.append(outputs) remote_utils.save_and_mirror_scp_to_remote( object_to_save=outputs, file_name=f"stest.{mode}.{num_examples_to_test}." f"{test_index}.{num_samples}." f"{batch_size}.{repetition}.pth") num_examples_tested += 1 return output_collections
def imitator_main(mode: str, num_examples_to_test: int) -> List[Dict[str, Any]]: if mode not in ["only-correct", "only-incorrect"]: raise ValueError(f"Unrecognized mode {mode}") task_tokenizer, task_model = misc_utils.create_tokenizer_and_model( constants.MNLI_MODEL_PATH) imitator_tokenizer, imitator_model = misc_utils.create_tokenizer_and_model( constants.MNLI_IMITATOR_MODEL_PATH) (mnli_train_dataset, mnli_eval_dataset) = misc_utils.create_datasets(task_name="mnli", tokenizer=task_tokenizer) task_model.cuda() imitator_model.cuda() if task_model.training is True or imitator_model.training is True: raise ValueError("One of the model is in training mode") print(task_model.device, imitator_model.device) # Most of these arguments are placeholders # and are not really used at all, so ignore # the exact values of these. trainer = transformers.Trainer( model=task_model, args=TrainingArguments(output_dir="./tmp-output", per_device_train_batch_size=128, per_device_eval_batch_size=128, learning_rate=5e-5, logging_steps=100), ) eval_instance_data_loader = misc_utils.get_dataloader( mnli_eval_dataset, batch_size=1, data_collator=default_data_collator) train_inputs_collections = torch.load( constants.MNLI_TRAIN_INPUT_COLLECTIONS_PATH) inputs_by_label: Dict[str, List[int]] = defaultdict(list) for i in range(len(train_inputs_collections)): label = mnli_train_dataset.label_list[train_inputs_collections[i] ["labels"]] inputs_by_label[label].append(i) outputs_collections = [] for i, test_inputs in enumerate(eval_instance_data_loader): if mode == "only-correct" and i not in CORRECT_INDICES[: num_examples_to_test]: continue if mode == "only-incorrect" and i not in INCORRECT_INDICES[: num_examples_to_test]: continue start_time = time.time() for using_ground_truth in [True, False]: outputs = run_one_imitator_experiment( task_model=task_model, imitator_model=imitator_model, test_inputs=test_inputs, trainer=trainer, train_dataset=mnli_train_dataset, train_inputs_collections=train_inputs_collections, inputs_by_label=inputs_by_label, finetune_using_ground_truth_label=using_ground_truth) outputs["index"] = i outputs_collections.append(outputs) end_time = time.time() print(f"#{len(outputs_collections)}/{len(outputs_collections)}: " f"Elapsed {(end_time - start_time) / 60:.2f}") torch.save(outputs_collections, f"imiator_experiments.{mode}.{num_examples_to_test}.pt") return outputs_collections
def run_full_influence_functions( mode: str, num_examples_to_test: int, s_test_num_samples: int = 1000) -> Dict[int, Dict[str, Any]]: if mode not in ["only-correct", "only-incorrect"]: raise ValueError(f"Unrecognized mode {mode}") tokenizer, model = misc_utils.create_tokenizer_and_model( constants.MNLI_MODEL_PATH) (mnli_train_dataset, mnli_eval_dataset) = misc_utils.create_datasets(task_name="mnli", tokenizer=tokenizer) batch_train_data_loader = misc_utils.get_dataloader(mnli_train_dataset, batch_size=128, random=True) instance_train_data_loader = misc_utils.get_dataloader(mnli_train_dataset, batch_size=1, random=False) eval_instance_data_loader = misc_utils.get_dataloader( dataset=mnli_eval_dataset, batch_size=1, random=False) output_mode = glue_output_modes["mnli"] def build_compute_metrics_fn(task_name: str): def compute_metrics_fn(p): if output_mode == "classification": preds = np.argmax(p.predictions, axis=1) elif output_mode == "regression": preds = np.squeeze(p.predictions) return glue_compute_metrics(task_name, preds, p.label_ids) return compute_metrics_fn # Most of these arguments are placeholders # and are not really used at all, so ignore # the exact values of these. trainer = transformers.Trainer( model=model, args=TrainingArguments(output_dir="./tmp-output", per_device_train_batch_size=128, per_device_eval_batch_size=128, learning_rate=5e-5, logging_steps=100), data_collator=default_data_collator, train_dataset=mnli_train_dataset, eval_dataset=mnli_eval_dataset, compute_metrics=build_compute_metrics_fn("mnli"), ) params_filter = [ n for n, p in model.named_parameters() if not p.requires_grad ] weight_decay_ignores = ["bias", "LayerNorm.weight"] + [ n for n, p in model.named_parameters() if not p.requires_grad ] model.cuda() num_examples_tested = 0 outputs_collections = {} for test_index, test_inputs in enumerate(eval_instance_data_loader): if num_examples_tested >= num_examples_to_test: break # Skip when we only want cases of correction prediction but the # prediction is incorrect, or vice versa prediction_is_correct = misc_utils.is_prediction_correct( trainer=trainer, model=model, inputs=test_inputs) if mode == "only-correct" and prediction_is_correct is False: continue if mode == "only-incorrect" and prediction_is_correct is True: continue with Timer() as timer: influences, _, s_test = nn_influence_utils.compute_influences( n_gpu=1, device=torch.device("cuda"), batch_train_data_loader=batch_train_data_loader, instance_train_data_loader=instance_train_data_loader, model=model, test_inputs=test_inputs, params_filter=params_filter, weight_decay=constants.WEIGHT_DECAY, weight_decay_ignores=weight_decay_ignores, s_test_damp=5e-3, s_test_scale=1e4, s_test_num_samples=s_test_num_samples, train_indices_to_include=None, s_test_iterations=1, precomputed_s_test=None) outputs = { "test_index": test_index, "influences": influences, "s_test": s_test, "time": timer.elapsed, "correct": prediction_is_correct, } num_examples_tested += 1 outputs_collections[test_index] = outputs remote_utils.save_and_mirror_scp_to_remote( object_to_save=outputs, file_name= f"KNN-recall.{mode}.{num_examples_to_test}.{test_index}.pth") print( f"Status: #{test_index} | {num_examples_tested} / {num_examples_to_test}" ) return outputs_collections
def main( mode: str, train_task_name: str, eval_task_name: str, num_eval_to_collect: int, use_parallel: bool = True, kNN_k: Optional[int] = None, hans_heuristic: Optional[str] = None, trained_on_task_name: Optional[str] = None, ) -> List[Dict[str, Union[int, Dict[int, float]]]]: if train_task_name not in ["mnli", "mnli-2", "hans"]: raise ValueError if eval_task_name not in ["mnli", "mnli-2", "hans"]: raise ValueError if trained_on_task_name is None: # The task the model was trained on # can be different from `train_task_name` # which is used to determine on which the # influence values will be computed. trained_on_task_name = train_task_name if trained_on_task_name not in ["mnli", "mnli-2", "hans"]: raise ValueError if mode not in ["only-correct", "only-incorrect"]: raise ValueError(f"Unrecognized mode {mode}") if kNN_k is None: kNN_k = DEFAULT_KNN_K # `trained_on_task_name` determines the model to load if trained_on_task_name in ["mnli"]: tokenizer, model = misc_utils.create_tokenizer_and_model( constants.MNLI_MODEL_PATH) if trained_on_task_name in ["mnli-2"]: tokenizer, model = misc_utils.create_tokenizer_and_model( constants.MNLI2_MODEL_PATH) if trained_on_task_name in ["hans"]: tokenizer, model = misc_utils.create_tokenizer_and_model( constants.HANS_MODEL_PATH) train_dataset, _ = misc_utils.create_datasets(task_name=train_task_name, tokenizer=tokenizer) _, eval_dataset = misc_utils.create_datasets(task_name=eval_task_name, tokenizer=tokenizer) faiss_index = influence_helpers.load_faiss_index( trained_on_task_name=trained_on_task_name, train_task_name=train_task_name) trainer = Trainer( model=model, args=TrainingArguments(output_dir="./tmp-output", per_device_train_batch_size=128, per_device_eval_batch_size=128, learning_rate=5e-5, logging_steps=100), train_dataset=train_dataset, eval_dataset=eval_dataset, ) if eval_task_name in ["mnli", "mnli-2"]: eval_instance_data_loader = misc_utils.get_dataloader( dataset=eval_dataset, batch_size=1, random=False) if eval_task_name in ["hans"]: if hans_heuristic is None: raise ValueError("`hans_heuristic` cannot be None for now") hans_helper = HansHelper(hans_train_dataset=None, hans_eval_dataset=eval_dataset) _, eval_instance_data_loader = hans_helper.get_dataset_and_dataloader_of_heuristic( mode="eval", heuristic=hans_heuristic, batch_size=1, random=False) # Data-points where the model got wrong correct_input_collections = [] incorrect_input_collections = [] for index, test_inputs in enumerate(eval_instance_data_loader): logits, labels, step_eval_loss = misc_utils.predict(trainer=trainer, model=model, inputs=test_inputs) if logits.argmax(axis=-1).item() != labels.item(): incorrect_input_collections.append((index, test_inputs)) else: correct_input_collections.append((index, test_inputs)) if mode == "only-incorrect": input_collections = incorrect_input_collections else: input_collections = correct_input_collections # Other settings are not supported as of now (s_test_damp, s_test_scale, s_test_num_samples) = influence_helpers.select_s_test_config( trained_on_task_name=trained_on_task_name, train_task_name=train_task_name, eval_task_name=eval_task_name) influences_collections = [] for index, inputs in input_collections[:num_eval_to_collect]: print(f"#{index}") influences = influence_helpers.compute_influences_simplified( k=kNN_k, faiss_index=faiss_index, model=model, inputs=inputs, train_dataset=train_dataset, use_parallel=use_parallel, s_test_damp=s_test_damp, s_test_scale=s_test_scale, s_test_num_samples=s_test_num_samples, device_ids=[0, 1, 2, 3], precomputed_s_test=None) influences_collections.append({ "index": index, "influences": influences, }) remote_utils.save_and_mirror_scp_to_remote( object_to_save=influences_collections, file_name=(f"visualization" f".{mode}.{num_eval_to_collect}" f".{train_task_name}-{eval_task_name}" f"-{hans_heuristic}-{trained_on_task_name}" f".{kNN_k}.{use_parallel}.pth")) return influences_collections