def mask_global(args, model, eval_dataloader): """ This method shows how to mask head (set some heads to zero), to test the effect on the network, based on the head importance scores, as described in Michel et al. (http://arxiv.org/abs/1905.10650) """ preds, labels = evaluate(args, model, eval_dataloader) preds = np.argmax( preds, axis=1) if args.output_mode == "classification" else np.squeeze(preds) original_score = compute_metrics(args.task_name, preds, labels)[args.metric_name] logger.info("Pruning: original score: %f, threshold: %f", original_score, original_score * args.masking_threshold) logger.info("Finding Gobal magnitude mask") model = model.to(torch.device("cpu")) new_mask = prune_model(model, amount=None, mode=prune.Identity) current_score = original_score total_steps = 10 while total_steps != 0 and current_score >= original_score * args.masking_threshold: print("Current Score:", current_score) print("Original Score:", original_score) total_steps -= 1 # Head New mask mask = {k: v.clone() for k, v in new_mask.items()} # save current head mask print("Total pruned Amount", (sum([(1 - v).sum() for v in mask.values()]) / sum([v.numel() for v in mask.values()]))) model = model.to(torch.device("cpu")) new_mask = prune_model(model, amount=args.masking_amount, mode=prune.L1Unstructured) model = model.to(args.device) preds, labels = evaluate(args, model, eval_dataloader) preds = np.argmax( preds, axis=1 ) if args.output_mode == "classification" else np.squeeze(preds) current_score = compute_metrics(args.task_name, preds, labels)[args.metric_name] logger.info( "Masking: current score: %f", current_score, ) total_masked = sum([(1 - v).sum() for v in mask.values()]) total_elements = sum([v.numel() for v in mask.values()]) logger.info(f"Prunned {total_masked/total_elements * 100:.2f}% of weights") torch.save(mask, os.path.join(args.output_dir, "magnitude_mask.p")) return mask, float(total_masked / total_elements)
def prune_heads(args, model, eval_dataloader, head_mask): """ This method shows how to prune head (remove heads weights) based on the head importance scores as described in Michel et al. (http://arxiv.org/abs/1905.10650) """ # Try pruning and test time speedup # Pruning is like masking but we actually remove the masked weights before_time = datetime.now() _, _, preds, labels = compute_heads_importance( args, model, eval_dataloader, compute_entropy=False, compute_importance=False, head_mask=head_mask ) preds = np.argmax(preds, axis=1) if args.output_mode == "classification" else np.squeeze(preds) score_masking = compute_metrics(args.task_name, preds, labels)[args.metric_name] original_time = datetime.now() - before_time original_num_params = sum(p.numel() for p in model.parameters()) heads_to_prune = {} for layer in range(len(head_mask)): heads_to_mask = [h[0] for h in (1 - head_mask[layer].long()).nonzero().tolist()] heads_to_prune[layer] = heads_to_mask assert sum(len(h) for h in heads_to_prune.values()) == (1 - head_mask.long()).sum().item() logger.info(f"{heads_to_prune}") model.prune_heads(heads_to_prune) pruned_num_params = sum(p.numel() for p in model.parameters()) before_time = datetime.now() _, _, preds, labels = compute_heads_importance( args, model, eval_dataloader, compute_entropy=False, compute_importance=False, head_mask=None ) preds = np.argmax(preds, axis=1) if args.output_mode == "classification" else np.squeeze(preds) score_pruning = compute_metrics(args.task_name, preds, labels)[args.metric_name] new_time = datetime.now() - before_time logger.info( "Pruning: original num of params: %.2e, after pruning %.2e (%.1f percents)", original_num_params, pruned_num_params, pruned_num_params / original_num_params * 100, ) logger.info("Pruning: score with masking: %f score with pruning: %f", score_masking, score_pruning) logger.info("Pruning: speed ratio (new timing / original timing): %f percents", original_time / new_time * 100)
def evaluate(args, task_name, data_dir, model, tokenizer, head_mask=None): # Loop to handle MNLI double evaluation (matched, mis-matched) eval_task_names = ("mnli", "mnli-mm") if task_name == "mnli" else (task_name,) results = {} for eval_task in eval_task_names: eval_dataset = load_and_cache_examples(args, data_dir, eval_task, tokenizer, evaluate=True) args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu) # Note that DistributedSampler samples randomly eval_sampler = SequentialSampler(eval_dataset) eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size) # Eval! logger.info("***** Running evaluation *****") logger.info(" Num examples = %d", len(eval_dataset)) logger.info(" Batch size = %d", args.eval_batch_size) eval_loss = 0.0 nb_eval_steps = 0 preds = None out_label_ids = None if head_mask is not None: head_mask = torch.tensor(head_mask, device=args.device) for batch in tqdm(eval_dataloader, desc="Evaluating"): model.eval() batch = tuple(t.to(args.device) for t in batch) with torch.no_grad(): inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]} if args.model_type != "distilbert": inputs["token_type_ids"] = ( batch[2] if args.model_type in ["bert", "xlnet", "albert"] else None ) # XLM, DistilBERT, RoBERTa, and XLM-RoBERTa don't use segment_ids outputs = model(**inputs, head_mask=head_mask) tmp_eval_loss, logits = outputs[:2] eval_loss += tmp_eval_loss.mean().item() nb_eval_steps += 1 if preds is None: preds = logits.detach().cpu().numpy() out_label_ids = inputs["labels"].detach().cpu().numpy() else: preds = np.append(preds, logits.detach().cpu().numpy(), axis=0) out_label_ids = np.append(out_label_ids, inputs["labels"].detach().cpu().numpy(), axis=0) eval_loss = eval_loss / nb_eval_steps processor = processors[eval_task]() label_list = processor.get_labels() if eval_task in ["mnli", "mnli-mm"] and args.model_type in ["roberta", "xlmroberta"]: # HACK(label indices are swapped in RoBERTa pretrained model) label_list[1], label_list[2] = label_list[2], label_list[1] label_map = {int(i): label for i, label in enumerate(label_list)} if args.output_mode == "classification": preds = np.argmax(preds, axis=1) pred_outputs = [label_map[p] for p in preds] elif args.output_mode == "regression": preds = np.squeeze(preds) pred_outputs = [float(p) for p in preds] result = compute_metrics(eval_task, preds, out_label_ids) for k, v in result.items(): if "mnli" in eval_task: k = f"{eval_task}_{k}" results[k] = v if "predictions" not in results: results["predictions"] = {} if "mnli" in eval_task: results["predictions"][eval_task] = pred_outputs else: results["predictions"] = pred_outputs return results
def mask_heads_mlps(args, model, eval_dataloader): """ This method shows how to mask head (set some heads to zero), to test the effect on the network, based on the head importance scores, as described in Michel et al. (http://arxiv.org/abs/1905.10650) """ head_importance, mlp_importance, preds, labels = compute_heads_mlps_importance( args, model, eval_dataloader) preds = np.argmax( preds, axis=1) if args.output_mode == "classification" else np.squeeze(preds) original_score = compute_metrics(args.task_name, preds, labels)[args.metric_name] logger.info("Pruning: original score: %f, threshold: %f", original_score, original_score * args.masking_threshold) new_mlp_mask = torch.ones_like(mlp_importance) new_head_mask = torch.ones_like(head_importance) num_to_mask = max(1, int(new_head_mask.numel() * args.masking_amount)) logger.info("Finding additional head masks") best_score = original_score current_score = best_score iteration = 0 while current_score >= original_score * args.masking_threshold: best_score = current_score # Head New mask head_mask = new_head_mask.clone() # save current head mask if args.save_mask_all_iterations: np.save( os.path.join(args.output_dir, f"head_mask_{iteration}.npy"), head_mask.detach().cpu().numpy()) np.save( os.path.join(args.output_dir, f"head_importance_{iteration}.npy"), head_importance.detach().cpu().numpy()) ###################### heads from least important to most - keep only not-masked heads head_importance[head_mask == 0.0] = float("Inf") current_heads_to_mask = head_importance.view(-1).sort()[1] # mask heads selected_heads_to_mask = [] for head in current_heads_to_mask: if len(selected_heads_to_mask ) == num_to_mask or head_importance.view(-1)[ head.item()] == float("Inf"): break layer_idx = head.item() // model.bert.config.num_attention_heads head_idx = head.item() % model.bert.config.num_attention_heads new_head_mask[layer_idx][head_idx] = 0.0 selected_heads_to_mask.append(head.item()) if not selected_heads_to_mask: break logger.info("Heads to mask: %s", str(selected_heads_to_mask)) #new_head_mask = new_head_mask.view_as(head_mask) print_2d_tensor(new_head_mask) ################################### MLP new mask mlp_mask = new_mlp_mask.clone() # save current mlp mask if args.save_mask_all_iterations: np.save(os.path.join(args.output_dir, f"mlp_mask_{iteration}.npy"), mlp_mask.detach().cpu().numpy()) np.save( os.path.join(args.output_dir, f"mlp_importance_{iteration}.npy"), mlp_importance.detach().cpu().numpy()) iteration += 1 # mlps from least important to most - keep only not-masked heads mlp_importance[mlp_mask == 0.0] = float("Inf") current_mlps_to_mask = mlp_importance.sort()[1] mlp_to_mask = current_mlps_to_mask[0] if mlp_importance[mlp_to_mask] == float("Inf"): break new_mlp_mask[mlp_to_mask] = 0.0 logger.info("MLP Layer to mask: %s", str(current_mlps_to_mask[0])) print_1d_tensor(new_mlp_mask) # Compute metric and head,mlp importance again head_importance, mlp_importance, preds, labels = compute_heads_mlps_importance( args, model, eval_dataloader, head_mask=new_head_mask, mlp_mask=new_mlp_mask) preds = np.argmax( preds, axis=1 ) if args.output_mode == "classification" else np.squeeze(preds) current_score = compute_metrics(args.task_name, preds, labels)[args.metric_name] logger.info( "MLP Masking: current score: %f, remaining mlps %d (%.1f percents)", current_score, new_mlp_mask.sum(), new_mlp_mask.sum() / new_mlp_mask.numel() * 100, ) logger.info( "Head Masking: current score: %f, remaining heads %d (%.1f percents)", current_score, new_head_mask.sum(), new_head_mask.sum() / new_head_mask.numel() * 100, ) logger.info("Finding additional head masks") current_score = best_score new_head_mask = head_mask # Only Heads while current_score >= original_score * args.masking_threshold: # Head New mask head_mask = new_head_mask.clone() # save current head mask if args.save_mask_all_iterations: np.save( os.path.join(args.output_dir, f"head_mask_{iteration}.npy"), head_mask.detach().cpu().numpy()) np.save( os.path.join(args.output_dir, f"head_importance_{iteration}.npy"), head_importance.detach().cpu().numpy()) iteration += 1 best_score = current_score ###################### heads from least important to most - keep only not-masked heads head_importance[head_mask == 0.0] = float("Inf") current_heads_to_mask = head_importance.view(-1).sort()[1] # mask heads selected_heads_to_mask = [] for head in current_heads_to_mask: if len(selected_heads_to_mask ) == num_to_mask // 2 or head_importance.view(-1)[ head.item()] == float("Inf"): break layer_idx = head.item() // model.bert.config.num_attention_heads head_idx = head.item() % model.bert.config.num_attention_heads new_head_mask[layer_idx][head_idx] = 0.0 selected_heads_to_mask.append(head.item()) if not selected_heads_to_mask: break logger.info("Heads to mask: %s", str(selected_heads_to_mask)) print_2d_tensor(new_head_mask) # Compute metric and head,mlp importance again head_importance, mlp_importance, preds, labels = compute_heads_mlps_importance( args, model, eval_dataloader, head_mask=new_head_mask, mlp_mask=mlp_mask) preds = np.argmax( preds, axis=1 ) if args.output_mode == "classification" else np.squeeze(preds) current_score = compute_metrics(args.task_name, preds, labels)[args.metric_name] logger.info( "Head Masking: current score: %f, remaining heads %d (%.1f percents)", current_score, new_head_mask.sum(), new_head_mask.sum() / new_head_mask.numel() * 100, ) logger.info("Finding additional MLP masks") current_score = best_score new_mlp_mask = mlp_mask while current_score >= original_score * args.masking_threshold: best_score = current_score ################################### MLP new mask mlp_mask = new_mlp_mask.clone() # save current mlp mask if args.save_mask_all_iterations: np.save(os.path.join(args.output_dir, f"mlp_mask_{iteration}.npy"), mlp_mask.detach().cpu().numpy()) np.save( os.path.join(args.output_dir, f"mlp_importance_{iteration}.npy"), mlp_importance.detach().cpu().numpy()) iteration += 1 # mlps from least important to most - keep only not-masked heads mlp_importance[mlp_mask == 0.0] = float("Inf") current_mlps_to_mask = mlp_importance.sort()[1] mlp_to_mask = current_mlps_to_mask[0] if mlp_importance[mlp_to_mask] == float("Inf"): break new_mlp_mask[mlp_to_mask] = 0.0 logger.info("MLP Layer to mask: %s", str(current_mlps_to_mask[0])) print_1d_tensor(new_mlp_mask) # Compute metric and head,mlp importance again head_importance, mlp_importance, preds, labels = compute_heads_mlps_importance( args, model, eval_dataloader, head_mask=head_mask, mlp_mask=new_mlp_mask) preds = np.argmax( preds, axis=1 ) if args.output_mode == "classification" else np.squeeze(preds) current_score = compute_metrics(args.task_name, preds, labels)[args.metric_name] logger.info( "MLP Masking: current score: %f, remaining mlps %d (%.1f percents)", current_score, new_mlp_mask.sum(), new_mlp_mask.sum() / new_mlp_mask.numel() * 100, ) logger.info("Final head mask") print_2d_tensor(head_mask) logger.info("Final mlp mask") print_1d_tensor(mlp_mask) np.save(os.path.join(args.output_dir, "head_mask.npy"), head_mask.detach().cpu().numpy()) np.save(os.path.join(args.output_dir, "mlp_mask.npy"), mlp_mask.detach().cpu().numpy()) return head_mask, mlp_mask
def evaluate(args, model, tokenizer, prefix=""): # Loop to handle MNLI double evaluation (matched, mis-matched) eval_task_names = ("mnli", "mnli-mm") if args.task_name == "mnli" else ( args.task_name, ) eval_outputs_dirs = (args.output_dir, args.output_dir) if args.task_name == "mnli" else ( args.output_dir, ) results = {} for eval_task, eval_output_dir in zip(eval_task_names, eval_outputs_dirs): eval_dataset = load_and_cache_examples(args, eval_task, tokenizer, evaluate=True) if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]: os.makedirs(eval_output_dir) args.eval_batch_size = args.per_gpu_eval_batch_size * max( 1, args.n_gpu) # Note that DistributedSampler samples randomly eval_sampler = SequentialSampler(eval_dataset) eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size) # multi-gpu eval if args.n_gpu > 1 and not isinstance(model, torch.nn.DataParallel): model = torch.nn.DataParallel(model) # Eval! logger.info("***** Running evaluation {} *****".format(prefix)) logger.info(" Num examples = %d", len(eval_dataset)) logger.info(" Batch size = %d", args.eval_batch_size) eval_loss = 0.0 nb_eval_steps = 0 preds = None out_label_ids = None for batch in tqdm(eval_dataloader, desc="Evaluating"): model.eval() batch = tuple(t.to(args.device) for t in batch) with torch.no_grad(): inputs = { "input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3] } if args.model_type != "distilbert": inputs["token_type_ids"] = ( batch[2] if args.model_type in ["bert", "xlnet", "albert"] else None ) # XLM, DistilBERT, RoBERTa, and XLM-RoBERTa don't use segment_ids outputs = model(**inputs) tmp_eval_loss, logits = outputs[:2] eval_loss += tmp_eval_loss.mean().item() nb_eval_steps += 1 if preds is None: preds = logits.detach().cpu().numpy() out_label_ids = inputs["labels"].detach().cpu().numpy() else: preds = np.append(preds, logits.detach().cpu().numpy(), axis=0) out_label_ids = np.append( out_label_ids, inputs["labels"].detach().cpu().numpy(), axis=0) eval_loss = eval_loss / nb_eval_steps if args.output_mode == "classification": preds = np.argmax(preds, axis=1) elif args.output_mode == "regression": preds = np.squeeze(preds) result = compute_metrics(eval_task, preds, out_label_ids) results.update(result) file_name = "eval_results.txt" if eval_task == "mnli-mm": file_name = "eval_results_mm.txt" output_eval_file = os.path.join(eval_output_dir, prefix, file_name) with open(output_eval_file, "w") as writer: logger.info("***** Eval results {} *****".format(prefix)) for key in sorted(result.keys()): logger.info(" %s = %s", key, str(result[key])) writer.write("%s = %s\n" % (key, str(result[key]))) return results