def main(args): pruning_method = args.pruning_method ampere_pruning_method = args.ampere_pruning_method threshold = args.threshold model_name_or_path = args.model_name_or_path.rstrip("/") target_model_path = args.target_model_path print(f"Load fine-pruned model from {model_name_or_path}") model = torch.load(os.path.join(model_name_or_path, "pytorch_model.bin")) pruned_model = {} for name, tensor in model.items(): if "embeddings" in name or "LayerNorm" in name or "pooler" in name: pruned_model[name] = tensor print(f"Copied layer {name}") elif "classifier" in name or "qa_output" in name: pruned_model[name] = tensor print(f"Copied layer {name}") elif "bias" in name: pruned_model[name] = tensor print(f"Copied layer {name}") else: if name.endswith(".weight"): pruned_model[ name] = MaskedLinear.masked_weights_from_state_dict( model, name, pruning_method, threshold, ampere_pruning_method) else: assert (MaskedLinear.check_name(name)) if target_model_path is None: target_model_path = os.path.join( os.path.dirname(model_name_or_path), f"bertarized_{os.path.basename(model_name_or_path)}") if not os.path.isdir(target_model_path): shutil.copytree(model_name_or_path, target_model_path) print(f"\nCreated folder {target_model_path}") torch.save(pruned_model, os.path.join(target_model_path, "pytorch_model.bin")) print("\nPruned model saved! See you later!")
def __init__(self, config): super().__init__() self.dense = MaskedLinear( config.intermediate_size, config.hidden_size, pruning_method=config.pruning_method, mask_init=config.mask_init, mask_scale=config.mask_scale, ) self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def main(args): serialization_dir = args.serialization_dir pruning_method = args.pruning_method threshold = args.threshold ampere_pruning_method = args.ampere_pruning_method st = torch.load(os.path.join(serialization_dir, "pytorch_model.bin"), map_location="cuda") remaining_count = 0 # Number of remaining (not pruned) params in the encoder encoder_count = 0 # Number of params in the encoder print("name".ljust(60, " "), "Remaining Weights %", "Remaining Weight") for name, param in st.items(): if "encoder" not in name: continue if name.endswith(".weight"): weights = MaskedLinear.masked_weights_from_state_dict( st, name, pruning_method, threshold, ampere_pruning_method) mask_ones = (weights != 0).sum().item() print( name.ljust(60, " "), str(round(100 * mask_ones / param.numel(), 3)).ljust(20, " "), str(mask_ones)) remaining_count += mask_ones elif MaskedLinear.check_name(name): pass else: encoder_count += param.numel() if name.endswith(".weight") and ".".join( name.split(".")[:-1] + ["mask_scores"]) in st: pass else: remaining_count += param.numel() print("") print("Remaining Weights (global) %: ", 100 * remaining_count / encoder_count)
def __init__(self, config): super().__init__() self.dense = MaskedLinear( config.hidden_size, config.intermediate_size, pruning_method=config.pruning_method, mask_init=config.mask_init, mask_scale=config.mask_scale, ) if isinstance(config.hidden_act, str): self.intermediate_act_fn = ACT2FN[config.hidden_act] else: self.intermediate_act_fn = config.hidden_act
def __init__(self, config): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr( config, "embedding_size"): raise ValueError( "The hidden size (%d) is not a multiple of the number of attention " "heads (%d)" % (config.hidden_size, config.num_attention_heads)) self.output_attentions = config.output_attentions self.num_attention_heads = config.num_attention_heads self.attention_head_size = int(config.hidden_size / config.num_attention_heads) self.all_head_size = self.num_attention_heads * self.attention_head_size self.query = MaskedLinear( config.hidden_size, self.all_head_size, pruning_method=config.pruning_method, mask_init=config.mask_init, mask_scale=config.mask_scale, ) self.key = MaskedLinear( config.hidden_size, self.all_head_size, pruning_method=config.pruning_method, mask_init=config.mask_init, mask_scale=config.mask_scale, ) self.value = MaskedLinear( config.hidden_size, self.all_head_size, pruning_method=config.pruning_method, mask_init=config.mask_init, mask_scale=config.mask_scale, ) self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
def create_masked_linear(in_features, out_features, config, bias=True): ret = MaskedLinear( in_features=in_features, out_features=out_features, pruning_method=config.pruning_method, mask_init=config.mask_init, mask_scale=config.mask_scale, mask_block_rows=config.mask_block_rows, mask_block_cols=config.mask_block_cols, ampere_pruning_method=config.ampere_pruning_method, ampere_mask_init=config.ampere_mask_init, ampere_mask_scale=config.ampere_mask_scale, shuffling_method=config.shuffling_method, in_shuffling_group=config.in_shuffling_group, out_shuffling_group=config.out_shuffling_group, ) return ret