Example #1
0
class SparseFCLayer(torch.nn.Module):
    """
    Sparse matrix multiplication supports only (sparse, dense). That's why the matrix multiplication is reversed here.
    """
    def __init__(self,
                 weights: torch.sparse.FloatTensor,
                 biases=None,
                 activation=None):
        super(SparseFCLayer, self).__init__()
        if not weights.is_sparse:
            raise ValueError("Left weights must be sparse")
        elif not weights.is_coalesced():
            raise ValueError("Left weights must be coalesced")

        # Dimension is reversed
        self.n_inputs = weights.size(1)
        self.n_outputs = weights.size(0)
        self._activation = activation

        self._weights = Parameter(weights)
        if biases is None:
            self._biases = Parameter(torch.Tensor(self.n_outputs, 1))
            torch.nn.init.zeros_(self._biases)
        else:
            self._biases = Parameter(biases)

    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
                              missing_keys, unexpected_keys, error_msgs):
        r"""Copies parameters and buffers from :attr:`state_dict` into only
        this module, but not its descendants. This is called on every submodule
        in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this
        module in input :attr:`state_dict` is provided as :attr:`local_metadata`.
        For state dicts without metadata, :attr:`local_metadata` is empty.
        Subclasses can achieve class-specific backward compatible loading using
        the version number at `local_metadata.get("version", None)`.

        .. note::
            :attr:`state_dict` is not the same object as the input
            :attr:`state_dict` to :meth:`~torch.nn.Module.load_state_dict`. So
            it can be modified.

        Arguments:
            state_dict (dict): a dict containing parameters and
                persistent buffers.
            prefix (str): the prefix for parameters and buffers used in this
                module
            local_metadata (dict): a dict containing the metadata for this module.
                See
            strict (bool): whether to strictly enforce that the keys in
                :attr:`state_dict` with :attr:`prefix` match the names of
                parameters and buffers in this module
            missing_keys (list of str): if ``strict=True``, add missing keys to
                this list
            unexpected_keys (list of str): if ``strict=True``, add unexpected
                keys to this list
            error_msgs (list of str): error messages should be added to this
                list, and will be reported together in
                :meth:`~torch.nn.Module.load_state_dict`
        """
        for hook in self._load_state_dict_pre_hooks.values():
            hook(state_dict, prefix, local_metadata, strict, missing_keys,
                 unexpected_keys, error_msgs)

        local_name_params = itertools.chain(self._parameters.items(),
                                            self._buffers.items())
        local_state = {
            k: v.data
            for k, v in local_name_params if v is not None
        }

        for name, param in local_state.items():
            key = prefix + name
            if key in state_dict:
                input_param = state_dict[key]

                # Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
                if len(param.shape) == 0 and len(input_param.shape) == 1:
                    input_param = input_param[0]

                if input_param.shape != param.shape:
                    # local shape should match the one in checkpoint
                    error_msgs.append(
                        'size mismatch for {}: copying a param with shape {} from checkpoint, '
                        'the shape in current model is {}.'.format(
                            key, input_param.shape, param.shape))
                    continue

                if isinstance(input_param, Parameter):
                    # backwards compatibility for serialized parameters
                    input_param = input_param.data
                try:
                    # param.copy_(input_param)
                    with torch.no_grad():
                        if name in self._parameters.keys():
                            # self._parameters[name].zero_()
                            self._parameters[name].copy_(input_param)
                        elif name in self._buffers.keys():
                            # self._buffers[name].zero_()
                            self._buffers[name].copy_(input_param)
                except Exception:
                    error_msgs.append(
                        'While copying the parameter named "{}", '
                        'whose dimensions in the model are {} and '
                        'whose dimensions in the checkpoint are {}.'.format(
                            key, param.size(), input_param.size()))
            elif strict:
                missing_keys.append(key)

        if strict:
            for key in state_dict.keys():
                if key.startswith(prefix):
                    input_name = key[len(prefix):]
                    input_name = input_name.split(
                        '.', 1)[0]  # get the name of param/buffer/child
                    if input_name not in self._modules and input_name not in local_state:
                        unexpected_keys.append(key)

    def _sparse_masked_select_abs(self,
                                  sparse_tensor: torch.sparse.FloatTensor,
                                  thr):
        indices = sparse_tensor._indices()
        values = sparse_tensor._values()
        prune_mask = torch.abs(values) >= thr
        return torch.sparse_coo_tensor(
            indices=indices.masked_select(prune_mask).reshape(2, -1),
            values=values.masked_select(prune_mask),
            size=[self.n_outputs, self.n_inputs]).coalesce()

    def prune_by_threshold(self, thr):
        self._weights = Parameter(
            self._sparse_masked_select_abs(self._weights, thr))

    def prune_by_rank(self, rank):
        weights_val = self._weights._values()
        sorted_abs_weights = torch.sort(torch.abs(weights_val))[0]
        thr = sorted_abs_weights[rank]
        self.prune_by_threshold(thr)

    def prune_by_pct(self, pct):
        prune_idx = int(self._weights._nnz() * pct)
        self.prune_by_rank(prune_idx)

    def forward(self, inputs: torch.Tensor):
        ret = torch.sparse.addmm(self._biases, self._weights, inputs)
        return ret if self._activation is None else self._activation(ret)

    @property
    def weights(self):
        return self._weights

    @property
    def biases(self):
        return self._biases

    @property
    def activation(self):
        return self._activation

    @property
    def n_weights(self):
        return self._weights._nnz()

    def __str__(self):
        return "Sparse layer with size {} and activation {}".format(
            (self.n_outputs, self.n_inputs), self._activation)
Example #2
0
class SparseLinear(nn.Module):
    __constants__ = ['in_features', 'out_features']

    def __init__(self,
                 weight: sparse.FloatTensor,
                 bias,
                 mask,
                 transpose=False):
        super(SparseLinear, self).__init__()
        if not weight.is_sparse:
            raise ValueError("Weight must be sparse")
        elif weight._nnz() > 0 and not weight.is_coalesced():
            raise ValueError("Weight must be coalesced")

        self.transpose = transpose

        self.in_features = weight.size(1)
        self.out_features = weight.size(0)
        self.mask = mask.clone()

        # in order to add to optimizer
        self.weight = Parameter(weight.data.clone(), requires_grad=False)
        # Don't move after creation to make it a leaf
        self.dense_weight_placeholder = Parameter(
            torch.empty(size=self.weight.size(), device=self.weight.device))
        self.dense_weight_placeholder.is_placeholder = True

        # create links
        self.weight.dense = self.dense_weight_placeholder
        self.weight.mask = self.mask
        self.weight.is_sparse_param = True

        if bias is None:
            self.register_parameter('bias', None)
        else:
            assert bias.size() == torch.Size((weight.size(0), 1))
            self.bias = Parameter(bias.data.clone())

    # def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
    #                           missing_keys, unexpected_keys, error_msgs):
    #     super(SparseLinear, self)._load_from_state_dict(state_dict, prefix, local_metadata, strict,
    #                                                     missing_keys, unexpected_keys, error_msgs)
    #
    #     assert hasattr(state_dict[prefix + "weight"], "mask")
    #     self.mask = state_dict[prefix + "weight"].mask

    def _sparse_masked_select_abs(self, sparse_tensor: sparse.FloatTensor,
                                  thr):
        indices = sparse_tensor._indices()
        values = sparse_tensor._values()
        prune_mask = torch.abs(values) >= thr
        return torch.sparse_coo_tensor(
            indices=indices.masked_select(prune_mask).reshape(2, -1),
            values=values.masked_select(prune_mask),
            size=[self.out_features, self.in_features]).coalesce()

    def prune_by_threshold(self, thr):
        self.weight = Parameter(
            self._sparse_masked_select_abs(self.weight, thr))

    def prune_by_rank(self, rank):
        weight_val = self.weight._values()
        sorted_abs_weight = torch.sort(torch.abs(weight_val))[0]
        thr = sorted_abs_weight[rank]
        self.prune_by_threshold(thr)

    def prune_by_pct(self, pct):
        if pct == 0:
            return
        prune_idx = int(self.weight._nnz() * pct)
        self.prune_by_rank(prune_idx)

    def move_data(self, device: torch.device):
        self.weight = self.weight.to(device)

    def forward(self, inp: torch.Tensor):
        if self.transpose:
            return AddmmFunction.apply(self.bias, self.weight,
                                       self.dense_weight_placeholder,
                                       inp.t()).t()
        else:
            return AddmmFunction.apply(self.bias, self.weight,
                                       self.dense_weight_placeholder, inp)

    @property
    def num_weight(self) -> int:
        return self.weight._nnz()

    def __repr__(self):
        return "SparseLinear(in_features={}, out_features={}, bias={}, transpose = {})".format(
            self.in_features, self.out_features, self.bias is not None,
            self.transpose)

    def __str__(self):
        return self.__repr__()