Ejemplo n.º 1
0
    def compute_mask(self, t, default_mask):
        # Check that the amount of units to  prune is not > than the number of
        # parameters in t
        tensor_size = t.nelement() // self.block_size[0] // self.block_size[1]
        # Compute number of units to prune: amount if int,
        # else amount * tensor_size
        nparams_toprune = prune._compute_nparams_toprune(
            self.amount, tensor_size)
        # This shoud raise an error if the number of units to prune is larger
        # than the number of units in the tensor
        prune._validate_pruning_amount(nparams_toprune, tensor_size)

        norm = _compute_block_norm(t, self.n, self.block_size)
        # largest=True --> top k; largest=False --> bottom k
        # Prune the smallest k blocks
        topk = torch.topk(norm.view(-1), k=nparams_toprune, largest=False)
        # topk will have .indices and .values

        if nparams_toprune == 0:  # k=0 not supported by torch.kthvalue
            mask = default_mask
        else:
            mask = torch.ones_like(norm)
            mask.view(-1)[topk.indices] = 0
            mask = torch.repeat_interleave(mask, self.block_size[0], dim=0)
            mask = torch.repeat_interleave(mask, self.block_size[1], dim=1)
            mask *= default_mask.to(dtype=mask.dtype)

        return mask
Ejemplo n.º 2
0
 def compute_mask(self, t, default_mask):
     tensor_size = t.nelement()
     nparams_toprune = prune._compute_nparams_toprune(self.prun_ratio, tensor_size)
     mask = default_mask.clone()
     topk = torch.topk(
         torch.abs(self.score).view(-1), k=nparams_toprune, largest=False
     )
     # topk will have .indices and .values
     mask.view(-1)[topk.indices] = 0
     return mask
Ejemplo n.º 3
0
    def compute_mask(self, t, default_mask):
        # Check that the amount of units to prune is not > than the number of parameters in t
        tensor_size = t.nelement()
        # Compute number of units to prune: amount if int, else amount * tensor_size
        nparams_toprune = prune._compute_nparams_toprune(self.amount, tensor_size)
        # If amount is not 0, but less than 1, round up to 1
        if 0 < nparams_toprune < 1:
            nparams_toprune = 1
        # This should raise an error if the number of units to prune is larger than the number of units in the tensor
        prune._validate_pruning_amount(nparams_toprune, tensor_size)

        mask = default_mask.clone(memory_format=torch.contiguous_format)

        if nparams_toprune != 0:  # k=0 not supported by torch.kthvalue
            # largest=True --> top k; largest=False --> bottom k
            # Prune the smallest k
            topk = torch.topk(torch.abs(t).view(-1), k=nparams_toprune, largest=False)
            # topk will have .indices and .values
            mask.view(-1)[topk.indices] = 0

        return mask