예제 #1
0
    def test_patch_module_tied_attention(self):
        config = BertConfig.from_pretrained("bert-base-uncased")
        model = BertForQuestionAnswering(config)

        parameters = LinearPruningParameters(
            method="topK",
            submethod="default",
            ampere_method="annealing",
            block_rows=32,
            block_cols=32,
        )

        context = PatcherContext()

        p_attention = JointPruningModulePatcher(context, parameters, "attention")
        p_dense = LinearPruningModulePatcher(context, parameters)

        module_patchers = dict(
            query=p_attention,
            key=p_attention,
            value=p_attention,
            att_dense=p_dense,
            interm_dense=p_dense,
            output_dense=p_dense,
        )

        patcher = BertLinearModelPatcher(module_patchers)
        patcher.patch(model)

        self.assertEqual(patcher.stats["patched"], 72)
        key_sizes = {k: len(v) for k, v in context.context_modules.items()}

        self.assertEqual(key_sizes, {"ampere_mask": 72, "mask": 48})
예제 #2
0
    def test_patch_module_ampere(self):
        config = BertConfig.from_pretrained("bert-base-uncased")
        model = BertForQuestionAnswering(config)

        parameters = LinearPruningArgs(
            method="topK",
            submethod="default",
            ampere_method="annealing",
            block_rows=32,
            block_cols=32,
            min_elements=0.005,
        )

        context = PatcherContext()

        p = LinearPruningModulePatcher(context, parameters, self.MODEL_STRUCTURE)

        module_patchers = dict(query=p, key=p, value=p, att_dense=p, interm_dense=p, output_dense=p)

        patcher = LinearModelPatcher(module_patchers, self.MODEL_STRUCTURE)
        patcher.patch(model)

        self.assertEqual(patcher.stats["patched"], 72)
        key_sizes = {k: len(v) for k, v in context.context_modules.items()}

        self.assertEqual(key_sizes, {"ampere_mask": 72, "mask": 72})
예제 #3
0
 def __init__(self, sparse_args, device, cache_dir, logit_names,
              teacher_constructor):
     # logit_names is ["start_logits", "end_logits"] for qa, ["logits"] for glue etc
     # teacher model is AutoModelForQuestionAnswering for qa, AutoModelForSequenceClassification for glue etc
     self.sparse_args = sparse_args
     self.patcher_context = PatcherContext()
     self.teacher_constructor = teacher_constructor
     self.teacher = self.create_teacher(device, cache_dir)
     self.logit_names = logit_names
예제 #4
0
 def __init__(self, sparse_args, device, cache_dir, model_name_or_path, logit_names, teacher_constructor):
     # logit_names is ["start_logits", "end_logits"] for qa, ["logits"] for glue etc
     # teacher model is AutoModelForQuestionAnswering for qa, AutoModelForSequenceClassification for glue etc
     self.sparse_args = sparse_args
     self.patcher_context = PatcherContext()
     self.teacher_constructor = teacher_constructor
     self.device = device
     self.cache_dir = cache_dir
     self.teacher = None
     self.layer_head_mask = self.create_head_rewind_info(device, cache_dir)
     self.logit_names = logit_names
     self.model_name_or_path = model_name_or_path
     config = AutoConfig.from_pretrained(model_name_or_path, cache_dir=cache_dir)
     self.model_structure = struct_from_config(config.__class__)
예제 #5
0
    def test_patch_tiedattention_line_pruning(self):
        config = BertConfig.from_pretrained("bert-base-uncased")
        model = BertForQuestionAnswering(config)

        parameters_attention = LinearPruningArgs(
            method="topK",
            submethod="default",
            ampere_method="annealing",
            block_rows=32,
            block_cols=32,
            min_elements=0.005,
        )

        parameters_dense = LinearPruningArgs(
            method="topK", submethod="1d", ampere_method="annealing", block_rows=32, block_cols=32, min_elements=0.005
        )

        context = PatcherContext()

        p_attention = JointPruningModulePatcher(context, parameters_attention, self.MODEL_STRUCTURE, suffix=".attention")
        p_dense = ChannelPruningModulePatcher(context, parameters_dense, self.MODEL_STRUCTURE, suffix="dense")

        module_patchers = dict(
            query=p_attention,
            key=p_attention,
            value=p_attention,
            att_dense=p_dense,
            interm_dense=p_dense,
            output_dense=p_dense,
        )

        patcher = LinearModelPatcher(module_patchers, self.MODEL_STRUCTURE)
        patcher.patch(model)

        self.assertEqual(patcher.stats["patched"], 72)
        key_sizes = {k: len(v) for k, v in context.context_modules.items()}

        for k, v in key_sizes.items():
            print(k, v)

        for k, v in context.context_modules.items():
            print(k, v)
        self.assertEqual(key_sizes, {"ampere_mask": 72, "mask": 12, "mask_1d": 48})