Пример #1
0
def main(args):
    pruning_method = args.pruning_method
    ampere_pruning_method = args.ampere_pruning_method
    threshold = args.threshold

    model_name_or_path = args.model_name_or_path.rstrip("/")
    target_model_path = args.target_model_path

    print(f"Load fine-pruned model from {model_name_or_path}")
    model = torch.load(os.path.join(model_name_or_path, "pytorch_model.bin"))
    pruned_model = {}

    for name, tensor in model.items():
        if "embeddings" in name or "LayerNorm" in name or "pooler" in name:
            pruned_model[name] = tensor
            print(f"Copied layer {name}")
        elif "classifier" in name or "qa_output" in name:
            pruned_model[name] = tensor
            print(f"Copied layer {name}")
        elif "bias" in name:
            pruned_model[name] = tensor
            print(f"Copied layer {name}")
        else:
            if name.endswith(".weight"):
                pruned_model[
                    name] = MaskedLinear.masked_weights_from_state_dict(
                        model, name, pruning_method, threshold,
                        ampere_pruning_method)
            else:
                assert (MaskedLinear.check_name(name))

    if target_model_path is None:
        target_model_path = os.path.join(
            os.path.dirname(model_name_or_path),
            f"bertarized_{os.path.basename(model_name_or_path)}")

    if not os.path.isdir(target_model_path):
        shutil.copytree(model_name_or_path, target_model_path)
        print(f"\nCreated folder {target_model_path}")

    torch.save(pruned_model,
               os.path.join(target_model_path, "pytorch_model.bin"))
    print("\nPruned model saved! See you later!")
def main(args):
    serialization_dir = args.serialization_dir
    pruning_method = args.pruning_method
    threshold = args.threshold
    ampere_pruning_method = args.ampere_pruning_method

    st = torch.load(os.path.join(serialization_dir, "pytorch_model.bin"),
                    map_location="cuda")

    remaining_count = 0  # Number of remaining (not pruned) params in the encoder
    encoder_count = 0  # Number of params in the encoder

    print("name".ljust(60, " "), "Remaining Weights %", "Remaining Weight")
    for name, param in st.items():
        if "encoder" not in name:
            continue

        if name.endswith(".weight"):
            weights = MaskedLinear.masked_weights_from_state_dict(
                st, name, pruning_method, threshold, ampere_pruning_method)
            mask_ones = (weights != 0).sum().item()
            print(
                name.ljust(60, " "),
                str(round(100 * mask_ones / param.numel(), 3)).ljust(20, " "),
                str(mask_ones))

            remaining_count += mask_ones
        elif MaskedLinear.check_name(name):
            pass
        else:
            encoder_count += param.numel()
            if name.endswith(".weight") and ".".join(
                    name.split(".")[:-1] + ["mask_scores"]) in st:
                pass
            else:
                remaining_count += param.numel()

    print("")
    print("Remaining Weights (global) %: ",
          100 * remaining_count / encoder_count)