Beispiel #1
0
class QuantizationLayer(Module):
    def __init__(self, in_features, out_features, m, ci):
        super(QuantizationLayer, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.codebookSize = m
        self.c = ci
        # There are two group of parameters: {ai, bi}, as described in the paper
        self.weight = Parameter(torch.Tensor(self.codebookSize - 1, 2))
        self.reset_parameters()

    def reset_parameters(self):
        std = 1. / math.sqrt(self.weight.numel())
        self.weight.data.uniform_(-std, std)

    def forward(self, x):
        # noinspection PyUnresolvedReferences
        ret = Variable(torch.zeros(x.size()), requires_grad=False)

        for kk in range(0, self.codebookSize - 1):
            # noinspection PyUnresolvedReferences
            temp_val = torch.add(x, self.weight[kk, 1])
            # noinspection PyUnresolvedReferences
            temp_val = torch.mul(temp_val, self.c)
            # noinspection PyUnresolvedReferences
            temp_val = torch.tanh(temp_val)
            # noinspection PyUnresolvedReferences
            temp_val = torch.mul(temp_val, self.weight[kk, 0])
            # noinspection PyUnresolvedReferences
            ret = torch.add(ret, temp_val)
        return ret
Beispiel #2
0
    class LargeParamModel(Module):
        def __init__(self):
            super().__init__()
            self.param = Parameter(torch.zeros((param_sz, ), dtype=torch.float32))

            # only do weight initialization on root rank to
            # make sure we are broadcasting correctly from rank 0
            if dist.get_rank() == 0:
                partition_sz = math.ceil(self.param.numel() / dist.get_world_size())
                offset = 0
                for rank in range(dist.get_world_size()):
                    with torch.no_grad():
                        self.param[offset:offset + partition_sz].fill_(rank)
                    offset += partition_sz

        def forward(self, x: Tensor) -> Tensor:
            return x * self.param
Beispiel #3
0
    def _init_param_attributes(self, p: Parameter) -> None:
        """
        We manage several attributes on each Parameter instance. The first two
        are set by :func:`_shard_parameters`:
            ``_is_sharded``: ``True`` if the Parameter is sharded or ``False``
                if the Parameter is intentionally not sharded (in which case we
                will all-reduce grads for this param). Currently the only way
                `_is_sharded = False` is if world_size = 1.
            ``_orig_size``: the size of the original Parameter (before sharding)
        A few attributes are set here:
            ``_local_shard``: a single shard of the parameter. This is needed to
                recover the shard after rebuilding full parameter in forward
                and backward.
            ``_full_param_padded``: the full weight (padded to be evenly
                divisible by ``world_size``), used for computation in the
                forward and backward pass. It is initialized with the
                appropriate size and then has its storage freed. This will be
                resized in place and only materialized (via all-gather) as needed.
        Another attribute is set by :func:`_register_post_backward_hooks`:
            ``_shard_bwd_hook``: it holds the parameter's AccumulateGrad object
                and the registered post hook handle.
        """
        assert hasattr(p, "_is_sharded") and hasattr(
            p, "_orig_size"
        ), "Parameters should have been sharded during construction."
        if hasattr(p, "_local_shard"):
            return

        # A single shard of the parameters.
        p._local_shard = p.data  # type: ignore[attr-defined]

        # We also maintain a full-sized parameter of type self.compute_dtype.
        # We resize the storage to size 0 at init (here) and only materialize
        # as needed. The storage may contain padding elements so that it is
        # evenly divisible by world_size, although these padding elements will
        # be removed before the relevant computation.
        if p._is_sharded:  # type: ignore[attr-defined]
            p._full_param_padded = torch.zeros(  # type: ignore[attr-defined]
                p.numel() * self.world_size,
                device=self.compute_device,
                dtype=self.compute_dtype,
            )
            _free_storage(p._full_param_padded)  # type: ignore[attr-defined]
class HashedEmbeddingBag(nn.Module):
    def __init__(
        self, 
        num_embeddings: int, 
        embedding_dim: int, 
        compression:float = 1. / 64., 
        mode:str = "sum", 
        _weight: Optional[torch.Tensor] = None) -> None:
        super(HashedEmbeddingBag, self).__init__()
        self.num_embeddings = num_embeddings
        self.embedding_dim = embedding_dim
        weight_size = int(num_embeddings * embedding_dim * compression)
        print("Inside HashedEmbeddingBag: ", num_embeddings, embedding_dim, compression, weight_size)
        if _weight is None:
            low = -math.sqrt(1 / self.num_embeddings)
            high = math.sqrt(1 / self.num_embeddings)
            self.hashed_weight = Parameter(torch.rand(weight_size) * (high - low) + low)
            #self.reset_parameters()
            print("Inside HashedEmbeddingBag (after reset): ", num_embeddings, embedding_dim, compression, weight_size, self.hashed_weight.shape)
        else:
            #assert len(_weight.shape) == 1 and _weight.shape[0] == weight_size, \
            #    'Shape of weight does not match num_embeddings and embedding_dim'
            self.hashed_weight = Parameter(_weight)
            self.weight_size = self.hashed_weight.numel()
        self.mode = mode
    """
    def reset_parameters(self) -> None:
        # init.normal_(self.weight)
        W = np.random.uniform(
                low=-np.sqrt(1 / self.num_embeddings), high=np.sqrt(1 / self.num_embeddings), size=(self.hashed_weight.shape[0], )
            ).astype(np.float32)
        self.hashed_weight.data = torch.tensor(W, requires_grad=True)
    """
    def forward(self, indices: torch.Tensor, offsets: Optional[torch.Tensor] = None) -> torch.Tensor:
        return HashedEmbeddingBagFunction.apply(
            self.hashed_weight,
            indices,
            offsets,
            self.mode,
            self.embedding_dim
        )
Beispiel #5
0
class InnerProductKernel(Kernel):

	def __init__(self, ndim, input_map=lambda x: x, diagonal=True):
		super(InnerProductKernel, self).__init__(input_map)
		self.diag = diagonal
		if diagonal:
			self.sigma_sqrt = Parameter(torch.FloatTensor(ndim))
		else:
			self.sigma_sqrt = Parameter(torch.FloatTensor(ndim, ndim))

	def reset_parameters(self):
		super(InnerProductKernel, self).reset_parameters()
		self.sigma_sqrt.data.normal_()

	def out_of_bounds(self, vec=None):
		if vec is None:
			return super(InnerProductKernel, self).out_of_bounds(self.log_amp)
		else:
			return not super(InnerProductKernel, self).out_of_bounds(vec[:1])

	def n_params(self):
		return super(InnerProductKernel, self).n_params() + self.sigma_chol.numel()

	def param_to_vec(self):
		return torch.cat([self.log_amp.data, self.sigma_chol.data])

	def vec_to_param(self, vec):
		self.log_amp.data = vec[0:1]
		self.sigma_sqrt.data = vec[1:]

	def prior(self, vec):
		return super(InnerProductKernel, self).prior(vec[:1]) + smp.normal(vec[1:])

	def forward(self, input1, input2=None):
		stabilizer = 0
		if input2 is None:
			input2 = input1
			stabilizer = Variable(torch.diag(input1.data.new(input1.size(0)).fill_(1e-6 * math.exp(self.log_amp.data[0]))))
		gram_mat = inner_product.InnerProductKernel.apply(self.input_map(input1), self.input_map(input2), self.log_amp, self.sigma_sqrt if self.diag else self.sigma_sqrt.view(int(self.sigma_sqrt.numel() ** 0.5), -1))
		return gram_mat + stabilizer

	def __repr__(self):
		return self.__class__.__name__ + ' (diag=' + ('True' if self.sigma_sqrt.dim()==1 else 'False') + ')'
class SparseTensor(nn.Module):
    def __init__(self,
                 tensor_size,
                 initial_sparsity,
                 sub_kernel_granularity=4):
        super(SparseTensor, self).__init__()
        self.s_tensor = Parameter(torch.Tensor(torch.Size(tensor_size)))
        self.initial_sparsity = initial_sparsity
        self.sub_kernel_granularity = sub_kernel_granularity

        assert self.s_tensor.dim() == 2 or self.s_tensor.dim(
        ) == 4, "can only do 2D or 4D sparse tensors"

        trailing_dimensions = [1] * (4 - sub_kernel_granularity)
        self.register_buffer(
            'mask', torch.Tensor(*(tensor_size[:sub_kernel_granularity])))

        self.normalize_coeff = np.prod(
            tensor_size[sub_kernel_granularity:]).item()

        self.conv_tensor = False if self.s_tensor.dim() == 2 else True

        self.mask.zero_()
        flat_mask = self.mask.view(-1)
        indices = np.arange(flat_mask.size(0))
        np.random.shuffle(indices)
        flat_mask[indices[:int((1 - initial_sparsity) * flat_mask.size(0) +
                               0.1)]] = 1

        self.grown_indices = None
        self.init_parameters()
        self.reinitialize_unused()

        self.tensor_sign = torch.sign(self.s_tensor.data.view(-1))

    def reinitialize_unused(self, reinitialize_unused_to_zero=True):
        unused_positions = (self.mask < 0.5)
        if reinitialize_unused_to_zero:
            self.s_tensor.data[unused_positions] = torch.zeros(
                self.s_tensor.data[unused_positions].size()).to(
                    self.s_tensor.device)
        else:
            if self.conv_tensor:
                n = self.s_tensor.size(0) * self.s_tensor.size(
                    2) * self.s_tensor.size(3)
                self.s_tensor.data[unused_positions] = torch.zeros(
                    self.s_tensor.data[unused_positions].size()).normal_(
                        0, math.sqrt(2. / n)).to(self.s_tensor.device)
            else:
                stdv = 1. / math.sqrt(self.s_tensor.size(1))
                self.s_tensor.data[unused_positions] = torch.zeros(
                    self.s_tensor.data[unused_positions].size()).normal_(
                        0, stdv).to(self.s_tensor.device)

    def init_parameters(self):
        stdv = 1 / math.sqrt(np.prod(self.s_tensor.size()[1:]))

        self.s_tensor.data.uniform_(-stdv, stdv)

    def prune_sign_change(self,
                          reinitialize_unused_to_zero=True,
                          enable_print=False):
        W_flat = self.s_tensor.data.view(-1)

        new_tensor_sign = torch.sign(W_flat)
        mask_flat = self.mask.view(-1)

        mask_indices = torch.nonzero(mask_flat > 0.5).view(-1)

        sign_change_indices = mask_indices[(
            (new_tensor_sign[mask_indices] *
             self.tensor_sign[mask_indices].to(new_tensor_sign.device)) <
            -0.5).nonzero().view(-1)]

        mask_flat[sign_change_indices] = 0
        self.reinitialize_unused(reinitialize_unused_to_zero)

        cutoff = sign_change_indices.numel()

        if enable_print:
            print('pruned {}  connections'.format(cutoff))
        if self.grown_indices is not None and enable_print:
            overlap = np.intersect1d(sign_change_indices.cpu().numpy(),
                                     self.grown_indices.cpu().numpy())
            print('pruned {} ({} %) just grown weights'.format(
                overlap.size, overlap.size * 100.0 / self.grown_indices.size(0)
                if self.grown_indices.size(0) > 0 else 0.0))

        self.tensor_sign = new_tensor_sign
        return sign_change_indices

    def prune_small_connections(self,
                                prune_fraction,
                                reinitialize_unused_to_zero=True):
        if self.conv_tensor and self.sub_kernel_granularity < 4:
            W_flat = self.s_tensor.abs().sum(
                list(np.arange(self.sub_kernel_granularity,
                               4))).view(-1) / self.normalize_coeff
        else:
            W_flat = self.s_tensor.data.view(-1)

        mask_flat = self.mask.view(-1)

        mask_indices = torch.nonzero(mask_flat > 0.5).view(-1)

        W_masked = W_flat[mask_indices]

        sorted_W_indices = torch.sort(torch.abs(W_masked))[1]

        cutoff = int(prune_fraction * W_masked.numel()) + 1

        mask_flat[mask_indices[sorted_W_indices[:cutoff]]] = 0
        self.reinitialize_unused(reinitialize_unused_to_zero)

        #        print('pruned {}  connections'.format(cutoff))
        #        if self.grown_indices is not None:
        #            overlap = np.intersect1d(mask_indices[sorted_W_indices[:cutoff]].cpu().numpy(),self.grown_indices.cpu().numpy())
        #print('pruned {} ({} %) just grown weights'.format(overlap.size,overlap.size * 100.0 / self.grown_indices.size(0)))

        return mask_indices[sorted_W_indices[:cutoff]]

    def prune_threshold(self, threshold, reinitialize_unused_to_zero=True):
        if self.conv_tensor and self.sub_kernel_granularity < 4:
            W_flat = self.s_tensor.abs().sum(
                list(np.arange(self.sub_kernel_granularity,
                               4))).view(-1) / self.normalize_coeff
        else:
            W_flat = self.s_tensor.data.view(-1)

        mask_flat = self.mask.view(-1)

        mask_indices = torch.nonzero(mask_flat > 0.5).view(-1)

        W_masked = W_flat[mask_indices]

        prune_indices = (W_masked.abs() < threshold).nonzero().view(-1)

        if mask_indices.size(0) == prune_indices.size(0):
            print('removing all. keeping one')
            prune_indices = prune_indices[1:]

        mask_flat[mask_indices[prune_indices]] = 0

        #       if mask_indices.numel() > 0 :
        #           print('pruned {}/{}({:.2f})  connections'.format(prune_indices.numel(),mask_indices.numel(),prune_indices.numel()/mask_indices.numel()))

        #        if self.grown_indices is not None and self.grown_indices.size(0) != 0 :
        #            overlap = np.intersect1d(mask_indices[prune_indices].cpu().numpy(),self.grown_indices.cpu().numpy())
        #            print('pruned {} ({} %) just grown weights'.format(overlap.size,overlap.size * 100.0 / self.grown_indices.size(0)))

        self.reinitialize_unused(reinitialize_unused_to_zero)

        return mask_indices[prune_indices]

    def grow_random(self,
                    grow_fraction,
                    pruned_indices=None,
                    enable_print=False,
                    n_to_add=None):
        mask_flat = self.mask.view(-1)
        mask_zero_indices = torch.nonzero(mask_flat < 0.5).view(-1)
        if pruned_indices is not None:
            cutoff = pruned_indices.size(0)
            mask_zero_indices = torch.Tensor(
                np.setdiff1d(mask_zero_indices.cpu().numpy(),
                             pruned_indices.cpu().numpy())).long().to(
                                 mask_zero_indices.device)
        else:
            cutoff = int(grow_fraction * mask_zero_indices.size(0))

        if n_to_add is not None:
            cutoff = n_to_add

        if mask_zero_indices.numel() < cutoff:
            print('******no place to grow {} connections, growing {} instead'.
                  format(cutoff, mask_zero_indices.numel()))
            cutoff = mask_zero_indices.numel()

        if enable_print:
            print('grown {}  connections'.format(cutoff))

        self.grown_indices = mask_zero_indices[torch.randperm(
            mask_zero_indices.numel())][:cutoff]
        mask_flat[self.grown_indices] = 1

        return cutoff

    def get_sparsity(self):
        active_elements = self.mask.sum() * np.prod(
            self.s_tensor.size()[self.sub_kernel_granularity:]).item()
        return (active_elements, 1 - active_elements / self.s_tensor.numel())

    def forward(self):
        if self.conv_tensor:
            return self.mask.view(
                *(self.mask.size() + (1, ) *
                  (4 - self.sub_kernel_granularity))) * self.s_tensor
        else:
            return self.mask * self.s_tensor

    def extra_repr(self):
        return 'full tensor size : {} , sparsity mask : {} , sub kernel granularity : {}'.format(
            self.s_tensor.size(), self.get_sparsity(),
            self.sub_kernel_granularity)
Beispiel #7
0
    def _init_param_attributes(self, p: Parameter) -> None:
        """
        We manage several attributes on each Parameter instance. The first two
        are set by :func:`_shard_parameters`:
            ``_is_sharded``: ``True`` if the Parameter is sharded or ``False``
                if the Parameter is intentionally not sharded (in which case we
                will all-reduce grads for this param). Currently the only way
                `_is_sharded = False` is if world_size = 1.
            ``_orig_size``: the size of the original Parameter (before sharding)
        A few attributes are set here:
            ``_local_shard``: a single shard of the parameter. This is needed to
                recover the shard after rebuilding full parameter in forward
                and backward.
            ``_full_param_padded``: the full weight (padded to be evenly
                divisible by ``world_size``), used for computation in the
                forward and backward pass. It is initialized with the
                appropriate size and then has its storage freed. This will be
                resized in place and only materialized (via all-gather) as needed.
        Another attribute is set by :func:`_register_post_backward_hooks`:
            ``_shard_bwd_hook``: it holds the parameter's AccumulateGrad object
                and the registered post hook handle.
        """
        assert hasattr(p, "_is_sharded") and hasattr(
            p, "_orig_size"
        ), "Parameters should have been sharded during construction."
        if hasattr(p, "_local_shard"):
            # If CPU offloading, p._local_shard should have been placed on CPU
            # during its first lazy construction.
            if self.cpu_offload.offload_params:
                assert p._local_shard.device == torch.device(  # type: ignore[attr-defined]
                    "cpu"
                ), (
                    "Expected p._local_shard to be on CPU, "  # type: ignore[attr-defined]
                    f"but it's on {p._local_shard.device}"  # type: ignore[attr-defined]
                )
            return

        # A single shard of the parameters. Also makes p._local_shard to be on
        # CPU if we are CPU offloading, since p.data would be on CPU during
        # init.
        if self.cpu_offload.offload_params:
            assert p.device == torch.device(
                "cpu"
            ), "Expected param to be on CPU when cpu_offloading is enabled."
        p._local_shard = p.data  # type: ignore[attr-defined]
        # If CPU offloading, pin the memory to enable faster CPU -> GPU device
        # transfer.
        if self.cpu_offload.offload_params:
            assert p._local_shard.device == torch.device(
                "cpu")  # type: ignore[attr-defined]
            p._local_shard.pin_memory()  # type: ignore[attr-defined]
            # When offloading parameters, also move the grad shard to CPU during
            # backward pass. In this case, it's important to pre-allocate the
            # CPU grad shard in pinned memory so that we can do a non-blocking
            # transfer.
            p._cpu_grad = torch.zeros_like(  # type: ignore[attr-defined]
                p, device=torch.device("cpu")).pin_memory()

        # We also maintain a full-sized parameter of type self.compute_dtype.
        # We resize the storage to size 0 at init (here) and only materialize
        # as needed. The storage may contain padding elements so that it is
        # evenly divisible by world_size, although these padding elements will
        # be removed before the relevant computation.
        if p._is_sharded:  # type: ignore[attr-defined]
            p._full_param_padded = torch.zeros(  # type: ignore[attr-defined]
                p.numel() * self.world_size,
                device=self.compute_device,
                dtype=self.compute_dtype,
            )
            _free_storage(p._full_param_padded)  # type: ignore[attr-defined]