Exemple #1
0
 def __init__(self, sparse_args, device, cache_dir, logit_names,
              teacher_constructor):
     # logit_names is ["start_logits", "end_logits"] for qa, ["logits"] for glue etc
     # teacher model is AutoModelForQuestionAnswering for qa, AutoModelForSequenceClassification for glue etc
     self.sparse_args = sparse_args
     self.patcher_context = PatcherContext()
     self.teacher_constructor = teacher_constructor
     self.teacher = self.create_teacher(device, cache_dir)
     self.logit_names = logit_names
 def __init__(self, sparse_args, device, cache_dir, model_name_or_path, logit_names, teacher_constructor):
     # logit_names is ["start_logits", "end_logits"] for qa, ["logits"] for glue etc
     # teacher model is AutoModelForQuestionAnswering for qa, AutoModelForSequenceClassification for glue etc
     self.sparse_args = sparse_args
     self.patcher_context = PatcherContext()
     self.teacher_constructor = teacher_constructor
     self.device = device
     self.cache_dir = cache_dir
     self.teacher = None
     self.layer_head_mask = self.create_head_rewind_info(device, cache_dir)
     self.logit_names = logit_names
     self.model_name_or_path = model_name_or_path
     config = AutoConfig.from_pretrained(model_name_or_path, cache_dir=cache_dir)
     self.model_structure = struct_from_config(config.__class__)
Exemple #3
0
    def test_patch_module_tied_attention(self):
        config = BertConfig.from_pretrained("bert-base-uncased")
        model = BertForQuestionAnswering(config)

        parameters = LinearPruningParameters(
            method="topK",
            submethod="default",
            ampere_method="annealing",
            block_rows=32,
            block_cols=32,
        )

        context = PatcherContext()

        p_attention = JointPruningModulePatcher(context, parameters, "attention")
        p_dense = LinearPruningModulePatcher(context, parameters)

        module_patchers = dict(
            query=p_attention,
            key=p_attention,
            value=p_attention,
            att_dense=p_dense,
            interm_dense=p_dense,
            output_dense=p_dense,
        )

        patcher = BertLinearModelPatcher(module_patchers)
        patcher.patch(model)

        self.assertEqual(patcher.stats["patched"], 72)
        key_sizes = {k: len(v) for k, v in context.context_modules.items()}

        self.assertEqual(key_sizes, {"ampere_mask": 72, "mask": 48})
Exemple #4
0
    def test_patch_module_ampere(self):
        config = BertConfig.from_pretrained("bert-base-uncased")
        model = BertForQuestionAnswering(config)

        parameters = LinearPruningArgs(
            method="topK",
            submethod="default",
            ampere_method="annealing",
            block_rows=32,
            block_cols=32,
            min_elements=0.005,
        )

        context = PatcherContext()

        p = LinearPruningModulePatcher(context, parameters, self.MODEL_STRUCTURE)

        module_patchers = dict(query=p, key=p, value=p, att_dense=p, interm_dense=p, output_dense=p)

        patcher = LinearModelPatcher(module_patchers, self.MODEL_STRUCTURE)
        patcher.patch(model)

        self.assertEqual(patcher.stats["patched"], 72)
        key_sizes = {k: len(v) for k, v in context.context_modules.items()}

        self.assertEqual(key_sizes, {"ampere_mask": 72, "mask": 72})
Exemple #5
0
    def test_patch_tiedattention_line_pruning(self):
        config = BertConfig.from_pretrained("bert-base-uncased")
        model = BertForQuestionAnswering(config)

        parameters_attention = LinearPruningArgs(
            method="topK",
            submethod="default",
            ampere_method="annealing",
            block_rows=32,
            block_cols=32,
            min_elements=0.005,
        )

        parameters_dense = LinearPruningArgs(
            method="topK", submethod="1d", ampere_method="annealing", block_rows=32, block_cols=32, min_elements=0.005
        )

        context = PatcherContext()

        p_attention = JointPruningModulePatcher(context, parameters_attention, self.MODEL_STRUCTURE, suffix=".attention")
        p_dense = ChannelPruningModulePatcher(context, parameters_dense, self.MODEL_STRUCTURE, suffix="dense")

        module_patchers = dict(
            query=p_attention,
            key=p_attention,
            value=p_attention,
            att_dense=p_dense,
            interm_dense=p_dense,
            output_dense=p_dense,
        )

        patcher = LinearModelPatcher(module_patchers, self.MODEL_STRUCTURE)
        patcher.patch(model)

        self.assertEqual(patcher.stats["patched"], 72)
        key_sizes = {k: len(v) for k, v in context.context_modules.items()}

        for k, v in key_sizes.items():
            print(k, v)

        for k, v in context.context_modules.items():
            print(k, v)
        self.assertEqual(key_sizes, {"ampere_mask": 72, "mask": 12, "mask_1d": 48})
Exemple #6
0
class ModelPatchingCoordinator:
    MODEL_STRUCTURE = BertStructure

    def __init__(self, sparse_args, device, cache_dir, logit_names,
                 teacher_constructor):
        # logit_names is ["start_logits", "end_logits"] for qa, ["logits"] for glue etc
        # teacher model is AutoModelForQuestionAnswering for qa, AutoModelForSequenceClassification for glue etc
        self.sparse_args = sparse_args
        self.patcher_context = PatcherContext()
        self.teacher_constructor = teacher_constructor
        self.teacher = self.create_teacher(device, cache_dir)
        self.logit_names = logit_names

    def parse_pruning_method(self, method):
        parts = method.split(":")
        if len(parts) == 2:
            return parts
        elif len(parts) == 1:
            return parts[0], "default"
        else:
            raise RuntimeError("Could not parse pruning method")

    def log(self):
        logs = {}
        for k, v in self.patcher_context.enumerate_context_data():
            logs[k] = v

        return logs

    def create_teacher(self, device, cache_dir):
        sparse_args = self.sparse_args

        if sparse_args.distil_teacher_name_or_path is not None:
            assert sparse_args.distil_alpha_ce > 0.0
            assert sparse_args.distil_alpha_ce + sparse_args.distil_alpha_teacher > 0.0

            model_config = AutoConfig.from_pretrained(
                sparse_args.distil_teacher_name_or_path, cache_dir=cache_dir)

            teacher = self.teacher_constructor.from_pretrained(
                sparse_args.distil_teacher_name_or_path,
                from_tf=bool(
                    ".ckpt" in sparse_args.distil_teacher_name_or_path),
                config=model_config,
                cache_dir=cache_dir,
            )
            print(teacher)
            teacher.to(device)
        else:
            teacher = None

        return teacher

    def schedule_threshold(
        self,
        step: int = -1,
        total_step: int = -1,
        warmup_steps: int = -1,
        training: bool = False,
    ):
        sparse_args = self.sparse_args

        initial_threshold = sparse_args.initial_threshold
        final_threshold = sparse_args.final_threshold
        initial_warmup = sparse_args.initial_warmup
        final_warmup = sparse_args.final_warmup
        final_lambda = sparse_args.regularization_final_lambda
        initial_ampere_temperature = sparse_args.initial_ampere_temperature
        final_ampere_temperature = sparse_args.final_ampere_temperature

        if training:
            if step <= initial_warmup * warmup_steps:
                threshold = initial_threshold
                ampere_temperature = initial_ampere_temperature
            elif step > (total_step - final_warmup * warmup_steps):
                threshold = final_threshold
                ampere_temperature = final_ampere_temperature
            else:
                spars_warmup_steps = initial_warmup * warmup_steps
                spars_schedu_steps = (final_warmup +
                                      initial_warmup) * warmup_steps
                mul_coeff = 1 - (step - spars_warmup_steps) / (
                    total_step - spars_schedu_steps)
                threshold = final_threshold + (
                    initial_threshold - final_threshold) * (mul_coeff**3)
                ampere_temperature = final_ampere_temperature + (
                    initial_ampere_temperature -
                    final_ampere_temperature) * (mul_coeff**3)
        else:
            threshold = final_threshold
            ampere_temperature = final_ampere_temperature

        regu_lambda = final_lambda * threshold / final_threshold

        context_data = dict(threshold=threshold,
                            regu_lambda=regu_lambda,
                            ampere_temperature=ampere_temperature)

        def interp(a, b, interpf):
            return a * interpf + (1.0 - interpf) * b

        if hasattr(sparse_args,
                   "layer_norm_patch") and sparse_args.layer_norm_patch:
            if training:
                interpf = 0.0
                layer_norm_patch_steps = sparse_args.layer_norm_patch_steps
                if step < layer_norm_patch_steps:
                    interpf = 1.0 - (step / layer_norm_patch_steps)

                delta = interp(sparse_args.layer_norm_patch_start_delta, 1.0,
                               interpf)
                mix = interpf

                context_data["layernorm_to_nonorm_delta"] = delta
                context_data["layernorm_to_nonorm_mix"] = mix
            else:
                context_data["layernorm_to_nonorm_delta"] = 1.0
                context_data["layernorm_to_nonorm_mix"] = 0.0

        if hasattr(sparse_args, "gelu_patch") and sparse_args.gelu_patch:
            if training:
                interpf = 0.0
                gelu_patch_steps = sparse_args.gelu_patch_steps
                if step < gelu_patch_steps:
                    interpf = 1.0 - (step / gelu_patch_steps)

                context_data["gelu_to_relu_mix"] = interpf
            else:
                context_data["gelu_to_relu_mix"] = 0.0

        self.patcher_context.set_context_data_dict(context_data)

    def regularization_loss(self, model: nn.Module):
        # Return regularization, lambda, and information on the network sparsity
        mode = self.sparse_args.regularization

        info = {}

        regul_modes = ["l1", "l0"]
        if mode in regul_modes:
            threshold = self.patcher_context.get_context_data("threshold")

        for name, module in model.named_modules():
            module_regu = 0
            module_nnz_info = {"nnz": 0, "numel": 0}
            nummod = 1
            if mode not in regul_modes:
                if isinstance(module, nn.Linear):
                    weight = module.weight
                    module_nnz_info["nnz"] = (weight != 0).sum()
                    module_nnz_info["numel"] = weight.numel()
                else:
                    continue
            elif isinstance(module, GenericLinearPruningContextModule):
                module_regu = module.regularization(mode)
            elif isinstance(module, MaskedLinear):
                module_nnz_info = module.get_sparsity_info()
                nummod = 0
            else:
                continue
            # TEMPORARY : use model info to perform this dispatch
            if not hasattr(self.sparse_args, "attention_output_with_dense"
                           ) or self.sparse_args.attention_output_with_dense:
                layer_names = ["key", "query", "value"]
                key = "dense"
                for ln in layer_names:
                    if ln in name:
                        key = "attention"
            else:
                key = "attention" if "attention" in name else "dense"

            if key not in info:
                info[key] = defaultdict(float)

            key_info = info[key]
            key_info["regu"] += module_regu
            key_info["nummod"] += nummod

            for k, v in module_nnz_info.items():
                key_info[k] += float(v)

        if mode not in regul_modes:
            lamb = 0
            lambdas = dict(attention=0, dense=0)
        else:
            lamb = self.patcher_context.get_context_data("regu_lambda")

            lambdas = dict(attention=self.sparse_args.attention_lambda * 0.5,
                           dense=self.sparse_args.dense_lambda * 0.5)

        info["total"] = defaultdict(float)

        for key, value in info.items():
            if key == "total":
                continue
            for k, v in value.items():
                if k == "numel" or "nnz" in k:
                    info["total"][k] += v

        for key, value in info.items():
            if value["numel"] != 0:
                # No patching (no pruning) -> no information on nnz -> dense linear layers
                value["nnz_perc"] = value["nnz"] / value["numel"]
            else:
                value["nnz_perc"] = 1.0
            for k in "nnz", "numel":
                if k in value:
                    del value[k]
            if key == "total":
                continue
            if value["nummod"] != 0:
                value["regu_loss"] = value["regu"] * lambdas[key] / value[
                    "nummod"]
                info["total"]["regu_loss"] += value["regu_loss"]
            for k in "regu", "nummod":
                if k in value:
                    del value[k]

        return info["total"]["regu_loss"], lamb, info

    def distil_loss_combine(self, ce_loss, model_inputs, model_outputs):
        sparse_args = self.sparse_args
        teacher = self.teacher

        if teacher == None:
            return ce_loss, 0.0

        temperature = sparse_args.distil_temperature

        with torch.no_grad():
            teacher_outputs = teacher(
                input_ids=model_inputs["input_ids"],
                token_type_ids=model_inputs["token_type_ids"],
                attention_mask=model_inputs["attention_mask"],
            )

        loss_logits = 0
        for logit_name in self.logit_names:
            logits_stu = model_outputs[logit_name]
            logits_tea = teacher_outputs[logit_name]

            loss_logits_part = nn_functional.kl_div(
                input=nn_functional.log_softmax(logits_stu / temperature,
                                                dim=-1),
                target=nn_functional.softmax(logits_tea / temperature, dim=-1),
                reduction="batchmean",
            ) * (temperature**2)

            loss_logits += loss_logits_part

        loss_logits /= len(self.logit_names)

        loss = sparse_args.distil_alpha_teacher * loss_logits + sparse_args.distil_alpha_ce * ce_loss

        return loss, loss_logits

    def create_optimizer_groups(self, model, args, sparse_args):
        # Prepare optimizer and schedule (linear warmup and decay)
        no_decay = [
            "bias", "LayerNorm.weight", "NoNorm.weight", "LayerNorm.bias",
            "NoNorm.bias"
        ]

        mask_params = []
        no_decay_params = []
        decay_params = []

        for n, p in model.named_parameters():
            if not p.requires_grad:
                continue
            if "mask_score" in n:
                mask_params.append(p)
            elif any(nd in n for nd in no_decay):
                no_decay_params.append(p)
            else:
                decay_params.append(p)

        optimizer_grouped_parameters = [
            {
                "params": mask_params,
                "lr": sparse_args.mask_scores_learning_rate,
            },
            {
                "params": no_decay_params,
                "lr": args.learning_rate,
                "weight_decay": 0.0,
            },
            {
                "params": decay_params,
                "lr": args.learning_rate,
                "weight_decay": args.weight_decay,
            },
        ]

        return optimizer_grouped_parameters

    def patch_model(self, model, trial=None):
        layers_count = model.config.num_hidden_layers
        sparse_args = self.sparse_args
        attention_pruning_method_parts = self.parse_pruning_method(
            sparse_args.attention_pruning_method)

        if hasattr(sparse_args, "bias_mask"):
            bias_mask = sparse_args.bias_mask
        else:
            bias_mask = False

        if hasattr(sparse_args, "linear_min_parameters"):
            linear_min_parameters = sparse_args.linear_min_parameters
        else:
            linear_min_parameters = 0.005

        patcher_context = self.patcher_context

        if attention_pruning_method_parts[
                0] != "disabled" or sparse_args.ampere_pruning_method != "disabled":
            args_attention = LinearPruningArgs(
                method=attention_pruning_method_parts[0],
                submethod=attention_pruning_method_parts[1],
                ampere_method=sparse_args.ampere_pruning_method,
                block_rows=sparse_args.attention_block_rows,
                block_cols=sparse_args.attention_block_cols,
                bias_mask=bias_mask,
                min_elements=linear_min_parameters,
            )

            args_attention_t = LinearPruningArgs(
                method=attention_pruning_method_parts[0],
                submethod=attention_pruning_method_parts[1],
                ampere_method=sparse_args.ampere_pruning_method,
                block_rows=sparse_args.attention_block_cols,
                block_cols=sparse_args.attention_block_rows,
                bias_mask=bias_mask,
                min_elements=linear_min_parameters,
            )
            if args_attention.submethod == "joint":
                p_attention = JointPruningModulePatcher(patcher_context,
                                                        args_attention,
                                                        suffix=".attention")
                p_attention_t = JointPruningModulePatcher(patcher_context,
                                                          args_attention_t,
                                                          suffix=".attention")
            else:
                p_attention = LinearPruningModulePatcher(
                    patcher_context, args_attention)
                p_attention_t = LinearPruningModulePatcher(
                    patcher_context, args_attention_t)
        else:
            p_attention = None
            p_attention_t = None

        dense_pruning_method_parts = self.parse_pruning_method(
            sparse_args.dense_pruning_method)

        if dense_pruning_method_parts[
                0] != "disabled" or sparse_args.ampere_pruning_method != "disabled":
            args_dense = LinearPruningArgs(
                method=dense_pruning_method_parts[0],
                submethod=dense_pruning_method_parts[1],
                ampere_method=sparse_args.ampere_pruning_method,
                block_rows=sparse_args.dense_block_rows,
                block_cols=sparse_args.dense_block_cols,
                bias_mask=bias_mask,
                min_elements=linear_min_parameters,
            )
            if args_dense.submethod.startswith("1d"):
                p_dense = ChannelPruningModulePatcher(patcher_context,
                                                      args_dense,
                                                      self.MODEL_STRUCTURE,
                                                      suffix="dense")
            else:
                p_dense = LinearPruningModulePatcher(patcher_context,
                                                     args_dense)
        else:
            p_dense = None

        if not hasattr(sparse_args, "attention_output_with_dense"
                       ) or sparse_args.attention_output_with_dense:
            p_att_dense = p_dense
        else:
            p_att_dense = p_attention_t

        module_patchers = dict(
            query=p_attention,
            key=p_attention,
            value=p_attention,
            att_dense=p_att_dense,
            interm_dense=p_dense,
            output_dense=p_dense,
        )

        if hasattr(sparse_args, "layer_norm_patch"):
            layer_norm_patch = sparse_args.layer_norm_patch
        else:
            layer_norm_patch = False

        if hasattr(sparse_args, "gelu_patch"):
            gelu_patch = sparse_args.gelu_patch
        else:
            gelu_patch = False

        patcher = BertLinearModelPatcher(module_patchers)

        patcher.patch(model)

        patched_count = 0
        if attention_pruning_method_parts[
                0] != "disabled" or sparse_args.ampere_pruning_method != "disabled":
            patched_count += 4 * layers_count

        if dense_pruning_method_parts[
                0] != "disabled" or sparse_args.ampere_pruning_method != "disabled":
            patched_count += 2 * layers_count

        assert (patcher.stats["patched"] == patched_count)

        if layer_norm_patch:

            def schedule_callback():
                mix = self.patcher_context.get_context_data(
                    "layernorm_to_nonorm_mix")
                delta = self.patcher_context.get_context_data(
                    "layernorm_to_nonorm_delta")
                return dict(mix=mix, delta=delta)

            layer_norm_patcher = Layer2NoNormPatcher(
                schedule_callback=schedule_callback)
            layer_norm_patcher.patch(model)
            layer_norm_patched_count = 2 * layers_count + 1
            assert (layer_norm_patcher.stats["patched"] ==
                    layer_norm_patched_count)

        if gelu_patch:

            def schedule_callback():
                mix = self.patcher_context.get_context_data("gelu_to_relu_mix")
                return dict(mix=mix)

            gelu_patcher = GeLU2ReLUModelPatcher(
                schedule_callback=schedule_callback)
            gelu_patcher.patch(model)
            gelu_patcher_count = layers_count
            assert (gelu_patcher.stats["patched"] == gelu_patcher_count)

        return patcher

    def compile_model(self, model):
        self.schedule_threshold()
        compiler = MaskedLinearModelCompiler()
        compiler.patch(model)

        if hasattr(self.sparse_args,
                   "layer_norm_patch") and self.sparse_args.layer_norm_patch:
            nnc = NoNormCompiler()
            nnc.patch(model)
            model.config.layer_norm_type = "no_norm"

        if hasattr(self.sparse_args,
                   "gelu_patch") and self.sparse_args.gelu_patch:
            model.config.hidden_act = "relu"

        pruner = BertHeadsPruner(model)
        removed_heads, total_heads = pruner.run()
        return removed_heads, total_heads
Exemple #7
0
class ModelPatchingCoordinator:
    MODEL_STRUCTURE = BertStructure

    def __init__(self, sparse_args, device, cache_dir, logit_names,
                 teacher_constructor):
        # logit_names is ["start_logits", "end_logits"] for qa, ["logits"] for glue etc
        # teacher modle is AutoModelForQuestionAnswering for qa, AutoModelForSequenceClassification for glue etc
        self.sparse_args = sparse_args
        self.patcher_context = PatcherContext()
        self.teacher_constructor = teacher_constructor
        self.teacher = self.create_teacher(device, cache_dir)
        self.logit_names = logit_names

    def parse_pruning_method(self, method):
        parts = method.split(":")
        if len(parts) == 2:
            return parts
        elif len(parts) == 1:
            return parts[0], "default"
        else:
            raise RuntimeError("Could not parse pruning method")

    def patch_model(self, model, trial):
        raise NotImplementedError("Implement in subclass")

    def log(self):
        logs = {}
        for k, v in self.patcher_context.enumerate_context_data():
            logs[k] = v

        return logs

    def create_teacher(self, device, cache_dir):
        sparse_args = self.sparse_args

        if sparse_args.distil_teacher_name_or_path is not None:
            assert sparse_args.distil_alpha_ce > 0.0
            assert sparse_args.distil_alpha_ce + sparse_args.distil_alpha_teacher > 0.0

            model_config = AutoConfig.from_pretrained(
                sparse_args.distil_teacher_name_or_path, cache_dir=cache_dir)

            teacher = self.teacher_constructor.from_pretrained(
                sparse_args.distil_teacher_name_or_path,
                from_tf=bool(
                    ".ckpt" in sparse_args.distil_teacher_name_or_path),
                config=model_config,
                cache_dir=cache_dir,
            )
            print(teacher)
            teacher.to(device)
        else:
            teacher = None

        return teacher

    def schedule_threshold(
        self,
        step: int = -1,
        total_step: int = -1,
        warmup_steps: int = -1,
        training: bool = False,
    ):
        sparse_args = self.sparse_args

        initial_threshold = sparse_args.initial_threshold
        final_threshold = sparse_args.final_threshold
        initial_warmup = sparse_args.initial_warmup
        final_warmup = sparse_args.final_warmup
        final_lambda = sparse_args.regularization_final_lambda
        initial_ampere_temperature = sparse_args.initial_ampere_temperature
        final_ampere_temperature = sparse_args.final_ampere_temperature

        if training:
            if step <= initial_warmup * warmup_steps:
                threshold = initial_threshold
                ampere_temperature = initial_ampere_temperature
            elif step > (total_step - final_warmup * warmup_steps):
                threshold = final_threshold
                ampere_temperature = final_ampere_temperature
            else:
                spars_warmup_steps = initial_warmup * warmup_steps
                spars_schedu_steps = (final_warmup +
                                      initial_warmup) * warmup_steps
                mul_coeff = 1 - (step - spars_warmup_steps) / (
                    total_step - spars_schedu_steps)
                threshold = final_threshold + (
                    initial_threshold - final_threshold) * (mul_coeff**3)
                ampere_temperature = final_ampere_temperature + (
                    initial_ampere_temperature -
                    final_ampere_temperature) * (mul_coeff**3)
        else:
            threshold = final_threshold
            ampere_temperature = final_ampere_temperature

        regu_lambda = final_lambda * threshold / final_threshold

        context_data = dict(
            threshold=threshold,
            ampere_temperature=ampere_temperature,
            regu_lambda=regu_lambda,
        )

        self.patcher_context.set_context_data_dict(context_data)

    def regularization_loss(self, model: nn.Module):
        # Return regularization, lambda, and information on the network sparsity
        mode = self.sparse_args.regularization

        info = {}

        regul_modes = ["l1", "l0"]
        if mode in regul_modes:
            threshold = self.patcher_context.get_context_data("threshold")

        for name, module in model.named_modules():
            if mode not in regul_modes:
                if isinstance(module, nn.Linear):
                    weight = module.weight
                    module_regu = 0
                    module_nnz = (weight != 0).sum()
                    numel = weight.numel()
                else:
                    continue
            elif isinstance(module, PatcherContextModule):
                param = module.mask_scores
                numel = param.numel()

                if mode == "l1":
                    module_regu = torch.norm(torch.sigmoid(param), p=1) / numel
                    module_nnz = (torch.sigmoid(param) >
                                  threshold).sum().item()
                elif mode == "l0":
                    assert (False)
                    module_regu = torch.sigmoid(
                        param - 2 / 3 * np.log(0.1 / 1.1)).sum() / numel
                    module_nnz = (
                        torch.sigmoid(param - 2 / 3 * np.log(0.1 / 1.1)) >
                        threshold).sum().item()
                else:
                    assert (False)
            else:
                continue
            # TEMPORARY : use model info to perform this dispatch
            if not hasattr(self.sparse_args, "attention_output_with_dense"
                           ) or self.sparse_args.attention_output_with_dense:
                layer_names = ["key", "query", "value"]
                key = "dense"
                for ln in layer_names:
                    if ln in name:
                        key = "attention"
            else:
                key = "attention" if "attention" in name else "dense"

            if key not in info:
                info[key] = defaultdict(float)

            key_info = info[key]
            key_info["regu"] += module_regu
            key_info["nnz"] += float(module_nnz)
            key_info["numel"] += numel
            key_info["nummod"] += 1

        if mode not in regul_modes:
            lamb = 0
            lambdas = dict(attention=0, dense=0)
        else:
            lamb = self.patcher_context.get_context_data("regu_lambda")

            lambdas = dict(attention=self.sparse_args.attention_lambda * 0.5,
                           dense=self.sparse_args.dense_lambda * 0.5)

        info["total"] = defaultdict(float)

        for key, value in info.items():
            if key == "total":
                continue
            for k, v in value.items():
                if k in ["numel", "nnz"]:
                    info["total"][k] += v

        for key, value in info.items():
            value["nnz_perc"] = value["nnz"] / value["numel"]
            del value["nnz"]
            del value["numel"]
            if key == "total":
                continue
            value["regu_loss"] = value["regu"] * lambdas[key] / value["nummod"]
            info["total"]["regu_loss"] += value["regu_loss"]
            del value["regu"]
            del value["nummod"]

        return info["total"]["regu_loss"], lamb, info

    def distil_loss_combine(self, ce_loss, model_inputs, model_outputs):
        sparse_args = self.sparse_args
        teacher = self.teacher

        if teacher == None:
            return ce_loss

        temperature = sparse_args.distil_temperature

        with torch.no_grad():
            teacher_outputs = teacher(
                input_ids=model_inputs["input_ids"],
                token_type_ids=model_inputs["token_type_ids"],
                attention_mask=model_inputs["attention_mask"],
            )

        loss_logits = 0
        for logit_name in self.logit_names:
            logits_stu = model_outputs[logit_name]
            logits_tea = teacher_outputs[logit_name]

            loss_logits_part = nn_functional.kl_div(
                input=nn_functional.log_softmax(logits_stu / temperature,
                                                dim=-1),
                target=nn_functional.softmax(logits_tea / temperature, dim=-1),
                reduction="batchmean",
            ) * (temperature**2)

            loss_logits += loss_logits_part

        loss_logits /= len(self.logit_names)

        loss = sparse_args.distil_alpha_teacher * loss_logits + sparse_args.distil_alpha_ce * ce_loss

        return loss, loss_logits

    def create_optimizer_groups(self, model, args, sparse_args):
        # Prepare optimizer and schedule (linear warmup and decay)
        no_decay = ["bias", "LayerNorm.weight"]

        mask_params = []
        no_decay_params = []
        decay_params = []

        for n, p in model.named_parameters():
            if not p.requires_grad:
                continue
            if "mask_score" in n:
                mask_params.append(p)
            elif any(nd in n for nd in no_decay):
                no_decay_params.append(p)
            else:
                decay_params.append(p)

        optimizer_grouped_parameters = [
            {
                "params": mask_params,
                "lr": sparse_args.mask_scores_learning_rate,
            },
            {
                "params": no_decay_params,
                "lr": args.learning_rate,
                "weight_decay": 0.0,
            },
            {
                "params": decay_params,
                "lr": args.learning_rate,
                "weight_decay": args.weight_decay,
            },
        ]

        return optimizer_grouped_parameters

    def compile_model(self, model):
        self.schedule_threshold()
        compiler = MaskedLinearModelCompiler()
        compiler.patch(model)

    def patch_model(self, model, trial):
        assert trial is None or len(trial.params) == 0
        attention_pruning_method_parts = self.parse_pruning_method(
            self.sparse_args.attention_pruning_method)

        if hasattr(self.sparse_args, "bias_mask"):
            bias_mask = self.sparse_args.bias_mask
        else:
            bias_mask = False

        args_attention = LinearPruningArgs(
            method=attention_pruning_method_parts[0],
            submethod=attention_pruning_method_parts[1],
            ampere_method=self.sparse_args.ampere_pruning_method,
            block_rows=self.sparse_args.attention_block_rows,
            block_cols=self.sparse_args.attention_block_cols,
            bias_mask=bias_mask)

        args_attention_t = LinearPruningArgs(
            method=attention_pruning_method_parts[0],
            submethod=attention_pruning_method_parts[1],
            ampere_method=self.sparse_args.ampere_pruning_method,
            block_rows=self.sparse_args.attention_block_cols,
            block_cols=self.sparse_args.attention_block_rows,
            bias_mask=bias_mask)

        dense_pruning_method_parts = self.parse_pruning_method(
            self.sparse_args.dense_pruning_method)

        args_dense = LinearPruningArgs(
            method=dense_pruning_method_parts[0],
            submethod=dense_pruning_method_parts[1],
            ampere_method=self.sparse_args.ampere_pruning_method,
            block_rows=self.sparse_args.dense_block_rows,
            block_cols=self.sparse_args.dense_block_cols,
            bias_mask=bias_mask)

        patcher_context = self.patcher_context

        if args_attention.submethod == "joint":
            p_attention = JointPruningModulePatcher(patcher_context,
                                                    args_attention,
                                                    suffix=".attention")
            p_attention_t = JointPruningModulePatcher(patcher_context,
                                                      args_attention_t,
                                                      suffix=".attention")
        else:
            p_attention = LinearPruningModulePatcher(patcher_context,
                                                     args_attention)
            p_attention_t = LinearPruningModulePatcher(patcher_context,
                                                       args_attention_t)

        if args_dense.submethod.startswith("1d"):
            p_dense = ChannelPruningModulePatcher(patcher_context,
                                                  args_dense,
                                                  self.MODEL_STRUCTURE,
                                                  suffix="dense")
        else:
            p_dense = LinearPruningModulePatcher(patcher_context, args_dense)

        if not hasattr(self.sparse_args, "attention_output_with_dense"
                       ) or self.sparse_args.attention_output_with_dense:
            p_att_dense = p_dense
        else:
            p_att_dense = p_attention_t

        module_patchers = dict(
            query=p_attention,
            key=p_attention,
            value=p_attention,
            att_dense=p_att_dense,
            interm_dense=p_dense,
            output_dense=p_dense,
        )

        patcher = BertLinearModelPatcher(module_patchers)
        patcher.patch(model)
        assert ((patcher.stats["patched"] % 72) == 0)

        return patcher
class ModelPatchingCoordinator:

    def __init__(self, sparse_args, device, cache_dir, model_name_or_path, logit_names, teacher_constructor):
        # logit_names is ["start_logits", "end_logits"] for qa, ["logits"] for glue etc
        # teacher model is AutoModelForQuestionAnswering for qa, AutoModelForSequenceClassification for glue etc
        self.sparse_args = sparse_args
        self.patcher_context = PatcherContext()
        self.teacher_constructor = teacher_constructor
        self.device = device
        self.cache_dir = cache_dir
        self.teacher = None
        self.layer_head_mask = self.create_head_rewind_info(device, cache_dir)
        self.logit_names = logit_names
        self.model_name_or_path = model_name_or_path
        config = AutoConfig.from_pretrained(model_name_or_path, cache_dir=cache_dir)
        self.model_structure = struct_from_config(config.__class__)

    def parse_pruning_method(self, method):
        parts = method.split(":")
        if len(parts) == 2:
            return parts
        elif len(parts) == 1:
            return parts[0], "default"
        else:
            raise RuntimeError("Could not parse pruning method")

    def log(self):
        logs = {}
        for k, v in self.patcher_context.enumerate_context_data():
            logs[k] = v

        return logs

    def create_teacher(self):
        if self.teacher is not None:
            return self.teacher

        device = self.device
        cache_dir = self.cache_dir

        sparse_args = self.sparse_args

        if sparse_args.distil_teacher_name_or_path is not None:
            assert sparse_args.distil_alpha_ce > 0.0
            assert sparse_args.distil_alpha_ce + sparse_args.distil_alpha_teacher > 0.0

            model_config = AutoConfig.from_pretrained(sparse_args.distil_teacher_name_or_path, cache_dir=cache_dir)

            teacher = self.teacher_constructor.from_pretrained(
                sparse_args.distil_teacher_name_or_path,
                from_tf=bool(".ckpt" in sparse_args.distil_teacher_name_or_path),
                config=model_config,
                cache_dir=cache_dir,
            )
            teacher.to(device)
            self.teacher = teacher

        return self.teacher


    def create_head_rewind_info(self, device, cache_dir):
        if not hasattr(self.sparse_args, "rewind_model_name_or_path"):
            return None

        rewind_model_name_or_path = self.sparse_args.rewind_model_name_or_path
        if rewind_model_name_or_path is None:
            return None
        else:
            rewind_config = AutoConfig.from_pretrained(rewind_model_name_or_path, cache_dir=cache_dir)

            return head_mask(rewind_config, device)

    def schedule_threshold(
        self,
        step: int = -1,
        total_step: int = -1,
        warmup_steps: int = -1,
        training: bool = False,
        compile:bool = False,
    ):
        sparse_args = self.sparse_args

        initial_threshold = sparse_args.initial_threshold
        final_threshold = sparse_args.final_threshold
        initial_warmup = sparse_args.initial_warmup
        final_warmup = sparse_args.final_warmup
        final_lambda = sparse_args.regularization_final_lambda
        initial_ampere_temperature = sparse_args.initial_ampere_temperature
        final_ampere_temperature = sparse_args.final_ampere_temperature

        if not training:
            step -= 1

        eval_with_current_patch_params = (hasattr(sparse_args, "eval_with_current_patch_params") and sparse_args.eval_with_current_patch_params)
        use_scheduler = training or eval_with_current_patch_params

        if compile:
            if use_scheduler:
                base_path = Path(self.model_name_or_path)
                training_args = torch.load(str(base_path / "training_args.bin"))
                warmup_steps = training_args.warmup_steps

                with (base_path / "trainer_state.json").open() as f:
                    trainer_state = json.load(f)

                step = trainer_state["global_step"]
                total_step = trainer_state["max_steps"]

        if use_scheduler:
            if step <= initial_warmup * warmup_steps:
                mul_coeff = 1.0
                threshold = initial_threshold
                ampere_temperature = initial_ampere_temperature
            elif step > (total_step - final_warmup * warmup_steps):
                mul_coeff = 0.0
                threshold = final_threshold
                ampere_temperature = final_ampere_temperature
            else:
                spars_warmup_steps = initial_warmup * warmup_steps
                spars_schedu_steps = (final_warmup + initial_warmup) * warmup_steps
                mul_coeff = 1 - (step - spars_warmup_steps) / (total_step - spars_schedu_steps)
                threshold = final_threshold + (initial_threshold - final_threshold) * (mul_coeff ** 3)
                ampere_temperature = final_ampere_temperature + (
                    initial_ampere_temperature - final_ampere_temperature
                ) * (mul_coeff ** 3)
        else:
            mul_coeff = 0.0
            threshold = final_threshold
            ampere_temperature = final_ampere_temperature

        regu_lambda = final_lambda * threshold / final_threshold

        context_data = dict(
            threshold=threshold,
            regu_lambda=regu_lambda,
            ampere_temperature = ampere_temperature,
            progress = 1.0 - mul_coeff
        )

        def interp(a,b, interpf):
            return a * interpf + (1.0 - interpf) * b

        if hasattr(sparse_args, "layer_norm_patch") and sparse_args.layer_norm_patch:
            if use_scheduler:
                interpf = 0.0
                layer_norm_patch_steps = sparse_args.layer_norm_patch_steps
                if step < layer_norm_patch_steps:
                    interpf = 1.0 - (step / layer_norm_patch_steps)

                delta = interp(sparse_args.layer_norm_patch_start_delta, 1.0, interpf)
                mix = interpf

                context_data["layernorm_to_nonorm_delta"] = delta
                context_data["layernorm_to_nonorm_mix"] = mix
            else:
                context_data["layernorm_to_nonorm_delta"] = 1.0
                context_data["layernorm_to_nonorm_mix"] = 0.0

        if hasattr(sparse_args, "gelu_patch") and sparse_args.gelu_patch:
            if use_scheduler:
                interpf = 0.0
                gelu_patch_steps = sparse_args.gelu_patch_steps
                if step < gelu_patch_steps:
                    interpf = 1.0 - (step / gelu_patch_steps)

                context_data["gelu_to_relu_mix"] = interpf
            else:
                context_data["gelu_to_relu_mix"] = 0.0

        self.patcher_context.set_context_data_dict(context_data)

    def regularization_loss(self, model: nn.Module):
        # Return regularization, lambda, and information on the network sparsity
        mode = self.sparse_args.regularization

        info = {}

        regul_modes = ["l1", "l0"]
        if mode in regul_modes:
            threshold = self.patcher_context.get_context_data("threshold")

        for name, module in model.named_modules():
            module_regu = 0
            module_nnz_info = {"nnz":0, "numel":0}
            nummod = 1
            if mode not in regul_modes:
                if isinstance(module, nn.Linear):
                    weight = module.weight
                    module_nnz_info["nnz"] = (weight != 0).sum()
                    module_nnz_info["numel"] = weight.numel()
                else:
                    continue
            elif isinstance(module, GenericLinearPruningContextModule):
                module_regu = module.regularization(mode)
            elif isinstance(module, MaskedLinear):
                module_nnz_info = module.get_sparsity_info()
                nummod = 0
            elif hasattr(module, "regularization"):
                module_regu = module.regularization()
                if hasattr(module, "get_sparsity_info"):
                    module_nnz_info = module.get_sparsity_info()
            else:
                continue

            key = "decoder_" if self.model_structure.is_decoder(name) else ""
            exclude_att_dense = not hasattr(self.sparse_args, "attention_output_with_dense") or self.sparse_args.attention_output_with_dense
            key += "attention" if self.model_structure.is_attention(name, exclude_att_dense=exclude_att_dense) else "dense"

            if key not in info:
                info[key] = defaultdict(float)

            key_info = info[key]
            key_info["regu"] += module_regu
            key_info["nummod"] += nummod

            for k,v in module_nnz_info.items():
                key_info[k] += float(v)

        if mode not in regul_modes:
            lamb = 0
            lambdas = {k: 0 for k in info.keys()}
        else:
            lamb = self.patcher_context.get_context_data("regu_lambda")
            lambdas = {}
            n = len(info)
            for k in info.keys():
                if k.endswith('attention'):
                    if k.startswith('decoder'):
                        if self.sparse_args.decoder_attention_lambda is None:
                            self.sparse_args.decoder_attention_lambda = self.sparse_args.attention_lambda
                        lambdas[k] = self.sparse_args.decoder_attention_lambda / n
                    else:
                        lambdas[k] = self.sparse_args.attention_lambda / n
                else:
                    if k.startswith('decoder'):
                        if self.sparse_args.decoder_dense_lambda is None:
                            self.sparse_args.decoder_dense_lambda = self.sparse_args.dense_lambda
                        lambdas[k] = self.sparse_args.decoder_dense_lambda / n
                    else:
                        lambdas[k] = self.sparse_args.dense_lambda / n

        info["total"] = defaultdict(float)

        for key, value in info.items():
            if key == "total":
                continue
            for k, v in value.items():
                if k == "numel" or "nnz" in k:
                    info["total"][k] += v

        for key, value in info.items():
            if value["numel"] != 0:
                # No patching (no pruning) -> no information on nnz -> dense linear layers
                value["nnz_perc"] = value["nnz"] / value["numel"]
            else:
                value["nnz_perc"] = 1.0
            for k in "nnz", "numel":
                if k in value:
                    del value[k]
            if key == "total":
                continue
            if value["nummod"] != 0:
                value["regu_loss"] = value["regu"] * lambdas[key] / value["nummod"]
                info["total"]["regu_loss"] += value["regu_loss"]
            for k in "regu", "nummod":
                if k in value:
                    del value[k]

        return info["total"]["regu_loss"], lamb, info

    def distil_loss_combine(self, ce_loss, model_inputs, model_outputs):
        sparse_args = self.sparse_args
        teacher = self.create_teacher()

        if teacher == None:
            return ce_loss, 0.0

        temperature = sparse_args.distil_temperature

        teacher_inputs_ = model_inputs.copy()
        if 'labels' in teacher_inputs_:
            del teacher_inputs_['labels']

        teacher_inputs = {}
        for k,v in teacher_inputs_.items():
            teacher_inputs[k] = v.detach().clone()

        with torch.no_grad():
            teacher_outputs = teacher(**teacher_inputs)

        loss_logits = 0
        for logit_name in self.logit_names:
            logits_stu = model_outputs[logit_name]
            logits_tea = teacher_outputs[logit_name].detach().clone()

            loss_logits_part = nn_functional.kl_div(
                input=nn_functional.log_softmax(logits_stu / temperature, dim=-1),
                target=nn_functional.softmax(logits_tea / temperature, dim=-1),
                reduction="batchmean",
            ) * (temperature ** 2)

            loss_logits = loss_logits + loss_logits_part

        loss_logits = loss_logits / len(self.logit_names)

        loss = sparse_args.distil_alpha_teacher * loss_logits + sparse_args.distil_alpha_ce * ce_loss

        return loss, loss_logits


    def create_optimizer_groups(self, model, args, sparse_args):
        # Prepare optimizer and schedule (linear warmup and decay)
        no_decay = ["bias", "LayerNorm.weight", "NoNorm.weight", "layer_norm.weight", "layernorm_embedding.weight",
                    "final_layer_norm.weight"]
        mask_params = []
        no_decay_params = []
        decay_params = []

        for n, p in model.named_parameters():
            if not p.requires_grad:
                continue
            if "mask_score" in n:
                mask_params.append(p)
            elif any(nd in n for nd in no_decay):
                no_decay_params.append(p)
            else:
                decay_params.append(p)

        optimizer_grouped_parameters = [
            {
                "params": mask_params,
                "lr": sparse_args.mask_scores_learning_rate,
            },
            {
                "params": no_decay_params,
                "lr": args.learning_rate,
                "weight_decay": 0.0,
            },
            {
                "params": decay_params,
                "lr": args.learning_rate,
                "weight_decay": args.weight_decay,
            },
        ]

        return optimizer_grouped_parameters


    def patch_model(self, model, trial = None):
        layers_count = model.config.num_hidden_layers
        sparse_args = self.sparse_args

        device = model.device

        attention_pruning_method_parts = self.parse_pruning_method(sparse_args.attention_pruning_method)

        if hasattr(sparse_args, "bias_mask"):
            bias_mask = sparse_args.bias_mask
        else:
            bias_mask = False

        if hasattr(sparse_args, "linear_min_parameters"):
            linear_min_parameters = sparse_args.linear_min_parameters
        else:
            linear_min_parameters = 0.005

        patcher_context = self.patcher_context

        if attention_pruning_method_parts[0] != "disabled" or sparse_args.ampere_pruning_method != "disabled":
            args_attention = LinearPruningArgs(
                method=attention_pruning_method_parts[0],
                submethod=attention_pruning_method_parts[1],
                ampere_method=sparse_args.ampere_pruning_method,
                block_rows=sparse_args.attention_block_rows,
                block_cols=sparse_args.attention_block_cols,
                bias_mask=bias_mask,
                min_elements=linear_min_parameters,
            )

            args_attention_t = LinearPruningArgs(
                method=attention_pruning_method_parts[0],
                submethod=attention_pruning_method_parts[1],
                ampere_method=sparse_args.ampere_pruning_method,
                block_rows=sparse_args.attention_block_cols,
                block_cols=sparse_args.attention_block_rows,
                bias_mask=bias_mask,
                min_elements=linear_min_parameters,
            )

            if args_attention.submethod == "joint":
                p_attention = JointPruningModulePatcher(patcher_context, args_attention, model_structure=self.model_structure, suffix=".attention")
                p_attention_t = JointPruningModulePatcher(patcher_context, args_attention_t, model_structure=self.model_structure, suffix=".attention")
            else:
                p_attention = LinearPruningModulePatcher(patcher_context,
                                                         args_attention,
                                                         model_structure=self.model_structure,
                                                         row_additive_mask = self.layer_head_mask)
                p_attention_t = LinearPruningModulePatcher(patcher_context,
                                                           args_attention_t,
                                                           model_structure = self.model_structure,
                                                           col_additive_mask = self.layer_head_mask)
        else:
            p_attention = None
            p_attention_t = None

        dense_pruning_method_parts = self.parse_pruning_method(sparse_args.dense_pruning_method)

        if dense_pruning_method_parts[0] != "disabled" or sparse_args.ampere_pruning_method != "disabled":
            args_dense = LinearPruningArgs(
                method=dense_pruning_method_parts[0],
                submethod=dense_pruning_method_parts[1],
                ampere_method=sparse_args.ampere_pruning_method,
                block_rows=sparse_args.dense_block_rows,
                block_cols=sparse_args.dense_block_cols,
                bias_mask=bias_mask,
                min_elements=linear_min_parameters,
            )
            if args_dense.submethod.startswith("1d"):
                p_dense = ChannelPruningModulePatcher(
                    patcher_context, args_dense, model_structure=self.model_structure, suffix="dense"
                )
            else:
                p_dense = LinearPruningModulePatcher(patcher_context, args_dense, model_structure=self.model_structure)
        else:
            p_dense = None

        if not hasattr(sparse_args, "attention_output_with_dense") or sparse_args.attention_output_with_dense:
            p_att_dense = p_dense
        else:
            p_att_dense = p_attention_t

        module_patchers = dict(
            query=p_attention,
            key=p_attention,
            value=p_attention,
            att_dense=p_att_dense,
            encoder_decoder_query=p_attention,
            encoder_decoder_key=p_attention,
            encoder_decoder_value=p_attention,
            encoder_decoder_att_dense=p_att_dense,
            interm_dense=p_dense,
            output_dense=p_dense,
        )

        if hasattr(sparse_args, "layer_norm_patch"):
            layer_norm_patch = sparse_args.layer_norm_patch
        else:
            layer_norm_patch = False

        if hasattr(sparse_args, "gelu_patch"):
            gelu_patch = sparse_args.gelu_patch
        else:
            gelu_patch = False

        patcher = LinearModelPatcher(module_patchers, model_structure=self.model_structure)

        patcher.patch(model)
        model = model.to(device)  # TODO: change this by making sure the mask_scores are located at the right place.

        self.stats = {}
        self.stats["main"] = patcher.stats

        if layer_norm_patch:
            def schedule_callback():
                mix = self.patcher_context.get_context_data("layernorm_to_nonorm_mix")
                delta = self.patcher_context.get_context_data("layernorm_to_nonorm_delta")
                return dict(mix=mix, delta=delta)

            layer_norm_patcher = Layer2NoNormPatcher(schedule_callback=schedule_callback)
            layer_norm_patcher.patch(model)
            self.stats["layer_norm"] = layer_norm_patcher.stats

        if gelu_patch:
            def schedule_callback():
                mix = self.patcher_context.get_context_data("gelu_to_relu_mix")
                return dict(mix=mix)

            gelu_patcher = GeLU2ReLUModelPatcher(schedule_callback=schedule_callback)
            gelu_patcher.patch(model)
            self.stats["gelu"] = gelu_patcher.stats

        return patcher


    def compile_model(self, model):
        self.schedule_threshold(compile=True)
        compiler = MaskedLinearModelCompiler()
        compiler.patch(model)

        if hasattr(self.sparse_args, "layer_norm_patch") and self.sparse_args.layer_norm_patch:
            nnc = NoNormCompiler()
            nnc.patch(model)
            model.config.layer_norm_type = "no_norm"

        if hasattr(self.sparse_args, "gelu_patch") and self.sparse_args.gelu_patch:
            model.config.hidden_act = "relu"

        pruner = BertHeadsPruner(model)
        removed_heads, total_heads = pruner.run()
        return removed_heads, total_heads