예제 #1
0
 def prune_heads(self, heads):
     attention_head_size = self.dim // self.n_heads
     if len(heads) == 0:
         return
     heads, index = find_pruneable_heads_and_indices(
         heads, self.n_heads, attention_head_size, self.pruned_heads)
     # Prune linear layers
     self.q_lin = prune_linear_layer(self.q_lin, index)
     self.k_lin = prune_linear_layer(self.k_lin, index)
     self.v_lin = prune_linear_layer(self.v_lin, index)
     self.out_lin = prune_linear_layer(self.out_lin, index, dim=1)
     # Update hyper params
     self.n_heads = self.n_heads - len(heads)
     self.dim = attention_head_size * self.n_heads
     self.pruned_heads = self.pruned_heads.union(heads)
예제 #2
0
    def prune_heads(self, heads):
        if len(heads) == 0:
            return
        heads, index = find_pruneable_heads_and_indices(
            heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
        )

        # Prune linear layers
        self.self.query = prune_linear_layer(self.self.query, index)
        self.self.key = prune_linear_layer(self.self.key, index)
        self.self.value = prune_linear_layer(self.self.value, index)
        self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)

        # Update hyper params and store pruned heads
        self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
        self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
        self.pruned_heads = self.pruned_heads.union(heads)
    def prune_heads(self, heads):
        if len(heads) == 0:
            return
        heads, index = find_pruneable_heads_and_indices(
            heads, self.n_head, self.split_size // self.n_head,
            self.pruned_heads)
        index_attn = torch.cat(
            [index, index + self.split_size, index + (2 * self.split_size)])

        # Prune conv1d layers
        self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)
        self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0)

        # Update hyper params
        self.split_size = (self.split_size // self.n_head) * (self.n_head -
                                                              len(heads))
        self.n_head = self.n_head - len(heads)
        self.pruned_heads = self.pruned_heads.union(heads)