class ModelPatchingCoordinator: MODEL_STRUCTURE = BertStructure def __init__(self, sparse_args, device, cache_dir, logit_names, teacher_constructor): # logit_names is ["start_logits", "end_logits"] for qa, ["logits"] for glue etc # teacher modle is AutoModelForQuestionAnswering for qa, AutoModelForSequenceClassification for glue etc self.sparse_args = sparse_args self.patcher_context = PatcherContext() self.teacher_constructor = teacher_constructor self.teacher = self.create_teacher(device, cache_dir) self.logit_names = logit_names def parse_pruning_method(self, method): parts = method.split(":") if len(parts) == 2: return parts elif len(parts) == 1: return parts[0], "default" else: raise RuntimeError("Could not parse pruning method") def patch_model(self, model, trial): raise NotImplementedError("Implement in subclass") def log(self): logs = {} for k, v in self.patcher_context.enumerate_context_data(): logs[k] = v return logs def create_teacher(self, device, cache_dir): sparse_args = self.sparse_args if sparse_args.distil_teacher_name_or_path is not None: assert sparse_args.distil_alpha_ce > 0.0 assert sparse_args.distil_alpha_ce + sparse_args.distil_alpha_teacher > 0.0 model_config = AutoConfig.from_pretrained( sparse_args.distil_teacher_name_or_path, cache_dir=cache_dir) teacher = self.teacher_constructor.from_pretrained( sparse_args.distil_teacher_name_or_path, from_tf=bool( ".ckpt" in sparse_args.distil_teacher_name_or_path), config=model_config, cache_dir=cache_dir, ) print(teacher) teacher.to(device) else: teacher = None return teacher def schedule_threshold( self, step: int = -1, total_step: int = -1, warmup_steps: int = -1, training: bool = False, ): sparse_args = self.sparse_args initial_threshold = sparse_args.initial_threshold final_threshold = sparse_args.final_threshold initial_warmup = sparse_args.initial_warmup final_warmup = sparse_args.final_warmup final_lambda = sparse_args.regularization_final_lambda initial_ampere_temperature = sparse_args.initial_ampere_temperature final_ampere_temperature = sparse_args.final_ampere_temperature if training: if step <= initial_warmup * warmup_steps: threshold = initial_threshold ampere_temperature = initial_ampere_temperature elif step > (total_step - final_warmup * warmup_steps): threshold = final_threshold ampere_temperature = final_ampere_temperature else: spars_warmup_steps = initial_warmup * warmup_steps spars_schedu_steps = (final_warmup + initial_warmup) * warmup_steps mul_coeff = 1 - (step - spars_warmup_steps) / ( total_step - spars_schedu_steps) threshold = final_threshold + ( initial_threshold - final_threshold) * (mul_coeff**3) ampere_temperature = final_ampere_temperature + ( initial_ampere_temperature - final_ampere_temperature) * (mul_coeff**3) else: threshold = final_threshold ampere_temperature = final_ampere_temperature regu_lambda = final_lambda * threshold / final_threshold context_data = dict( threshold=threshold, ampere_temperature=ampere_temperature, regu_lambda=regu_lambda, ) self.patcher_context.set_context_data_dict(context_data) def regularization_loss(self, model: nn.Module): # Return regularization, lambda, and information on the network sparsity mode = self.sparse_args.regularization info = {} regul_modes = ["l1", "l0"] if mode in regul_modes: threshold = self.patcher_context.get_context_data("threshold") for name, module in model.named_modules(): if mode not in regul_modes: if isinstance(module, nn.Linear): weight = module.weight module_regu = 0 module_nnz = (weight != 0).sum() numel = weight.numel() else: continue elif isinstance(module, PatcherContextModule): param = module.mask_scores numel = param.numel() if mode == "l1": module_regu = torch.norm(torch.sigmoid(param), p=1) / numel module_nnz = (torch.sigmoid(param) > threshold).sum().item() elif mode == "l0": assert (False) module_regu = torch.sigmoid( param - 2 / 3 * np.log(0.1 / 1.1)).sum() / numel module_nnz = ( torch.sigmoid(param - 2 / 3 * np.log(0.1 / 1.1)) > threshold).sum().item() else: assert (False) else: continue # TEMPORARY : use model info to perform this dispatch if not hasattr(self.sparse_args, "attention_output_with_dense" ) or self.sparse_args.attention_output_with_dense: layer_names = ["key", "query", "value"] key = "dense" for ln in layer_names: if ln in name: key = "attention" else: key = "attention" if "attention" in name else "dense" if key not in info: info[key] = defaultdict(float) key_info = info[key] key_info["regu"] += module_regu key_info["nnz"] += float(module_nnz) key_info["numel"] += numel key_info["nummod"] += 1 if mode not in regul_modes: lamb = 0 lambdas = dict(attention=0, dense=0) else: lamb = self.patcher_context.get_context_data("regu_lambda") lambdas = dict(attention=self.sparse_args.attention_lambda * 0.5, dense=self.sparse_args.dense_lambda * 0.5) info["total"] = defaultdict(float) for key, value in info.items(): if key == "total": continue for k, v in value.items(): if k in ["numel", "nnz"]: info["total"][k] += v for key, value in info.items(): value["nnz_perc"] = value["nnz"] / value["numel"] del value["nnz"] del value["numel"] if key == "total": continue value["regu_loss"] = value["regu"] * lambdas[key] / value["nummod"] info["total"]["regu_loss"] += value["regu_loss"] del value["regu"] del value["nummod"] return info["total"]["regu_loss"], lamb, info def distil_loss_combine(self, ce_loss, model_inputs, model_outputs): sparse_args = self.sparse_args teacher = self.teacher if teacher == None: return ce_loss temperature = sparse_args.distil_temperature with torch.no_grad(): teacher_outputs = teacher( input_ids=model_inputs["input_ids"], token_type_ids=model_inputs["token_type_ids"], attention_mask=model_inputs["attention_mask"], ) loss_logits = 0 for logit_name in self.logit_names: logits_stu = model_outputs[logit_name] logits_tea = teacher_outputs[logit_name] loss_logits_part = nn_functional.kl_div( input=nn_functional.log_softmax(logits_stu / temperature, dim=-1), target=nn_functional.softmax(logits_tea / temperature, dim=-1), reduction="batchmean", ) * (temperature**2) loss_logits += loss_logits_part loss_logits /= len(self.logit_names) loss = sparse_args.distil_alpha_teacher * loss_logits + sparse_args.distil_alpha_ce * ce_loss return loss, loss_logits def create_optimizer_groups(self, model, args, sparse_args): # Prepare optimizer and schedule (linear warmup and decay) no_decay = ["bias", "LayerNorm.weight"] mask_params = [] no_decay_params = [] decay_params = [] for n, p in model.named_parameters(): if not p.requires_grad: continue if "mask_score" in n: mask_params.append(p) elif any(nd in n for nd in no_decay): no_decay_params.append(p) else: decay_params.append(p) optimizer_grouped_parameters = [ { "params": mask_params, "lr": sparse_args.mask_scores_learning_rate, }, { "params": no_decay_params, "lr": args.learning_rate, "weight_decay": 0.0, }, { "params": decay_params, "lr": args.learning_rate, "weight_decay": args.weight_decay, }, ] return optimizer_grouped_parameters def compile_model(self, model): self.schedule_threshold() compiler = MaskedLinearModelCompiler() compiler.patch(model) def patch_model(self, model, trial): assert trial is None or len(trial.params) == 0 attention_pruning_method_parts = self.parse_pruning_method( self.sparse_args.attention_pruning_method) if hasattr(self.sparse_args, "bias_mask"): bias_mask = self.sparse_args.bias_mask else: bias_mask = False args_attention = LinearPruningArgs( method=attention_pruning_method_parts[0], submethod=attention_pruning_method_parts[1], ampere_method=self.sparse_args.ampere_pruning_method, block_rows=self.sparse_args.attention_block_rows, block_cols=self.sparse_args.attention_block_cols, bias_mask=bias_mask) args_attention_t = LinearPruningArgs( method=attention_pruning_method_parts[0], submethod=attention_pruning_method_parts[1], ampere_method=self.sparse_args.ampere_pruning_method, block_rows=self.sparse_args.attention_block_cols, block_cols=self.sparse_args.attention_block_rows, bias_mask=bias_mask) dense_pruning_method_parts = self.parse_pruning_method( self.sparse_args.dense_pruning_method) args_dense = LinearPruningArgs( method=dense_pruning_method_parts[0], submethod=dense_pruning_method_parts[1], ampere_method=self.sparse_args.ampere_pruning_method, block_rows=self.sparse_args.dense_block_rows, block_cols=self.sparse_args.dense_block_cols, bias_mask=bias_mask) patcher_context = self.patcher_context if args_attention.submethod == "joint": p_attention = JointPruningModulePatcher(patcher_context, args_attention, suffix=".attention") p_attention_t = JointPruningModulePatcher(patcher_context, args_attention_t, suffix=".attention") else: p_attention = LinearPruningModulePatcher(patcher_context, args_attention) p_attention_t = LinearPruningModulePatcher(patcher_context, args_attention_t) if args_dense.submethod.startswith("1d"): p_dense = ChannelPruningModulePatcher(patcher_context, args_dense, self.MODEL_STRUCTURE, suffix="dense") else: p_dense = LinearPruningModulePatcher(patcher_context, args_dense) if not hasattr(self.sparse_args, "attention_output_with_dense" ) or self.sparse_args.attention_output_with_dense: p_att_dense = p_dense else: p_att_dense = p_attention_t module_patchers = dict( query=p_attention, key=p_attention, value=p_attention, att_dense=p_att_dense, interm_dense=p_dense, output_dense=p_dense, ) patcher = BertLinearModelPatcher(module_patchers) patcher.patch(model) assert ((patcher.stats["patched"] % 72) == 0) return patcher
class ModelPatchingCoordinator: MODEL_STRUCTURE = BertStructure def __init__(self, sparse_args, device, cache_dir, logit_names, teacher_constructor): # logit_names is ["start_logits", "end_logits"] for qa, ["logits"] for glue etc # teacher model is AutoModelForQuestionAnswering for qa, AutoModelForSequenceClassification for glue etc self.sparse_args = sparse_args self.patcher_context = PatcherContext() self.teacher_constructor = teacher_constructor self.teacher = self.create_teacher(device, cache_dir) self.logit_names = logit_names def parse_pruning_method(self, method): parts = method.split(":") if len(parts) == 2: return parts elif len(parts) == 1: return parts[0], "default" else: raise RuntimeError("Could not parse pruning method") def log(self): logs = {} for k, v in self.patcher_context.enumerate_context_data(): logs[k] = v return logs def create_teacher(self, device, cache_dir): sparse_args = self.sparse_args if sparse_args.distil_teacher_name_or_path is not None: assert sparse_args.distil_alpha_ce > 0.0 assert sparse_args.distil_alpha_ce + sparse_args.distil_alpha_teacher > 0.0 model_config = AutoConfig.from_pretrained( sparse_args.distil_teacher_name_or_path, cache_dir=cache_dir) teacher = self.teacher_constructor.from_pretrained( sparse_args.distil_teacher_name_or_path, from_tf=bool( ".ckpt" in sparse_args.distil_teacher_name_or_path), config=model_config, cache_dir=cache_dir, ) print(teacher) teacher.to(device) else: teacher = None return teacher def schedule_threshold( self, step: int = -1, total_step: int = -1, warmup_steps: int = -1, training: bool = False, ): sparse_args = self.sparse_args initial_threshold = sparse_args.initial_threshold final_threshold = sparse_args.final_threshold initial_warmup = sparse_args.initial_warmup final_warmup = sparse_args.final_warmup final_lambda = sparse_args.regularization_final_lambda initial_ampere_temperature = sparse_args.initial_ampere_temperature final_ampere_temperature = sparse_args.final_ampere_temperature if training: if step <= initial_warmup * warmup_steps: threshold = initial_threshold ampere_temperature = initial_ampere_temperature elif step > (total_step - final_warmup * warmup_steps): threshold = final_threshold ampere_temperature = final_ampere_temperature else: spars_warmup_steps = initial_warmup * warmup_steps spars_schedu_steps = (final_warmup + initial_warmup) * warmup_steps mul_coeff = 1 - (step - spars_warmup_steps) / ( total_step - spars_schedu_steps) threshold = final_threshold + ( initial_threshold - final_threshold) * (mul_coeff**3) ampere_temperature = final_ampere_temperature + ( initial_ampere_temperature - final_ampere_temperature) * (mul_coeff**3) else: threshold = final_threshold ampere_temperature = final_ampere_temperature regu_lambda = final_lambda * threshold / final_threshold context_data = dict(threshold=threshold, regu_lambda=regu_lambda, ampere_temperature=ampere_temperature) def interp(a, b, interpf): return a * interpf + (1.0 - interpf) * b if hasattr(sparse_args, "layer_norm_patch") and sparse_args.layer_norm_patch: if training: interpf = 0.0 layer_norm_patch_steps = sparse_args.layer_norm_patch_steps if step < layer_norm_patch_steps: interpf = 1.0 - (step / layer_norm_patch_steps) delta = interp(sparse_args.layer_norm_patch_start_delta, 1.0, interpf) mix = interpf context_data["layernorm_to_nonorm_delta"] = delta context_data["layernorm_to_nonorm_mix"] = mix else: context_data["layernorm_to_nonorm_delta"] = 1.0 context_data["layernorm_to_nonorm_mix"] = 0.0 if hasattr(sparse_args, "gelu_patch") and sparse_args.gelu_patch: if training: interpf = 0.0 gelu_patch_steps = sparse_args.gelu_patch_steps if step < gelu_patch_steps: interpf = 1.0 - (step / gelu_patch_steps) context_data["gelu_to_relu_mix"] = interpf else: context_data["gelu_to_relu_mix"] = 0.0 self.patcher_context.set_context_data_dict(context_data) def regularization_loss(self, model: nn.Module): # Return regularization, lambda, and information on the network sparsity mode = self.sparse_args.regularization info = {} regul_modes = ["l1", "l0"] if mode in regul_modes: threshold = self.patcher_context.get_context_data("threshold") for name, module in model.named_modules(): module_regu = 0 module_nnz_info = {"nnz": 0, "numel": 0} nummod = 1 if mode not in regul_modes: if isinstance(module, nn.Linear): weight = module.weight module_nnz_info["nnz"] = (weight != 0).sum() module_nnz_info["numel"] = weight.numel() else: continue elif isinstance(module, GenericLinearPruningContextModule): module_regu = module.regularization(mode) elif isinstance(module, MaskedLinear): module_nnz_info = module.get_sparsity_info() nummod = 0 else: continue # TEMPORARY : use model info to perform this dispatch if not hasattr(self.sparse_args, "attention_output_with_dense" ) or self.sparse_args.attention_output_with_dense: layer_names = ["key", "query", "value"] key = "dense" for ln in layer_names: if ln in name: key = "attention" else: key = "attention" if "attention" in name else "dense" if key not in info: info[key] = defaultdict(float) key_info = info[key] key_info["regu"] += module_regu key_info["nummod"] += nummod for k, v in module_nnz_info.items(): key_info[k] += float(v) if mode not in regul_modes: lamb = 0 lambdas = dict(attention=0, dense=0) else: lamb = self.patcher_context.get_context_data("regu_lambda") lambdas = dict(attention=self.sparse_args.attention_lambda * 0.5, dense=self.sparse_args.dense_lambda * 0.5) info["total"] = defaultdict(float) for key, value in info.items(): if key == "total": continue for k, v in value.items(): if k == "numel" or "nnz" in k: info["total"][k] += v for key, value in info.items(): if value["numel"] != 0: # No patching (no pruning) -> no information on nnz -> dense linear layers value["nnz_perc"] = value["nnz"] / value["numel"] else: value["nnz_perc"] = 1.0 for k in "nnz", "numel": if k in value: del value[k] if key == "total": continue if value["nummod"] != 0: value["regu_loss"] = value["regu"] * lambdas[key] / value[ "nummod"] info["total"]["regu_loss"] += value["regu_loss"] for k in "regu", "nummod": if k in value: del value[k] return info["total"]["regu_loss"], lamb, info def distil_loss_combine(self, ce_loss, model_inputs, model_outputs): sparse_args = self.sparse_args teacher = self.teacher if teacher == None: return ce_loss, 0.0 temperature = sparse_args.distil_temperature with torch.no_grad(): teacher_outputs = teacher( input_ids=model_inputs["input_ids"], token_type_ids=model_inputs["token_type_ids"], attention_mask=model_inputs["attention_mask"], ) loss_logits = 0 for logit_name in self.logit_names: logits_stu = model_outputs[logit_name] logits_tea = teacher_outputs[logit_name] loss_logits_part = nn_functional.kl_div( input=nn_functional.log_softmax(logits_stu / temperature, dim=-1), target=nn_functional.softmax(logits_tea / temperature, dim=-1), reduction="batchmean", ) * (temperature**2) loss_logits += loss_logits_part loss_logits /= len(self.logit_names) loss = sparse_args.distil_alpha_teacher * loss_logits + sparse_args.distil_alpha_ce * ce_loss return loss, loss_logits def create_optimizer_groups(self, model, args, sparse_args): # Prepare optimizer and schedule (linear warmup and decay) no_decay = [ "bias", "LayerNorm.weight", "NoNorm.weight", "LayerNorm.bias", "NoNorm.bias" ] mask_params = [] no_decay_params = [] decay_params = [] for n, p in model.named_parameters(): if not p.requires_grad: continue if "mask_score" in n: mask_params.append(p) elif any(nd in n for nd in no_decay): no_decay_params.append(p) else: decay_params.append(p) optimizer_grouped_parameters = [ { "params": mask_params, "lr": sparse_args.mask_scores_learning_rate, }, { "params": no_decay_params, "lr": args.learning_rate, "weight_decay": 0.0, }, { "params": decay_params, "lr": args.learning_rate, "weight_decay": args.weight_decay, }, ] return optimizer_grouped_parameters def patch_model(self, model, trial=None): layers_count = model.config.num_hidden_layers sparse_args = self.sparse_args attention_pruning_method_parts = self.parse_pruning_method( sparse_args.attention_pruning_method) if hasattr(sparse_args, "bias_mask"): bias_mask = sparse_args.bias_mask else: bias_mask = False if hasattr(sparse_args, "linear_min_parameters"): linear_min_parameters = sparse_args.linear_min_parameters else: linear_min_parameters = 0.005 patcher_context = self.patcher_context if attention_pruning_method_parts[ 0] != "disabled" or sparse_args.ampere_pruning_method != "disabled": args_attention = LinearPruningArgs( method=attention_pruning_method_parts[0], submethod=attention_pruning_method_parts[1], ampere_method=sparse_args.ampere_pruning_method, block_rows=sparse_args.attention_block_rows, block_cols=sparse_args.attention_block_cols, bias_mask=bias_mask, min_elements=linear_min_parameters, ) args_attention_t = LinearPruningArgs( method=attention_pruning_method_parts[0], submethod=attention_pruning_method_parts[1], ampere_method=sparse_args.ampere_pruning_method, block_rows=sparse_args.attention_block_cols, block_cols=sparse_args.attention_block_rows, bias_mask=bias_mask, min_elements=linear_min_parameters, ) if args_attention.submethod == "joint": p_attention = JointPruningModulePatcher(patcher_context, args_attention, suffix=".attention") p_attention_t = JointPruningModulePatcher(patcher_context, args_attention_t, suffix=".attention") else: p_attention = LinearPruningModulePatcher( patcher_context, args_attention) p_attention_t = LinearPruningModulePatcher( patcher_context, args_attention_t) else: p_attention = None p_attention_t = None dense_pruning_method_parts = self.parse_pruning_method( sparse_args.dense_pruning_method) if dense_pruning_method_parts[ 0] != "disabled" or sparse_args.ampere_pruning_method != "disabled": args_dense = LinearPruningArgs( method=dense_pruning_method_parts[0], submethod=dense_pruning_method_parts[1], ampere_method=sparse_args.ampere_pruning_method, block_rows=sparse_args.dense_block_rows, block_cols=sparse_args.dense_block_cols, bias_mask=bias_mask, min_elements=linear_min_parameters, ) if args_dense.submethod.startswith("1d"): p_dense = ChannelPruningModulePatcher(patcher_context, args_dense, self.MODEL_STRUCTURE, suffix="dense") else: p_dense = LinearPruningModulePatcher(patcher_context, args_dense) else: p_dense = None if not hasattr(sparse_args, "attention_output_with_dense" ) or sparse_args.attention_output_with_dense: p_att_dense = p_dense else: p_att_dense = p_attention_t module_patchers = dict( query=p_attention, key=p_attention, value=p_attention, att_dense=p_att_dense, interm_dense=p_dense, output_dense=p_dense, ) if hasattr(sparse_args, "layer_norm_patch"): layer_norm_patch = sparse_args.layer_norm_patch else: layer_norm_patch = False if hasattr(sparse_args, "gelu_patch"): gelu_patch = sparse_args.gelu_patch else: gelu_patch = False patcher = BertLinearModelPatcher(module_patchers) patcher.patch(model) patched_count = 0 if attention_pruning_method_parts[ 0] != "disabled" or sparse_args.ampere_pruning_method != "disabled": patched_count += 4 * layers_count if dense_pruning_method_parts[ 0] != "disabled" or sparse_args.ampere_pruning_method != "disabled": patched_count += 2 * layers_count assert (patcher.stats["patched"] == patched_count) if layer_norm_patch: def schedule_callback(): mix = self.patcher_context.get_context_data( "layernorm_to_nonorm_mix") delta = self.patcher_context.get_context_data( "layernorm_to_nonorm_delta") return dict(mix=mix, delta=delta) layer_norm_patcher = Layer2NoNormPatcher( schedule_callback=schedule_callback) layer_norm_patcher.patch(model) layer_norm_patched_count = 2 * layers_count + 1 assert (layer_norm_patcher.stats["patched"] == layer_norm_patched_count) if gelu_patch: def schedule_callback(): mix = self.patcher_context.get_context_data("gelu_to_relu_mix") return dict(mix=mix) gelu_patcher = GeLU2ReLUModelPatcher( schedule_callback=schedule_callback) gelu_patcher.patch(model) gelu_patcher_count = layers_count assert (gelu_patcher.stats["patched"] == gelu_patcher_count) return patcher def compile_model(self, model): self.schedule_threshold() compiler = MaskedLinearModelCompiler() compiler.patch(model) if hasattr(self.sparse_args, "layer_norm_patch") and self.sparse_args.layer_norm_patch: nnc = NoNormCompiler() nnc.patch(model) model.config.layer_norm_type = "no_norm" if hasattr(self.sparse_args, "gelu_patch") and self.sparse_args.gelu_patch: model.config.hidden_act = "relu" pruner = BertHeadsPruner(model) removed_heads, total_heads = pruner.run() return removed_heads, total_heads
class ModelPatchingCoordinator: def __init__(self, sparse_args, device, cache_dir, model_name_or_path, logit_names, teacher_constructor): # logit_names is ["start_logits", "end_logits"] for qa, ["logits"] for glue etc # teacher model is AutoModelForQuestionAnswering for qa, AutoModelForSequenceClassification for glue etc self.sparse_args = sparse_args self.patcher_context = PatcherContext() self.teacher_constructor = teacher_constructor self.device = device self.cache_dir = cache_dir self.teacher = None self.layer_head_mask = self.create_head_rewind_info(device, cache_dir) self.logit_names = logit_names self.model_name_or_path = model_name_or_path config = AutoConfig.from_pretrained(model_name_or_path, cache_dir=cache_dir) self.model_structure = struct_from_config(config.__class__) def parse_pruning_method(self, method): parts = method.split(":") if len(parts) == 2: return parts elif len(parts) == 1: return parts[0], "default" else: raise RuntimeError("Could not parse pruning method") def log(self): logs = {} for k, v in self.patcher_context.enumerate_context_data(): logs[k] = v return logs def create_teacher(self): if self.teacher is not None: return self.teacher device = self.device cache_dir = self.cache_dir sparse_args = self.sparse_args if sparse_args.distil_teacher_name_or_path is not None: assert sparse_args.distil_alpha_ce > 0.0 assert sparse_args.distil_alpha_ce + sparse_args.distil_alpha_teacher > 0.0 model_config = AutoConfig.from_pretrained(sparse_args.distil_teacher_name_or_path, cache_dir=cache_dir) teacher = self.teacher_constructor.from_pretrained( sparse_args.distil_teacher_name_or_path, from_tf=bool(".ckpt" in sparse_args.distil_teacher_name_or_path), config=model_config, cache_dir=cache_dir, ) teacher.to(device) self.teacher = teacher return self.teacher def create_head_rewind_info(self, device, cache_dir): if not hasattr(self.sparse_args, "rewind_model_name_or_path"): return None rewind_model_name_or_path = self.sparse_args.rewind_model_name_or_path if rewind_model_name_or_path is None: return None else: rewind_config = AutoConfig.from_pretrained(rewind_model_name_or_path, cache_dir=cache_dir) return head_mask(rewind_config, device) def schedule_threshold( self, step: int = -1, total_step: int = -1, warmup_steps: int = -1, training: bool = False, compile:bool = False, ): sparse_args = self.sparse_args initial_threshold = sparse_args.initial_threshold final_threshold = sparse_args.final_threshold initial_warmup = sparse_args.initial_warmup final_warmup = sparse_args.final_warmup final_lambda = sparse_args.regularization_final_lambda initial_ampere_temperature = sparse_args.initial_ampere_temperature final_ampere_temperature = sparse_args.final_ampere_temperature if not training: step -= 1 eval_with_current_patch_params = (hasattr(sparse_args, "eval_with_current_patch_params") and sparse_args.eval_with_current_patch_params) use_scheduler = training or eval_with_current_patch_params if compile: if use_scheduler: base_path = Path(self.model_name_or_path) training_args = torch.load(str(base_path / "training_args.bin")) warmup_steps = training_args.warmup_steps with (base_path / "trainer_state.json").open() as f: trainer_state = json.load(f) step = trainer_state["global_step"] total_step = trainer_state["max_steps"] if use_scheduler: if step <= initial_warmup * warmup_steps: mul_coeff = 1.0 threshold = initial_threshold ampere_temperature = initial_ampere_temperature elif step > (total_step - final_warmup * warmup_steps): mul_coeff = 0.0 threshold = final_threshold ampere_temperature = final_ampere_temperature else: spars_warmup_steps = initial_warmup * warmup_steps spars_schedu_steps = (final_warmup + initial_warmup) * warmup_steps mul_coeff = 1 - (step - spars_warmup_steps) / (total_step - spars_schedu_steps) threshold = final_threshold + (initial_threshold - final_threshold) * (mul_coeff ** 3) ampere_temperature = final_ampere_temperature + ( initial_ampere_temperature - final_ampere_temperature ) * (mul_coeff ** 3) else: mul_coeff = 0.0 threshold = final_threshold ampere_temperature = final_ampere_temperature regu_lambda = final_lambda * threshold / final_threshold context_data = dict( threshold=threshold, regu_lambda=regu_lambda, ampere_temperature = ampere_temperature, progress = 1.0 - mul_coeff ) def interp(a,b, interpf): return a * interpf + (1.0 - interpf) * b if hasattr(sparse_args, "layer_norm_patch") and sparse_args.layer_norm_patch: if use_scheduler: interpf = 0.0 layer_norm_patch_steps = sparse_args.layer_norm_patch_steps if step < layer_norm_patch_steps: interpf = 1.0 - (step / layer_norm_patch_steps) delta = interp(sparse_args.layer_norm_patch_start_delta, 1.0, interpf) mix = interpf context_data["layernorm_to_nonorm_delta"] = delta context_data["layernorm_to_nonorm_mix"] = mix else: context_data["layernorm_to_nonorm_delta"] = 1.0 context_data["layernorm_to_nonorm_mix"] = 0.0 if hasattr(sparse_args, "gelu_patch") and sparse_args.gelu_patch: if use_scheduler: interpf = 0.0 gelu_patch_steps = sparse_args.gelu_patch_steps if step < gelu_patch_steps: interpf = 1.0 - (step / gelu_patch_steps) context_data["gelu_to_relu_mix"] = interpf else: context_data["gelu_to_relu_mix"] = 0.0 self.patcher_context.set_context_data_dict(context_data) def regularization_loss(self, model: nn.Module): # Return regularization, lambda, and information on the network sparsity mode = self.sparse_args.regularization info = {} regul_modes = ["l1", "l0"] if mode in regul_modes: threshold = self.patcher_context.get_context_data("threshold") for name, module in model.named_modules(): module_regu = 0 module_nnz_info = {"nnz":0, "numel":0} nummod = 1 if mode not in regul_modes: if isinstance(module, nn.Linear): weight = module.weight module_nnz_info["nnz"] = (weight != 0).sum() module_nnz_info["numel"] = weight.numel() else: continue elif isinstance(module, GenericLinearPruningContextModule): module_regu = module.regularization(mode) elif isinstance(module, MaskedLinear): module_nnz_info = module.get_sparsity_info() nummod = 0 elif hasattr(module, "regularization"): module_regu = module.regularization() if hasattr(module, "get_sparsity_info"): module_nnz_info = module.get_sparsity_info() else: continue key = "decoder_" if self.model_structure.is_decoder(name) else "" exclude_att_dense = not hasattr(self.sparse_args, "attention_output_with_dense") or self.sparse_args.attention_output_with_dense key += "attention" if self.model_structure.is_attention(name, exclude_att_dense=exclude_att_dense) else "dense" if key not in info: info[key] = defaultdict(float) key_info = info[key] key_info["regu"] += module_regu key_info["nummod"] += nummod for k,v in module_nnz_info.items(): key_info[k] += float(v) if mode not in regul_modes: lamb = 0 lambdas = {k: 0 for k in info.keys()} else: lamb = self.patcher_context.get_context_data("regu_lambda") lambdas = {} n = len(info) for k in info.keys(): if k.endswith('attention'): if k.startswith('decoder'): if self.sparse_args.decoder_attention_lambda is None: self.sparse_args.decoder_attention_lambda = self.sparse_args.attention_lambda lambdas[k] = self.sparse_args.decoder_attention_lambda / n else: lambdas[k] = self.sparse_args.attention_lambda / n else: if k.startswith('decoder'): if self.sparse_args.decoder_dense_lambda is None: self.sparse_args.decoder_dense_lambda = self.sparse_args.dense_lambda lambdas[k] = self.sparse_args.decoder_dense_lambda / n else: lambdas[k] = self.sparse_args.dense_lambda / n info["total"] = defaultdict(float) for key, value in info.items(): if key == "total": continue for k, v in value.items(): if k == "numel" or "nnz" in k: info["total"][k] += v for key, value in info.items(): if value["numel"] != 0: # No patching (no pruning) -> no information on nnz -> dense linear layers value["nnz_perc"] = value["nnz"] / value["numel"] else: value["nnz_perc"] = 1.0 for k in "nnz", "numel": if k in value: del value[k] if key == "total": continue if value["nummod"] != 0: value["regu_loss"] = value["regu"] * lambdas[key] / value["nummod"] info["total"]["regu_loss"] += value["regu_loss"] for k in "regu", "nummod": if k in value: del value[k] return info["total"]["regu_loss"], lamb, info def distil_loss_combine(self, ce_loss, model_inputs, model_outputs): sparse_args = self.sparse_args teacher = self.create_teacher() if teacher == None: return ce_loss, 0.0 temperature = sparse_args.distil_temperature teacher_inputs_ = model_inputs.copy() if 'labels' in teacher_inputs_: del teacher_inputs_['labels'] teacher_inputs = {} for k,v in teacher_inputs_.items(): teacher_inputs[k] = v.detach().clone() with torch.no_grad(): teacher_outputs = teacher(**teacher_inputs) loss_logits = 0 for logit_name in self.logit_names: logits_stu = model_outputs[logit_name] logits_tea = teacher_outputs[logit_name].detach().clone() loss_logits_part = nn_functional.kl_div( input=nn_functional.log_softmax(logits_stu / temperature, dim=-1), target=nn_functional.softmax(logits_tea / temperature, dim=-1), reduction="batchmean", ) * (temperature ** 2) loss_logits = loss_logits + loss_logits_part loss_logits = loss_logits / len(self.logit_names) loss = sparse_args.distil_alpha_teacher * loss_logits + sparse_args.distil_alpha_ce * ce_loss return loss, loss_logits def create_optimizer_groups(self, model, args, sparse_args): # Prepare optimizer and schedule (linear warmup and decay) no_decay = ["bias", "LayerNorm.weight", "NoNorm.weight", "layer_norm.weight", "layernorm_embedding.weight", "final_layer_norm.weight"] mask_params = [] no_decay_params = [] decay_params = [] for n, p in model.named_parameters(): if not p.requires_grad: continue if "mask_score" in n: mask_params.append(p) elif any(nd in n for nd in no_decay): no_decay_params.append(p) else: decay_params.append(p) optimizer_grouped_parameters = [ { "params": mask_params, "lr": sparse_args.mask_scores_learning_rate, }, { "params": no_decay_params, "lr": args.learning_rate, "weight_decay": 0.0, }, { "params": decay_params, "lr": args.learning_rate, "weight_decay": args.weight_decay, }, ] return optimizer_grouped_parameters def patch_model(self, model, trial = None): layers_count = model.config.num_hidden_layers sparse_args = self.sparse_args device = model.device attention_pruning_method_parts = self.parse_pruning_method(sparse_args.attention_pruning_method) if hasattr(sparse_args, "bias_mask"): bias_mask = sparse_args.bias_mask else: bias_mask = False if hasattr(sparse_args, "linear_min_parameters"): linear_min_parameters = sparse_args.linear_min_parameters else: linear_min_parameters = 0.005 patcher_context = self.patcher_context if attention_pruning_method_parts[0] != "disabled" or sparse_args.ampere_pruning_method != "disabled": args_attention = LinearPruningArgs( method=attention_pruning_method_parts[0], submethod=attention_pruning_method_parts[1], ampere_method=sparse_args.ampere_pruning_method, block_rows=sparse_args.attention_block_rows, block_cols=sparse_args.attention_block_cols, bias_mask=bias_mask, min_elements=linear_min_parameters, ) args_attention_t = LinearPruningArgs( method=attention_pruning_method_parts[0], submethod=attention_pruning_method_parts[1], ampere_method=sparse_args.ampere_pruning_method, block_rows=sparse_args.attention_block_cols, block_cols=sparse_args.attention_block_rows, bias_mask=bias_mask, min_elements=linear_min_parameters, ) if args_attention.submethod == "joint": p_attention = JointPruningModulePatcher(patcher_context, args_attention, model_structure=self.model_structure, suffix=".attention") p_attention_t = JointPruningModulePatcher(patcher_context, args_attention_t, model_structure=self.model_structure, suffix=".attention") else: p_attention = LinearPruningModulePatcher(patcher_context, args_attention, model_structure=self.model_structure, row_additive_mask = self.layer_head_mask) p_attention_t = LinearPruningModulePatcher(patcher_context, args_attention_t, model_structure = self.model_structure, col_additive_mask = self.layer_head_mask) else: p_attention = None p_attention_t = None dense_pruning_method_parts = self.parse_pruning_method(sparse_args.dense_pruning_method) if dense_pruning_method_parts[0] != "disabled" or sparse_args.ampere_pruning_method != "disabled": args_dense = LinearPruningArgs( method=dense_pruning_method_parts[0], submethod=dense_pruning_method_parts[1], ampere_method=sparse_args.ampere_pruning_method, block_rows=sparse_args.dense_block_rows, block_cols=sparse_args.dense_block_cols, bias_mask=bias_mask, min_elements=linear_min_parameters, ) if args_dense.submethod.startswith("1d"): p_dense = ChannelPruningModulePatcher( patcher_context, args_dense, model_structure=self.model_structure, suffix="dense" ) else: p_dense = LinearPruningModulePatcher(patcher_context, args_dense, model_structure=self.model_structure) else: p_dense = None if not hasattr(sparse_args, "attention_output_with_dense") or sparse_args.attention_output_with_dense: p_att_dense = p_dense else: p_att_dense = p_attention_t module_patchers = dict( query=p_attention, key=p_attention, value=p_attention, att_dense=p_att_dense, encoder_decoder_query=p_attention, encoder_decoder_key=p_attention, encoder_decoder_value=p_attention, encoder_decoder_att_dense=p_att_dense, interm_dense=p_dense, output_dense=p_dense, ) if hasattr(sparse_args, "layer_norm_patch"): layer_norm_patch = sparse_args.layer_norm_patch else: layer_norm_patch = False if hasattr(sparse_args, "gelu_patch"): gelu_patch = sparse_args.gelu_patch else: gelu_patch = False patcher = LinearModelPatcher(module_patchers, model_structure=self.model_structure) patcher.patch(model) model = model.to(device) # TODO: change this by making sure the mask_scores are located at the right place. self.stats = {} self.stats["main"] = patcher.stats if layer_norm_patch: def schedule_callback(): mix = self.patcher_context.get_context_data("layernorm_to_nonorm_mix") delta = self.patcher_context.get_context_data("layernorm_to_nonorm_delta") return dict(mix=mix, delta=delta) layer_norm_patcher = Layer2NoNormPatcher(schedule_callback=schedule_callback) layer_norm_patcher.patch(model) self.stats["layer_norm"] = layer_norm_patcher.stats if gelu_patch: def schedule_callback(): mix = self.patcher_context.get_context_data("gelu_to_relu_mix") return dict(mix=mix) gelu_patcher = GeLU2ReLUModelPatcher(schedule_callback=schedule_callback) gelu_patcher.patch(model) self.stats["gelu"] = gelu_patcher.stats return patcher def compile_model(self, model): self.schedule_threshold(compile=True) compiler = MaskedLinearModelCompiler() compiler.patch(model) if hasattr(self.sparse_args, "layer_norm_patch") and self.sparse_args.layer_norm_patch: nnc = NoNormCompiler() nnc.patch(model) model.config.layer_norm_type = "no_norm" if hasattr(self.sparse_args, "gelu_patch") and self.sparse_args.gelu_patch: model.config.hidden_act = "relu" pruner = BertHeadsPruner(model) removed_heads, total_heads = pruner.run() return removed_heads, total_heads