Exemplo n.º 1
0
    def prune_heads(self, heads):
        if len(heads) == 0:
            return
        mask = torch.ones(self.self.num_attention_heads,
                          self.self.attention_head_size)
        heads = set(
            heads
        ) - self.pruned_heads  # Convert to set and remove already pruned heads
        for head in heads:
            # Compute how many pruned heads are before the head and move the index accordingly
            head = head - sum(1 if h < head else 0 for h in self.pruned_heads)
            mask[head] = 0
        mask = mask.view(-1).contiguous().eq(1)
        index = torch.arange(len(mask))[mask].long()

        # Prune linear layers
        self.self.query = prune_linear_layer(self.self.query, index)
        self.self.key = prune_linear_layer(self.self.key, index)
        self.self.value = prune_linear_layer(self.self.value, index)
        self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)

        # Update hyper params and store pruned heads
        self.self.num_attention_heads = self.self.num_attention_heads - len(
            heads)
        self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
        self.pruned_heads = self.pruned_heads.union(heads)
Exemplo n.º 2
0
 def prune_heads(self, heads):
     attention_head_size = self.dim // self.n_heads
     if len(heads) == 0:
         return
     heads, index = find_pruneable_heads_and_indices(
         heads, self.n_heads, attention_head_size, self.pruned_heads)
     # Prune linear layers
     self.q_lin = prune_linear_layer(self.q_lin, index)
     self.k_lin = prune_linear_layer(self.k_lin, index)
     self.v_lin = prune_linear_layer(self.v_lin, index)
     self.out_lin = prune_linear_layer(self.out_lin, index, dim=1)
     # Update hyper params
     self.n_heads = self.n_heads - len(heads)
     self.dim = attention_head_size * self.n_heads
     self.pruned_heads = self.pruned_heads.union(heads)
Exemplo n.º 3
0
    def prune_heads(self, heads):
        if len(heads) == 0:
            return
        heads, index = find_pruneable_heads_and_indices(
            heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
        )

        # Prune linear layers
        self.self.query = prune_linear_layer(self.self.query, index)
        self.self.key = prune_linear_layer(self.self.key, index)
        self.self.value = prune_linear_layer(self.self.value, index)
        self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)

        # Update hyper params and store pruned heads
        self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
        self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
        self.pruned_heads = self.pruned_heads.union(heads)
Exemplo n.º 4
0
 def prune_heads(self, heads):
     if len(heads) == 0:
         return
     mask = torch.ones(self.n_heads, self.d_kv)
     heads = set(heads) - self.pruned_heads
     for head in heads:
         head -= sum(1 if h < head else 0 for h in self.pruned_heads)
         mask[head] = 0
     mask = mask.view(-1).contiguous().eq(1)
     index = torch.arange(len(mask))[mask].long()
     # Prune linear layers
     self.q = prune_linear_layer(self.q, index)
     self.k = prune_linear_layer(self.k, index)
     self.v = prune_linear_layer(self.v, index)
     self.o = prune_linear_layer(self.o, index, dim=1)
     # Update hyper params
     self.n_heads = self.n_heads - len(heads)
     self.inner_dim = self.d_kv * self.n_heads
     self.pruned_heads = self.pruned_heads.union(heads)
Exemplo n.º 5
0
 def prune_heads(self, heads: List[int]):
     attention_head_size = self.dim // self.n_heads
     if len(heads) == 0:
         return
     mask = torch.ones(self.n_heads, attention_head_size)
     heads = set(heads) - self.pruned_heads
     for head in heads:
         head -= sum(1 if h < head else 0 for h in self.pruned_heads)
         mask[head] = 0
     mask = mask.view(-1).contiguous().eq(1)
     index = torch.arange(len(mask))[mask].long()
     # Prune linear layers
     self.q_lin = prune_linear_layer(self.q_lin, index)
     self.k_lin = prune_linear_layer(self.k_lin, index)
     self.v_lin = prune_linear_layer(self.v_lin, index)
     self.out_lin = prune_linear_layer(self.out_lin, index, dim=1)
     # Update hyper params
     self.n_heads = self.n_heads - len(heads)
     self.dim = attention_head_size * self.n_heads
     self.pruned_heads = self.pruned_heads.union(heads)