def main(args): pruning_method = args.pruning_method ampere_pruning_method = args.ampere_pruning_method threshold = args.threshold model_name_or_path = args.model_name_or_path.rstrip("/") target_model_path = args.target_model_path print(f"Load fine-pruned model from {model_name_or_path}") model = torch.load(os.path.join(model_name_or_path, "pytorch_model.bin")) pruned_model = {} for name, tensor in model.items(): if "embeddings" in name or "LayerNorm" in name or "pooler" in name: pruned_model[name] = tensor print(f"Copied layer {name}") elif "classifier" in name or "qa_output" in name: pruned_model[name] = tensor print(f"Copied layer {name}") elif "bias" in name: pruned_model[name] = tensor print(f"Copied layer {name}") else: if name.endswith(".weight"): pruned_model[ name] = MaskedLinear.masked_weights_from_state_dict( model, name, pruning_method, threshold, ampere_pruning_method) else: assert (MaskedLinear.check_name(name)) if target_model_path is None: target_model_path = os.path.join( os.path.dirname(model_name_or_path), f"bertarized_{os.path.basename(model_name_or_path)}") if not os.path.isdir(target_model_path): shutil.copytree(model_name_or_path, target_model_path) print(f"\nCreated folder {target_model_path}") torch.save(pruned_model, os.path.join(target_model_path, "pytorch_model.bin")) print("\nPruned model saved! See you later!")
def main(args): serialization_dir = args.serialization_dir pruning_method = args.pruning_method threshold = args.threshold ampere_pruning_method = args.ampere_pruning_method st = torch.load(os.path.join(serialization_dir, "pytorch_model.bin"), map_location="cuda") remaining_count = 0 # Number of remaining (not pruned) params in the encoder encoder_count = 0 # Number of params in the encoder print("name".ljust(60, " "), "Remaining Weights %", "Remaining Weight") for name, param in st.items(): if "encoder" not in name: continue if name.endswith(".weight"): weights = MaskedLinear.masked_weights_from_state_dict( st, name, pruning_method, threshold, ampere_pruning_method) mask_ones = (weights != 0).sum().item() print( name.ljust(60, " "), str(round(100 * mask_ones / param.numel(), 3)).ljust(20, " "), str(mask_ones)) remaining_count += mask_ones elif MaskedLinear.check_name(name): pass else: encoder_count += param.numel() if name.endswith(".weight") and ".".join( name.split(".")[:-1] + ["mask_scores"]) in st: pass else: remaining_count += param.numel() print("") print("Remaining Weights (global) %: ", 100 * remaining_count / encoder_count)