Beispiel #1
0
def main(args):
    pruning_method = args.pruning_method
    ampere_pruning_method = args.ampere_pruning_method
    threshold = args.threshold

    model_name_or_path = args.model_name_or_path.rstrip("/")
    target_model_path = args.target_model_path

    print(f"Load fine-pruned model from {model_name_or_path}")
    model = torch.load(os.path.join(model_name_or_path, "pytorch_model.bin"))
    pruned_model = {}

    for name, tensor in model.items():
        if "embeddings" in name or "LayerNorm" in name or "pooler" in name:
            pruned_model[name] = tensor
            print(f"Copied layer {name}")
        elif "classifier" in name or "qa_output" in name:
            pruned_model[name] = tensor
            print(f"Copied layer {name}")
        elif "bias" in name:
            pruned_model[name] = tensor
            print(f"Copied layer {name}")
        else:
            if name.endswith(".weight"):
                pruned_model[
                    name] = MaskedLinear.masked_weights_from_state_dict(
                        model, name, pruning_method, threshold,
                        ampere_pruning_method)
            else:
                assert (MaskedLinear.check_name(name))

    if target_model_path is None:
        target_model_path = os.path.join(
            os.path.dirname(model_name_or_path),
            f"bertarized_{os.path.basename(model_name_or_path)}")

    if not os.path.isdir(target_model_path):
        shutil.copytree(model_name_or_path, target_model_path)
        print(f"\nCreated folder {target_model_path}")

    torch.save(pruned_model,
               os.path.join(target_model_path, "pytorch_model.bin"))
    print("\nPruned model saved! See you later!")
Beispiel #2
0
 def __init__(self, config):
     super().__init__()
     self.dense = MaskedLinear(
         config.intermediate_size,
         config.hidden_size,
         pruning_method=config.pruning_method,
         mask_init=config.mask_init,
         mask_scale=config.mask_scale,
     )
     self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
     self.dropout = nn.Dropout(config.hidden_dropout_prob)
def main(args):
    serialization_dir = args.serialization_dir
    pruning_method = args.pruning_method
    threshold = args.threshold
    ampere_pruning_method = args.ampere_pruning_method

    st = torch.load(os.path.join(serialization_dir, "pytorch_model.bin"),
                    map_location="cuda")

    remaining_count = 0  # Number of remaining (not pruned) params in the encoder
    encoder_count = 0  # Number of params in the encoder

    print("name".ljust(60, " "), "Remaining Weights %", "Remaining Weight")
    for name, param in st.items():
        if "encoder" not in name:
            continue

        if name.endswith(".weight"):
            weights = MaskedLinear.masked_weights_from_state_dict(
                st, name, pruning_method, threshold, ampere_pruning_method)
            mask_ones = (weights != 0).sum().item()
            print(
                name.ljust(60, " "),
                str(round(100 * mask_ones / param.numel(), 3)).ljust(20, " "),
                str(mask_ones))

            remaining_count += mask_ones
        elif MaskedLinear.check_name(name):
            pass
        else:
            encoder_count += param.numel()
            if name.endswith(".weight") and ".".join(
                    name.split(".")[:-1] + ["mask_scores"]) in st:
                pass
            else:
                remaining_count += param.numel()

    print("")
    print("Remaining Weights (global) %: ",
          100 * remaining_count / encoder_count)
Beispiel #4
0
 def __init__(self, config):
     super().__init__()
     self.dense = MaskedLinear(
         config.hidden_size,
         config.intermediate_size,
         pruning_method=config.pruning_method,
         mask_init=config.mask_init,
         mask_scale=config.mask_scale,
     )
     if isinstance(config.hidden_act, str):
         self.intermediate_act_fn = ACT2FN[config.hidden_act]
     else:
         self.intermediate_act_fn = config.hidden_act
Beispiel #5
0
    def __init__(self, config):
        super().__init__()
        if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
                config, "embedding_size"):
            raise ValueError(
                "The hidden size (%d) is not a multiple of the number of attention "
                "heads (%d)" %
                (config.hidden_size, config.num_attention_heads))
        self.output_attentions = config.output_attentions

        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(config.hidden_size /
                                       config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        self.query = MaskedLinear(
            config.hidden_size,
            self.all_head_size,
            pruning_method=config.pruning_method,
            mask_init=config.mask_init,
            mask_scale=config.mask_scale,
        )
        self.key = MaskedLinear(
            config.hidden_size,
            self.all_head_size,
            pruning_method=config.pruning_method,
            mask_init=config.mask_init,
            mask_scale=config.mask_scale,
        )
        self.value = MaskedLinear(
            config.hidden_size,
            self.all_head_size,
            pruning_method=config.pruning_method,
            mask_init=config.mask_init,
            mask_scale=config.mask_scale,
        )

        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
def create_masked_linear(in_features, out_features, config, bias=True):
    ret = MaskedLinear(
        in_features=in_features,
        out_features=out_features,
        pruning_method=config.pruning_method,
        mask_init=config.mask_init,
        mask_scale=config.mask_scale,
        mask_block_rows=config.mask_block_rows,
        mask_block_cols=config.mask_block_cols,
        ampere_pruning_method=config.ampere_pruning_method,
        ampere_mask_init=config.ampere_mask_init,
        ampere_mask_scale=config.ampere_mask_scale,
        shuffling_method=config.shuffling_method,
        in_shuffling_group=config.in_shuffling_group,
        out_shuffling_group=config.out_shuffling_group,
    )
    return ret