Esempio n. 1
0
    def test_patch_module_tied_attention(self):
        config = BertConfig.from_pretrained("bert-base-uncased")
        model = BertForQuestionAnswering(config)

        parameters = LinearPruningParameters(
            method="topK",
            submethod="default",
            ampere_method="annealing",
            block_rows=32,
            block_cols=32,
        )

        context = PatcherContext()

        p_attention = JointPruningModulePatcher(context, parameters, "attention")
        p_dense = LinearPruningModulePatcher(context, parameters)

        module_patchers = dict(
            query=p_attention,
            key=p_attention,
            value=p_attention,
            att_dense=p_dense,
            interm_dense=p_dense,
            output_dense=p_dense,
        )

        patcher = BertLinearModelPatcher(module_patchers)
        patcher.patch(model)

        self.assertEqual(patcher.stats["patched"], 72)
        key_sizes = {k: len(v) for k, v in context.context_modules.items()}

        self.assertEqual(key_sizes, {"ampere_mask": 72, "mask": 48})
Esempio n. 2
0
    def patch_model(self, model, trial=None):
        layers_count = model.config.num_hidden_layers
        sparse_args = self.sparse_args
        attention_pruning_method_parts = self.parse_pruning_method(
            sparse_args.attention_pruning_method)

        if hasattr(sparse_args, "bias_mask"):
            bias_mask = sparse_args.bias_mask
        else:
            bias_mask = False

        if hasattr(sparse_args, "linear_min_parameters"):
            linear_min_parameters = sparse_args.linear_min_parameters
        else:
            linear_min_parameters = 0.005

        patcher_context = self.patcher_context

        if attention_pruning_method_parts[
                0] != "disabled" or sparse_args.ampere_pruning_method != "disabled":
            args_attention = LinearPruningArgs(
                method=attention_pruning_method_parts[0],
                submethod=attention_pruning_method_parts[1],
                ampere_method=sparse_args.ampere_pruning_method,
                block_rows=sparse_args.attention_block_rows,
                block_cols=sparse_args.attention_block_cols,
                bias_mask=bias_mask,
                min_elements=linear_min_parameters,
            )

            args_attention_t = LinearPruningArgs(
                method=attention_pruning_method_parts[0],
                submethod=attention_pruning_method_parts[1],
                ampere_method=sparse_args.ampere_pruning_method,
                block_rows=sparse_args.attention_block_cols,
                block_cols=sparse_args.attention_block_rows,
                bias_mask=bias_mask,
                min_elements=linear_min_parameters,
            )
            if args_attention.submethod == "joint":
                p_attention = JointPruningModulePatcher(patcher_context,
                                                        args_attention,
                                                        suffix=".attention")
                p_attention_t = JointPruningModulePatcher(patcher_context,
                                                          args_attention_t,
                                                          suffix=".attention")
            else:
                p_attention = LinearPruningModulePatcher(
                    patcher_context, args_attention)
                p_attention_t = LinearPruningModulePatcher(
                    patcher_context, args_attention_t)
        else:
            p_attention = None
            p_attention_t = None

        dense_pruning_method_parts = self.parse_pruning_method(
            sparse_args.dense_pruning_method)

        if dense_pruning_method_parts[
                0] != "disabled" or sparse_args.ampere_pruning_method != "disabled":
            args_dense = LinearPruningArgs(
                method=dense_pruning_method_parts[0],
                submethod=dense_pruning_method_parts[1],
                ampere_method=sparse_args.ampere_pruning_method,
                block_rows=sparse_args.dense_block_rows,
                block_cols=sparse_args.dense_block_cols,
                bias_mask=bias_mask,
                min_elements=linear_min_parameters,
            )
            if args_dense.submethod.startswith("1d"):
                p_dense = ChannelPruningModulePatcher(patcher_context,
                                                      args_dense,
                                                      self.MODEL_STRUCTURE,
                                                      suffix="dense")
            else:
                p_dense = LinearPruningModulePatcher(patcher_context,
                                                     args_dense)
        else:
            p_dense = None

        if not hasattr(sparse_args, "attention_output_with_dense"
                       ) or sparse_args.attention_output_with_dense:
            p_att_dense = p_dense
        else:
            p_att_dense = p_attention_t

        module_patchers = dict(
            query=p_attention,
            key=p_attention,
            value=p_attention,
            att_dense=p_att_dense,
            interm_dense=p_dense,
            output_dense=p_dense,
        )

        if hasattr(sparse_args, "layer_norm_patch"):
            layer_norm_patch = sparse_args.layer_norm_patch
        else:
            layer_norm_patch = False

        if hasattr(sparse_args, "gelu_patch"):
            gelu_patch = sparse_args.gelu_patch
        else:
            gelu_patch = False

        patcher = BertLinearModelPatcher(module_patchers)

        patcher.patch(model)

        patched_count = 0
        if attention_pruning_method_parts[
                0] != "disabled" or sparse_args.ampere_pruning_method != "disabled":
            patched_count += 4 * layers_count

        if dense_pruning_method_parts[
                0] != "disabled" or sparse_args.ampere_pruning_method != "disabled":
            patched_count += 2 * layers_count

        assert (patcher.stats["patched"] == patched_count)

        if layer_norm_patch:

            def schedule_callback():
                mix = self.patcher_context.get_context_data(
                    "layernorm_to_nonorm_mix")
                delta = self.patcher_context.get_context_data(
                    "layernorm_to_nonorm_delta")
                return dict(mix=mix, delta=delta)

            layer_norm_patcher = Layer2NoNormPatcher(
                schedule_callback=schedule_callback)
            layer_norm_patcher.patch(model)
            layer_norm_patched_count = 2 * layers_count + 1
            assert (layer_norm_patcher.stats["patched"] ==
                    layer_norm_patched_count)

        if gelu_patch:

            def schedule_callback():
                mix = self.patcher_context.get_context_data("gelu_to_relu_mix")
                return dict(mix=mix)

            gelu_patcher = GeLU2ReLUModelPatcher(
                schedule_callback=schedule_callback)
            gelu_patcher.patch(model)
            gelu_patcher_count = layers_count
            assert (gelu_patcher.stats["patched"] == gelu_patcher_count)

        return patcher
Esempio n. 3
0
    def patch_model(self, model, trial):
        assert trial is None or len(trial.params) == 0
        attention_pruning_method_parts = self.parse_pruning_method(
            self.sparse_args.attention_pruning_method)

        if hasattr(self.sparse_args, "bias_mask"):
            bias_mask = self.sparse_args.bias_mask
        else:
            bias_mask = False

        args_attention = LinearPruningArgs(
            method=attention_pruning_method_parts[0],
            submethod=attention_pruning_method_parts[1],
            ampere_method=self.sparse_args.ampere_pruning_method,
            block_rows=self.sparse_args.attention_block_rows,
            block_cols=self.sparse_args.attention_block_cols,
            bias_mask=bias_mask)

        args_attention_t = LinearPruningArgs(
            method=attention_pruning_method_parts[0],
            submethod=attention_pruning_method_parts[1],
            ampere_method=self.sparse_args.ampere_pruning_method,
            block_rows=self.sparse_args.attention_block_cols,
            block_cols=self.sparse_args.attention_block_rows,
            bias_mask=bias_mask)

        dense_pruning_method_parts = self.parse_pruning_method(
            self.sparse_args.dense_pruning_method)

        args_dense = LinearPruningArgs(
            method=dense_pruning_method_parts[0],
            submethod=dense_pruning_method_parts[1],
            ampere_method=self.sparse_args.ampere_pruning_method,
            block_rows=self.sparse_args.dense_block_rows,
            block_cols=self.sparse_args.dense_block_cols,
            bias_mask=bias_mask)

        patcher_context = self.patcher_context

        if args_attention.submethod == "joint":
            p_attention = JointPruningModulePatcher(patcher_context,
                                                    args_attention,
                                                    suffix=".attention")
            p_attention_t = JointPruningModulePatcher(patcher_context,
                                                      args_attention_t,
                                                      suffix=".attention")
        else:
            p_attention = LinearPruningModulePatcher(patcher_context,
                                                     args_attention)
            p_attention_t = LinearPruningModulePatcher(patcher_context,
                                                       args_attention_t)

        if args_dense.submethod.startswith("1d"):
            p_dense = ChannelPruningModulePatcher(patcher_context,
                                                  args_dense,
                                                  self.MODEL_STRUCTURE,
                                                  suffix="dense")
        else:
            p_dense = LinearPruningModulePatcher(patcher_context, args_dense)

        if not hasattr(self.sparse_args, "attention_output_with_dense"
                       ) or self.sparse_args.attention_output_with_dense:
            p_att_dense = p_dense
        else:
            p_att_dense = p_attention_t

        module_patchers = dict(
            query=p_attention,
            key=p_attention,
            value=p_attention,
            att_dense=p_att_dense,
            interm_dense=p_dense,
            output_dense=p_dense,
        )

        patcher = BertLinearModelPatcher(module_patchers)
        patcher.patch(model)
        assert ((patcher.stats["patched"] % 72) == 0)

        return patcher