def sparsify_model(self): """ Sparsify all linear layers in encoder as well as the word embedding layer. """ encoder = self.encoder sparsity = self.config.sparsity device = self.device # Use `getattr` here for backwards compatibility for configs without this param. sparsify_all_embeddings = getattr(self.config, "sparsify_all_embeddings", False) def get_sparsity(name): if isinstance(sparsity, dict): if name in sparsity: return sparsity[name] else: raise KeyError( f"Layer {name} not included in sparsity dict.") else: return sparsity # Perform model surgery by replacing the linear layers with `SparseWeights`. linear_modules = filter_modules(encoder, include_modules=[torch.nn.Linear]) for name, module in linear_modules.items(): layer_sparsity = get_sparsity("bert.encoder." + name) sparse_module = SparseWeights( module, sparsity=layer_sparsity, allow_extremes=True # this allows the model to start fully dense ) set_module_attr(self.encoder, name, sparse_module.to(device)) # Replace the embedding layers in a similar fashion. if sparsify_all_embeddings: embeddings = [ "word_embeddings", "position_embeddings", "token_type_embeddings" ] else: embeddings = ["word_embeddings"] for embedding_name in embeddings: dense_module = getattr(self.embeddings, embedding_name) layer_sparsity = get_sparsity(f"bert.embeddings.{embedding_name}") sparse_module = SparseEmbeddings(dense_module, sparsity=layer_sparsity) setattr(self.embeddings, embedding_name, sparse_module.to(device))
def sparsify_model(self): """ Sparsify all linear layers in encoder. """ encoder = self.encoder sparsity = self.config.sparsity device = self.device # Perform model surgery by replacing the linear layers with `SparseWeights`. linear_modules = filter_modules(encoder, include_modules=[torch.nn.Linear]) for name, module in linear_modules.items(): sparse_module = SparseWeights(module, sparsity=sparsity).to(device) set_module_attr(self.encoder, name, sparse_module)
def resize_model_buffers(model, state_dict): """ Resizes the models buffers by initializing a zero tensor matching the same size as that within the state_dict. """ for name, init_buffer in list(model.named_buffers()): if name not in state_dict: continue saved_buffer = state_dict[name] new_buffer = torch.zeros( saved_buffer.shape, dtype=init_buffer.dtype, layout=init_buffer.layout, device=init_buffer.device, ) set_module_attr(model, name, new_buffer)
def sparsify_model(self): """ Sparsify all linear layers in encoder as well as the word embedding layer. """ encoder = self.encoder sparsity = self.config.sparsity device = self.device # Perform model surgery by replacing the linear layers with `SparseWeights`. linear_modules = filter_modules(encoder, include_modules=[torch.nn.Linear]) for name, module in linear_modules.items(): sparse_module = SparseWeights(module, sparsity=sparsity).to(device) set_module_attr(self.encoder, name, sparse_module) # Replace the embedding layer in a similar fashion. dense_embeddings = self.embeddings.word_embeddings sparse_embeddings = SparseEmbeddings(dense_embeddings, sparsity=sparsity) self.embeddings.word_embeddings = sparse_embeddings
def resize_position_embeddings(model, new_seq_length): """ Resizes model's position embeddings matrices if the size of max position embedding doesn't match new sequence length. (size of position embedding equals size of the attention window) :param new_seq_length: Tokenizer sequence length. """ position_embeddings = filter_modules( model, include_patterns=[".*position_embeddings.*"]) for module_name, module in position_embeddings.items(): original_embed_data = module.weight.data max_position_embeddings, embed_hidden_size = original_embed_data.size() if max_position_embeddings != new_seq_length: new_embed = torch.nn.Embedding(new_seq_length, embed_hidden_size) new_embed.weight.data[:, :] = original_embed_data[: new_seq_length, :] set_module_attr(model, module_name, new_embed) return model