def test_patch_module_tied_attention(self): config = BertConfig.from_pretrained("bert-base-uncased") model = BertForQuestionAnswering(config) parameters = LinearPruningParameters( method="topK", submethod="default", ampere_method="annealing", block_rows=32, block_cols=32, ) context = PatcherContext() p_attention = JointPruningModulePatcher(context, parameters, "attention") p_dense = LinearPruningModulePatcher(context, parameters) module_patchers = dict( query=p_attention, key=p_attention, value=p_attention, att_dense=p_dense, interm_dense=p_dense, output_dense=p_dense, ) patcher = BertLinearModelPatcher(module_patchers) patcher.patch(model) self.assertEqual(patcher.stats["patched"], 72) key_sizes = {k: len(v) for k, v in context.context_modules.items()} self.assertEqual(key_sizes, {"ampere_mask": 72, "mask": 48})
def test_patch_module_ampere(self): config = BertConfig.from_pretrained("bert-base-uncased") model = BertForQuestionAnswering(config) parameters = LinearPruningArgs( method="topK", submethod="default", ampere_method="annealing", block_rows=32, block_cols=32, min_elements=0.005, ) context = PatcherContext() p = LinearPruningModulePatcher(context, parameters, self.MODEL_STRUCTURE) module_patchers = dict(query=p, key=p, value=p, att_dense=p, interm_dense=p, output_dense=p) patcher = LinearModelPatcher(module_patchers, self.MODEL_STRUCTURE) patcher.patch(model) self.assertEqual(patcher.stats["patched"], 72) key_sizes = {k: len(v) for k, v in context.context_modules.items()} self.assertEqual(key_sizes, {"ampere_mask": 72, "mask": 72})
def __init__(self, sparse_args, device, cache_dir, logit_names, teacher_constructor): # logit_names is ["start_logits", "end_logits"] for qa, ["logits"] for glue etc # teacher model is AutoModelForQuestionAnswering for qa, AutoModelForSequenceClassification for glue etc self.sparse_args = sparse_args self.patcher_context = PatcherContext() self.teacher_constructor = teacher_constructor self.teacher = self.create_teacher(device, cache_dir) self.logit_names = logit_names
def __init__(self, sparse_args, device, cache_dir, model_name_or_path, logit_names, teacher_constructor): # logit_names is ["start_logits", "end_logits"] for qa, ["logits"] for glue etc # teacher model is AutoModelForQuestionAnswering for qa, AutoModelForSequenceClassification for glue etc self.sparse_args = sparse_args self.patcher_context = PatcherContext() self.teacher_constructor = teacher_constructor self.device = device self.cache_dir = cache_dir self.teacher = None self.layer_head_mask = self.create_head_rewind_info(device, cache_dir) self.logit_names = logit_names self.model_name_or_path = model_name_or_path config = AutoConfig.from_pretrained(model_name_or_path, cache_dir=cache_dir) self.model_structure = struct_from_config(config.__class__)
def test_patch_tiedattention_line_pruning(self): config = BertConfig.from_pretrained("bert-base-uncased") model = BertForQuestionAnswering(config) parameters_attention = LinearPruningArgs( method="topK", submethod="default", ampere_method="annealing", block_rows=32, block_cols=32, min_elements=0.005, ) parameters_dense = LinearPruningArgs( method="topK", submethod="1d", ampere_method="annealing", block_rows=32, block_cols=32, min_elements=0.005 ) context = PatcherContext() p_attention = JointPruningModulePatcher(context, parameters_attention, self.MODEL_STRUCTURE, suffix=".attention") p_dense = ChannelPruningModulePatcher(context, parameters_dense, self.MODEL_STRUCTURE, suffix="dense") module_patchers = dict( query=p_attention, key=p_attention, value=p_attention, att_dense=p_dense, interm_dense=p_dense, output_dense=p_dense, ) patcher = LinearModelPatcher(module_patchers, self.MODEL_STRUCTURE) patcher.patch(model) self.assertEqual(patcher.stats["patched"], 72) key_sizes = {k: len(v) for k, v in context.context_modules.items()} for k, v in key_sizes.items(): print(k, v) for k, v in context.context_modules.items(): print(k, v) self.assertEqual(key_sizes, {"ampere_mask": 72, "mask": 12, "mask_1d": 48})