def test_patch_module_tied_attention(self): config = BertConfig.from_pretrained("bert-base-uncased") model = BertForQuestionAnswering(config) parameters = LinearPruningParameters( method="topK", submethod="default", ampere_method="annealing", block_rows=32, block_cols=32, ) context = PatcherContext() p_attention = JointPruningModulePatcher(context, parameters, "attention") p_dense = LinearPruningModulePatcher(context, parameters) module_patchers = dict( query=p_attention, key=p_attention, value=p_attention, att_dense=p_dense, interm_dense=p_dense, output_dense=p_dense, ) patcher = BertLinearModelPatcher(module_patchers) patcher.patch(model) self.assertEqual(patcher.stats["patched"], 72) key_sizes = {k: len(v) for k, v in context.context_modules.items()} self.assertEqual(key_sizes, {"ampere_mask": 72, "mask": 48})
def patch_model(self, model, trial=None): layers_count = model.config.num_hidden_layers sparse_args = self.sparse_args attention_pruning_method_parts = self.parse_pruning_method( sparse_args.attention_pruning_method) if hasattr(sparse_args, "bias_mask"): bias_mask = sparse_args.bias_mask else: bias_mask = False if hasattr(sparse_args, "linear_min_parameters"): linear_min_parameters = sparse_args.linear_min_parameters else: linear_min_parameters = 0.005 patcher_context = self.patcher_context if attention_pruning_method_parts[ 0] != "disabled" or sparse_args.ampere_pruning_method != "disabled": args_attention = LinearPruningArgs( method=attention_pruning_method_parts[0], submethod=attention_pruning_method_parts[1], ampere_method=sparse_args.ampere_pruning_method, block_rows=sparse_args.attention_block_rows, block_cols=sparse_args.attention_block_cols, bias_mask=bias_mask, min_elements=linear_min_parameters, ) args_attention_t = LinearPruningArgs( method=attention_pruning_method_parts[0], submethod=attention_pruning_method_parts[1], ampere_method=sparse_args.ampere_pruning_method, block_rows=sparse_args.attention_block_cols, block_cols=sparse_args.attention_block_rows, bias_mask=bias_mask, min_elements=linear_min_parameters, ) if args_attention.submethod == "joint": p_attention = JointPruningModulePatcher(patcher_context, args_attention, suffix=".attention") p_attention_t = JointPruningModulePatcher(patcher_context, args_attention_t, suffix=".attention") else: p_attention = LinearPruningModulePatcher( patcher_context, args_attention) p_attention_t = LinearPruningModulePatcher( patcher_context, args_attention_t) else: p_attention = None p_attention_t = None dense_pruning_method_parts = self.parse_pruning_method( sparse_args.dense_pruning_method) if dense_pruning_method_parts[ 0] != "disabled" or sparse_args.ampere_pruning_method != "disabled": args_dense = LinearPruningArgs( method=dense_pruning_method_parts[0], submethod=dense_pruning_method_parts[1], ampere_method=sparse_args.ampere_pruning_method, block_rows=sparse_args.dense_block_rows, block_cols=sparse_args.dense_block_cols, bias_mask=bias_mask, min_elements=linear_min_parameters, ) if args_dense.submethod.startswith("1d"): p_dense = ChannelPruningModulePatcher(patcher_context, args_dense, self.MODEL_STRUCTURE, suffix="dense") else: p_dense = LinearPruningModulePatcher(patcher_context, args_dense) else: p_dense = None if not hasattr(sparse_args, "attention_output_with_dense" ) or sparse_args.attention_output_with_dense: p_att_dense = p_dense else: p_att_dense = p_attention_t module_patchers = dict( query=p_attention, key=p_attention, value=p_attention, att_dense=p_att_dense, interm_dense=p_dense, output_dense=p_dense, ) if hasattr(sparse_args, "layer_norm_patch"): layer_norm_patch = sparse_args.layer_norm_patch else: layer_norm_patch = False if hasattr(sparse_args, "gelu_patch"): gelu_patch = sparse_args.gelu_patch else: gelu_patch = False patcher = BertLinearModelPatcher(module_patchers) patcher.patch(model) patched_count = 0 if attention_pruning_method_parts[ 0] != "disabled" or sparse_args.ampere_pruning_method != "disabled": patched_count += 4 * layers_count if dense_pruning_method_parts[ 0] != "disabled" or sparse_args.ampere_pruning_method != "disabled": patched_count += 2 * layers_count assert (patcher.stats["patched"] == patched_count) if layer_norm_patch: def schedule_callback(): mix = self.patcher_context.get_context_data( "layernorm_to_nonorm_mix") delta = self.patcher_context.get_context_data( "layernorm_to_nonorm_delta") return dict(mix=mix, delta=delta) layer_norm_patcher = Layer2NoNormPatcher( schedule_callback=schedule_callback) layer_norm_patcher.patch(model) layer_norm_patched_count = 2 * layers_count + 1 assert (layer_norm_patcher.stats["patched"] == layer_norm_patched_count) if gelu_patch: def schedule_callback(): mix = self.patcher_context.get_context_data("gelu_to_relu_mix") return dict(mix=mix) gelu_patcher = GeLU2ReLUModelPatcher( schedule_callback=schedule_callback) gelu_patcher.patch(model) gelu_patcher_count = layers_count assert (gelu_patcher.stats["patched"] == gelu_patcher_count) return patcher
def test_base(self): config = BertConfig.from_pretrained("bert-base-uncased") model = BertForQuestionAnswering(config) patcher = BertLinearModelPatcher({}) layers = patcher.get_patchable_layers(model)
def patch_model(self, model, trial): assert trial is None or len(trial.params) == 0 attention_pruning_method_parts = self.parse_pruning_method( self.sparse_args.attention_pruning_method) if hasattr(self.sparse_args, "bias_mask"): bias_mask = self.sparse_args.bias_mask else: bias_mask = False args_attention = LinearPruningArgs( method=attention_pruning_method_parts[0], submethod=attention_pruning_method_parts[1], ampere_method=self.sparse_args.ampere_pruning_method, block_rows=self.sparse_args.attention_block_rows, block_cols=self.sparse_args.attention_block_cols, bias_mask=bias_mask) args_attention_t = LinearPruningArgs( method=attention_pruning_method_parts[0], submethod=attention_pruning_method_parts[1], ampere_method=self.sparse_args.ampere_pruning_method, block_rows=self.sparse_args.attention_block_cols, block_cols=self.sparse_args.attention_block_rows, bias_mask=bias_mask) dense_pruning_method_parts = self.parse_pruning_method( self.sparse_args.dense_pruning_method) args_dense = LinearPruningArgs( method=dense_pruning_method_parts[0], submethod=dense_pruning_method_parts[1], ampere_method=self.sparse_args.ampere_pruning_method, block_rows=self.sparse_args.dense_block_rows, block_cols=self.sparse_args.dense_block_cols, bias_mask=bias_mask) patcher_context = self.patcher_context if args_attention.submethod == "joint": p_attention = JointPruningModulePatcher(patcher_context, args_attention, suffix=".attention") p_attention_t = JointPruningModulePatcher(patcher_context, args_attention_t, suffix=".attention") else: p_attention = LinearPruningModulePatcher(patcher_context, args_attention) p_attention_t = LinearPruningModulePatcher(patcher_context, args_attention_t) if args_dense.submethod.startswith("1d"): p_dense = ChannelPruningModulePatcher(patcher_context, args_dense, self.MODEL_STRUCTURE, suffix="dense") else: p_dense = LinearPruningModulePatcher(patcher_context, args_dense) if not hasattr(self.sparse_args, "attention_output_with_dense" ) or self.sparse_args.attention_output_with_dense: p_att_dense = p_dense else: p_att_dense = p_attention_t module_patchers = dict( query=p_attention, key=p_attention, value=p_attention, att_dense=p_att_dense, interm_dense=p_dense, output_dense=p_dense, ) patcher = BertLinearModelPatcher(module_patchers) patcher.patch(model) assert ((patcher.stats["patched"] % 72) == 0) return patcher