Пример #1
0
    def sparsify_model(self):
        """
        Sparsify all linear layers in encoder as well as the word embedding layer.
        """

        encoder = self.encoder
        sparsity = self.config.sparsity
        device = self.device

        # Use `getattr` here for backwards compatibility for configs without this param.
        sparsify_all_embeddings = getattr(self.config,
                                          "sparsify_all_embeddings", False)

        def get_sparsity(name):
            if isinstance(sparsity, dict):
                if name in sparsity:
                    return sparsity[name]
                else:
                    raise KeyError(
                        f"Layer {name} not included in sparsity dict.")
            else:
                return sparsity

        # Perform model surgery by replacing the linear layers with `SparseWeights`.
        linear_modules = filter_modules(encoder,
                                        include_modules=[torch.nn.Linear])
        for name, module in linear_modules.items():
            layer_sparsity = get_sparsity("bert.encoder." + name)
            sparse_module = SparseWeights(
                module,
                sparsity=layer_sparsity,
                allow_extremes=True  # this allows the model to start fully dense
            )
            set_module_attr(self.encoder, name, sparse_module.to(device))

        # Replace the embedding layers in a similar fashion.
        if sparsify_all_embeddings:
            embeddings = [
                "word_embeddings", "position_embeddings",
                "token_type_embeddings"
            ]
        else:
            embeddings = ["word_embeddings"]

        for embedding_name in embeddings:
            dense_module = getattr(self.embeddings, embedding_name)
            layer_sparsity = get_sparsity(f"bert.embeddings.{embedding_name}")
            sparse_module = SparseEmbeddings(dense_module,
                                             sparsity=layer_sparsity)
            setattr(self.embeddings, embedding_name, sparse_module.to(device))
Пример #2
0
    def sparsify_model(self):
        """
        Sparsify all linear layers in encoder.
        """

        encoder = self.encoder
        sparsity = self.config.sparsity
        device = self.device

        # Perform model surgery by replacing the linear layers with `SparseWeights`.
        linear_modules = filter_modules(encoder,
                                        include_modules=[torch.nn.Linear])
        for name, module in linear_modules.items():
            sparse_module = SparseWeights(module, sparsity=sparsity).to(device)
            set_module_attr(self.encoder, name, sparse_module)
Пример #3
0
def resize_model_buffers(model, state_dict):
    """
    Resizes the models buffers by initializing a zero tensor
    matching the same size as that within the state_dict.
    """

    for name, init_buffer in list(model.named_buffers()):

        if name not in state_dict:
            continue

        saved_buffer = state_dict[name]
        new_buffer = torch.zeros(
            saved_buffer.shape,
            dtype=init_buffer.dtype,
            layout=init_buffer.layout,
            device=init_buffer.device,
        )

        set_module_attr(model, name, new_buffer)
Пример #4
0
    def sparsify_model(self):
        """
        Sparsify all linear layers in encoder as well as the word embedding layer.
        """

        encoder = self.encoder
        sparsity = self.config.sparsity
        device = self.device

        # Perform model surgery by replacing the linear layers with `SparseWeights`.
        linear_modules = filter_modules(encoder,
                                        include_modules=[torch.nn.Linear])
        for name, module in linear_modules.items():
            sparse_module = SparseWeights(module, sparsity=sparsity).to(device)
            set_module_attr(self.encoder, name, sparse_module)

        # Replace the embedding layer in a similar fashion.
        dense_embeddings = self.embeddings.word_embeddings
        sparse_embeddings = SparseEmbeddings(dense_embeddings,
                                             sparsity=sparsity)
        self.embeddings.word_embeddings = sparse_embeddings
Пример #5
0
def resize_position_embeddings(model, new_seq_length):
    """
    Resizes model's position embeddings matrices if the size of max position embedding
    doesn't match new sequence length.
    (size of position embedding equals size of the attention window)

    :param new_seq_length: Tokenizer sequence length.
    """

    position_embeddings = filter_modules(
        model, include_patterns=[".*position_embeddings.*"])
    for module_name, module in position_embeddings.items():
        original_embed_data = module.weight.data
        max_position_embeddings, embed_hidden_size = original_embed_data.size()
        if max_position_embeddings != new_seq_length:
            new_embed = torch.nn.Embedding(new_seq_length, embed_hidden_size)
            new_embed.weight.data[:, :] = original_embed_data[:
                                                              new_seq_length, :]
            set_module_attr(model, module_name, new_embed)

    return model