Пример #1
0
    def _binopt(self, other, op):
        """
        Do the binary operation fn to two sparse matrices using
        fast_csr_matrix only when other is also a fast_csr_matrix.
        """
        # e.g. csr_plus_csr, csr_minus_csr, etc.
        if not isinstance(other, fast_csr_matrix):
            other = csr_matrix(other)
        # e.g. csr_plus_csr, csr_minus_csr, etc.
        fn = getattr(_sparsetools, self.format + op + self.format)

        maxnnz = self.nnz + other.nnz
        idx_dtype = get_index_dtype((self.indptr, self.indices,
                                     other.indptr, other.indices),
                                    maxval=maxnnz)
        indptr = np.empty(self.indptr.shape, dtype=idx_dtype)
        indices = np.empty(maxnnz, dtype=idx_dtype)

        bool_ops = ['_ne_', '_lt_', '_gt_', '_le_', '_ge_']
        if op in bool_ops:
            data = np.empty(maxnnz, dtype=np.bool_)
        else:
            data = np.empty(maxnnz, dtype=upcast(self.dtype, other.dtype))

        fn(self.shape[0], self.shape[1],
           np.asarray(self.indptr, dtype=idx_dtype),
           np.asarray(self.indices, dtype=idx_dtype),
           self.data,
           np.asarray(other.indptr, dtype=idx_dtype),
           np.asarray(other.indices, dtype=idx_dtype),
           other.data,
           indptr, indices, data)

        actual_nnz = indptr[-1]
        indices = indices[:actual_nnz]
        data = data[:actual_nnz]
        if actual_nnz < maxnnz // 2:
            # too much waste, trim arrays
            indices = indices.copy()
            data = data.copy()
        if isinstance(other, fast_csr_matrix) and (not op in bool_ops):
            A = fast_csr_matrix((data, indices, indptr), dtype=data.dtype, shape=self.shape)
        else:
            A = csr_matrix((data, indices, indptr), dtype=data.dtype, shape=self.shape)
        return A
Пример #2
0
 def test_upcast(self):
     assert_equal(sputils.upcast('intc'), np.intc)
     assert_equal(sputils.upcast('int32', 'float32'), np.float64)
     assert_equal(sputils.upcast('bool', complex, float), np.complex128)
     assert_equal(sputils.upcast('i', 'd'), np.float64)
Пример #3
0
    def _mul_sparse_matrix(self, other):
        """
        Do the sparse matrix mult returning fast_csr_matrix only
        when other is also fast_csr_matrix.
        """
        M, _ = self.shape
        _, N = other.shape

        major_axis = self._swap((M, N))[0]
        if isinstance(other, fast_csr_matrix):
            A = zcsr_mult(self, other, sorted=1)
            return A

        other = csr_matrix(other)  # convert to this format
        idx_dtype = get_index_dtype((self.indptr, self.indices,
                                     other.indptr, other.indices),
                                    maxval=M*N)

        # scipy 1.5 renamed the older csr_matmat_pass1 to the much more
        # descriptive csr_matmat_maxnnz, but also changed the call and logic
        # structure of constructing the indices.
        try:
            fn = getattr(_sparsetools, self.format + '_matmat_maxnnz')
            nnz = fn(M, N,
                     np.asarray(self.indptr, dtype=idx_dtype),
                     np.asarray(self.indices, dtype=idx_dtype),
                     np.asarray(other.indptr, dtype=idx_dtype),
                     np.asarray(other.indices, dtype=idx_dtype))
            idx_dtype = get_index_dtype((self.indptr, self.indices,
                                         other.indptr, other.indices),
                                        maxval=nnz)
            indptr = np.empty(major_axis + 1, dtype=idx_dtype)
        except AttributeError:
            indptr = np.empty(major_axis + 1, dtype=idx_dtype)
            fn = getattr(_sparsetools, self.format + '_matmat_pass1')
            fn(M, N,
               np.asarray(self.indptr, dtype=idx_dtype),
               np.asarray(self.indices, dtype=idx_dtype),
               np.asarray(other.indptr, dtype=idx_dtype),
               np.asarray(other.indices, dtype=idx_dtype),
               indptr)
            nnz = indptr[-1]
            idx_dtype = get_index_dtype((self.indptr, self.indices,
                                         other.indptr, other.indices),
                                        maxval=nnz)

        indices = np.empty(nnz, dtype=idx_dtype)
        data = np.empty(nnz, dtype=upcast(self.dtype, other.dtype))

        try:
            fn = getattr(_sparsetools, self.format + '_matmat')
        except AttributeError:
            fn = getattr(_sparsetools, self.format + '_matmat_pass2')
        fn(M, N, np.asarray(self.indptr, dtype=idx_dtype),
           np.asarray(self.indices, dtype=idx_dtype),
           self.data,
           np.asarray(other.indptr, dtype=idx_dtype),
           np.asarray(other.indices, dtype=idx_dtype),
           other.data,
           indptr, indices, data)
        A = csr_matrix((data, indices, indptr), shape=(M, N))
        return A