예제 #1
0
def bdsmm(sparse, dense):
    """
    Batch dense-sparse matrix multiply
    """
    if sparse.ndimension() > 2:
        batch_size, n_rows, n_cols = sparse.size()
        batch_assignment = sparse._indices()[0]
        indices = sparse._indices()[1:].clone()
        indices[0].add_(n_rows, batch_assignment)
        indices[1].add_(n_cols, batch_assignment)
        sparse_2d = sparse.new(
            indices, sparse._values(),
            torch.Size((batch_size * n_rows, batch_size * n_cols)))

        if dense.size(0) == 1:
            dense = dense.repeat(batch_size, 1, 1)
        dense_2d = dense.contiguous().view(batch_size * n_cols, -1)
        res = torch.dsmm(sparse_2d, dense_2d)
        res = res.view(batch_size, n_rows, -1)
        return res
    elif dense.ndimension() == 3:
        batch_size, _, n_cols = dense.size()
        res = torch.dsmm(
            sparse,
            dense.transpose(0, 1).contiguous().view(-1, batch_size * n_cols))
        res = res.view(-1, batch_size, n_cols)
        res = res.transpose(0, 1).contiguous()
        return res
    else:
        return torch.dsmm(sparse, dense)
예제 #2
0
        def closure(left_vectors, right_vectors):
            if left_vectors.ndimension() == 1:
                left_factor = left_vectors.unsqueeze(0)
                right_factor = right_vectors.unsqueeze(0)
            else:
                left_factor = left_vectors
                right_factor = right_vectors
            if len(args) == 1:
                columns, = args
                return kp_sym_toeplitz_derivative_quadratic_form(columns, left_factor, right_factor),
            elif len(args) == 3:
                columns, W_left, W_right = args
                left_factor = torch.dsmm(W_left.t(), left_factor.t()).t()
                right_factor = torch.dsmm(W_right.t(), right_factor.t()).t()

                res = kp_sym_toeplitz_derivative_quadratic_form(columns, left_factor, right_factor)
                return res, None, None
            elif len(args) == 4:
                columns, W_left, W_right, added_diag, = args
                diag_grad = columns.new(len(added_diag)).zero_()
                diag_grad[0] = (left_factor * right_factor).sum()

                left_factor = torch.dsmm(W_left.t(), left_factor.t()).t()
                right_factor = torch.dsmm(W_right.t(), right_factor.t()).t()

                res = kp_sym_toeplitz_derivative_quadratic_form(columns, left_factor, right_factor)
                return res, None, None, diag_grad
def bdsmm(sparse, dense):
    """
    Batch dense-sparse matrix multiply
    """
    # Make the batch sparse matrix into a block-diagonal matrix
    if sparse.ndimension() > 2:
        # Expand the tensors to account for broadcasting
        output_shape = _matmul_broadcast_shape(sparse.shape, dense.shape)
        expanded_sparse_shape = output_shape[:-2] + sparse.shape[-2:]
        unsqueezed_sparse_shape = [1 for _ in range(len(output_shape) - sparse.dim())] + list(sparse.shape)
        repeat_sizes = tuple(
            output_size // sparse_size
            for output_size, sparse_size in zip(expanded_sparse_shape, unsqueezed_sparse_shape)
        )
        sparse = sparse_repeat(sparse, *repeat_sizes)
        dense = dense.expand(*output_shape[:-2], dense.size(-2), dense.size(-1))

        # Figure out how much need to be added to the row/column indices to create
        # a block-diagonal matrix
        *batch_shape, num_rows, num_cols = sparse.shape
        batch_size = torch.Size(batch_shape).numel()
        batch_multiplication_factor = torch.tensor(
            [torch.Size(batch_shape[i + 1 :]).numel() for i in range(len(batch_shape))],
            dtype=torch.long,
            device=sparse.device,
        )
        if batch_multiplication_factor.is_cuda:
            batch_assignment = (sparse._indices()[:-2].float().t() @ batch_multiplication_factor.float()).long()
        else:
            batch_assignment = sparse._indices()[:-2].t() @ batch_multiplication_factor

        # Create block-diagonal sparse tensor
        indices = sparse._indices()[-2:].clone()
        indices[0].add_(batch_assignment, alpha=num_rows)
        indices[1].add_(batch_assignment, alpha=num_cols)
        sparse_2d = torch.sparse_coo_tensor(
            indices,
            sparse._values(),
            torch.Size((batch_size * num_rows, batch_size * num_cols)),
            dtype=sparse._values().dtype,
            device=sparse._values().device,
        )

        dense_2d = dense.reshape(batch_size * num_cols, -1)
        res = torch.dsmm(sparse_2d, dense_2d)
        res = res.view(*batch_shape, num_rows, -1)
        return res

    elif dense.dim() > 2:
        *batch_shape, num_rows, num_cols = dense.size()
        batch_size = torch.Size(batch_shape).numel()
        dense = dense.view(batch_size, num_rows, num_cols)
        res = torch.dsmm(sparse, dense.transpose(0, 1).reshape(-1, batch_size * num_cols))
        res = res.view(-1, batch_size, num_cols)
        res = res.transpose(0, 1).reshape(*batch_shape, -1, num_cols)
        return res

    else:
        return torch.dsmm(sparse, dense)
예제 #4
0
def interpolated_sym_toeplitz_matmul(toeplitz_column,
                                     vector,
                                     W_left=None,
                                     W_right=None,
                                     noise_diag=None):
    """
    Given a interpolated symmetric Toeplitz matrix W_left*T*W_right, plus possibly an additional
    diagonal component s*I, compute a product with some vector or matrix vector.

    Args:
        - toeplitz_column (vector matrix) - First column of the symmetric Toeplitz matrix T
        - W_left (sparse matrix nxm) - Left interpolation matrix
        - W_right (sparse matrix pxm) - Right interpolation matrix
        - vector (matrix pxk) - Vector (k=1) or matrix (k>1) to multiply WTW with
        - noise_diag (vector p) - If not none, add (s*I)vector to WTW at the end.

    Returns:
        - matrix nxk
    """
    noise_term = None
    ndim = vector.ndimension()

    if ndim == 1:
        vector = vector.unsqueeze(1)

    if noise_diag is not None:
        noise_term = noise_diag.unsqueeze(-1).expand_as(vector) * vector

    if W_left is not None:
        # Get W_{r}^{T}vector
        if W_left.ndimension() == 3:  # Batch mode:
            Wt_times_v = bdsmm(W_right.transpose(1, 2), vector)
            # Get (TW_{r}^{T})vector
            TWt_v = sym_toeplitz_matmul(toeplitz_column, Wt_times_v)
            # Get (W_{l}TW_{r}^{T})vector
            WTWt_v = bdsmm(W_left, TWt_v)
        else:
            Wt_times_v = torch.dsmm(W_right.t(), vector)
            # Get (TW_{r}^{T})vector
            TWt_v = sym_toeplitz_matmul(toeplitz_column, Wt_times_v)
            # Get (W_{l}TW_{r}^{T})vector
            WTWt_v = torch.dsmm(W_left, TWt_v)
    else:
        WTWt_v = sym_toeplitz_matmul(toeplitz_column, vector)

    if noise_term is not None:
        # Get (W_{l}TW_{r}^{T} + \sigma^{2}I)vector
        WTWt_v = WTWt_v + noise_term

    if ndim == 1:
        WTWt_v = WTWt_v.squeeze(1)

    return WTWt_v
예제 #5
0
def interpolated_sym_toeplitz_mul(toeplitz_column,
                                  vector,
                                  W_left=None,
                                  W_right=None,
                                  noise_diag=None):
    """
    Given a interpolated symmetric Toeplitz matrix W_left*T*W_right, plus possibly an additional
    diagonal component s*I, compute a product with some vector or matrix vector.

    Args:
        - toeplitz_column (vector matrix) - First column of the symmetric Toeplitz matrix T
        - W_left (sparse matrix nxm) - Left interpolation matrix
        - W_right (sparse matrix pxm) - Right interpolation matrix
        - vector (matrix pxk) - Vector (k=1) or matrix (k>1) to multiply WTW with
        - noise_diag (vector p) - If not none, add (s*I)vector to WTW at the end.

    Returns:
        - matrix nxk - The result of multiplying (WTW + sI)vector if noise_diag exists, or (WTW)vector otherwise.
    """
    noise_term = None
    if vector.ndimension() == 1:
        if noise_diag is not None:
            noise_term = noise_diag.expand_as(vector) * vector
        vector = vector.unsqueeze(1)
        mul_func = utils.toeplitz.toeplitz_mv
    else:
        if noise_diag is not None:
            noise_term = noise_diag.unsqueeze(1).expand_as(vector) * vector
        mul_func = utils.toeplitz.toeplitz_mm

    if W_left is not None:
        # Get W_{r}^{T}vector
        Wt_times_v = torch.dsmm(W_right.t(), vector)
        # Get (TW_{r}^{T})vector
        TWt_v = mul_func(toeplitz_column, toeplitz_column,
                         Wt_times_v.squeeze())

        if TWt_v.ndimension() == 1:
            TWt_v.unsqueeze_(1)

        # Get (W_{l}TW_{r}^{T})vector
        WTWt_v = torch.dsmm(W_left, TWt_v).squeeze()
    else:
        WTWt_v = mul_func(toeplitz_column, toeplitz_column, vector)

    if noise_term is not None:
        # Get (W_{l}TW_{r}^{T} + \sigma^{2}I)vector
        WTWt_v = WTWt_v + noise_term

    return WTWt_v
예제 #6
0
        def test_shape(di, dj, dk):
            x = self._gen_sparse(2, 20, [di, dj])[0]
            y = self.randn(dj, dk)

            res = torch.dsmm(x, y)
            expected = torch.mm(self.safeToDense(x), y)
            self.assertEqual(res, expected)
예제 #7
0
        def test_shape(di, dj, dk):
            x = self._gen_sparse(2, 20, [di, dj])[0]
            y = self.randn(dj, dk)

            res = torch.dsmm(x, y)
            expected = torch.mm(self.safeToDense(x), y)
            self.assertEqual(res, expected)
    def backward(self, grad_output):
        grad_output_value = grad_output.squeeze()[0]
        toeplitz_column, labels, noise_diag = self.saved_tensors

        mat_inv_y = self.mat_inv_y
        mv_closure = self.mv_closure

        mat_grad = None
        y_grad = None
        noise_grad = None

        if self.needs_input_grad[0]:
            y_mat_inv_W_left = torch.dsmm(self.W_left.t(),
                                          mat_inv_y.unsqueeze(1)).t()
            W_right_mat_inv_y = torch.dsmm(self.W_right.t(),
                                           mat_inv_y.unsqueeze(1))
            quad_form_part = sym_toeplitz_derivative_quadratic_form(
                y_mat_inv_W_left.squeeze(), W_right_mat_inv_y.squeeze())
            log_det_part = torch.zeros(len(toeplitz_column))
            sample_matrix = torch.sign(
                torch.randn(len(labels), self.num_samples))

            left_vectors = torch.dsmm(
                self.W_left.t(),
                LinearCG().solve(mv_closure, sample_matrix)).t()
            right_vectors = torch.dsmm(self.W_right.t(), sample_matrix).t()

            for left_vector, right_vector in zip(left_vectors, right_vectors):
                log_det_part += sym_toeplitz_derivative_quadratic_form(
                    left_vector, right_vector)

            log_det_part.div_(self.num_samples)

            mat_grad = quad_form_part - log_det_part
            mat_grad.mul_(0.5 * grad_output_value)

        if self.needs_input_grad[1]:
            # Need gradient with respect to labels
            y_grad = mat_inv_y.mul_(-grad_output_value)

        if self.needs_input_grad[2]:
            quad_form_part = mat_inv_y.dot(mat_inv_y)
            noise_grad = toeplitz_column.new().resize_(1).fill_(
                quad_form_part - self.tr_inv)
            noise_grad.mul_(0.5 * grad_output_value)

        return mat_grad, y_grad, noise_grad
예제 #9
0
        def test_shape(di, dj, dk):
            x = self._gen_sparse(2, 20, [di, dj], is_cuda)[0]
            y = torch.randn(dj, dk)
            if is_cuda:
                y = y.cuda()

            res = torch.dsmm(x, y)
            expected = torch.mm(x.to_dense(), y)
            self.assertEqual(res, expected)
예제 #10
0
def kp_interpolated_toeplitz_matmul(toeplitz_columns, tensor, interp_left=None, interp_right=None, noise_diag=None):
    """
    Given an interpolated matrix interp_left * T_1 \otimes ... \otimes T_d * interp_right, plus possibly an additional
    diagonal component s*I, compute a product with some tensor or matrix tensor, where T_i is
    symmetric Toeplitz matrices.

    Args:
        - toeplitz_columns (d x m matrix) - columns of d toeplitz matrix T_i with
          length n_i
        - interp_left (sparse matrix nxm) - Left interpolation matrix
        - interp_right (sparse matrix pxm) - Right interpolation matrix
        - tensor (matrix p x k) - Vector (k=1) or matrix (k>1) to multiply WKW with
        - noise_diag (tensor p) - If not none, add (s*I)tensor to WKW at the end.

    Returns:
        - tensor
    """
    output_dims = tensor.ndimension()
    noise_term = None

    if output_dims == 1:
        tensor = tensor.unsqueeze(1)

    if noise_diag is not None:
        noise_term = noise_diag.unsqueeze(1).expand_as(tensor) * tensor

    if interp_left is not None:
        # Get interp_{r}^{T} tensor
        interp_right_tensor = torch.dsmm(interp_right.t(), tensor)
        # Get (T interp_{r}^{T}) tensor
        rhs = kronecker_product_toeplitz_matmul(toeplitz_columns, toeplitz_columns, interp_right_tensor)

        # Get (interp_{l} T interp_{r}^{T})tensor
        output = torch.dsmm(interp_left, rhs)
    else:
        output = kronecker_product_toeplitz_matmul(toeplitz_columns, toeplitz_columns, tensor)

    if noise_term is not None:
        # Get (interp_{l} T interp_{r}^{T} + \sigma^{2}I)tensor
        output = output + noise_term

    if output_dims == 1:
        output = output.squeeze(1)
    return output
예제 #11
0
        def closure(left_vectors, right_vectors):
            if left_vectors.ndimension() == 1:
                left_factor = left_vectors.unsqueeze(0)
                right_factor = right_vectors.unsqueeze(0)
            else:
                left_factor = left_vectors
                right_factor = right_vectors
            if len(args) == 1:
                toeplitz_column, = args
                return sym_toeplitz_derivative_quadratic_form(
                    left_factor, right_factor),
            elif len(args) == 3:
                toeplitz_column, W_left, W_right = args
                if W_left.ndimension() == 3:
                    left_factor = bdsmm(W_left.transpose(1, 2),
                                        left_factor.transpose(1, 2)).transpose(
                                            1, 2)
                    right_factor = bdsmm(W_right.transpose(1, 2),
                                         right_factor.transpose(1,
                                                                2)).transpose(
                                                                    1, 2)
                else:
                    left_factor = torch.dsmm(W_left.t(), left_factor.t()).t()
                    right_factor = torch.dsmm(W_right.t(),
                                              right_factor.t()).t()
                return sym_toeplitz_derivative_quadratic_form(
                    left_factor, right_factor), None, None
            elif len(args) == 4:
                toeplitz_column, W_left, W_right, added_diag, = args

                diag_grad = left_vectors.new(len(added_diag)).zero_()
                diag_grad[0] = (left_vectors * right_vectors).sum()
                if W_left.ndimension() == 3:
                    left_factor = bdsmm(W_left.transpose(1, 2),
                                        left_factor.t(1, 2)).t(1, 2)
                    right_factor = bdsmm(W_right.t(1, 2),
                                         right_factor.t(1, 2)).t(1, 2)
                else:
                    left_factor = torch.dsmm(W_left.t(), left_factor.t()).t()
                    right_factor = torch.dsmm(W_right.t(),
                                              right_factor.t()).t()
                return sym_toeplitz_derivative_quadratic_form(
                    left_factor, right_factor), None, None, diag_grad
예제 #12
0
def test_interpolation():
    x = torch.linspace(0.01, 1, 100)
    grid = torch.linspace(-0.05, 1.05, 50)
    J, C = Interpolation().interpolate(grid, x)
    W = utils.toeplitz.index_coef_to_sparse(J, C, len(grid))
    test_func_grid = grid.pow(2)
    test_func_x = x.pow(2)

    interp_func_x = torch.dsmm(W, test_func_grid.unsqueeze(1)).squeeze()

    assert all(
        torch.abs(interp_func_x - test_func_x) / (test_func_x + 1e-10) < 1e-5)
예제 #13
0
    def backward(self, dy):
        '''
        backprop with meprop
        if k is invalid (<=0 or > output feature), no top-k selection is applied
        '''
        x, w, b = self.saved_tensors
        dx = dw = db = None

        if self.k > 0 and self.k < w.size(1):  # backprop with top-k selection
            _, indices = dy.abs().topk(self.k)
            if self.sparse:  # using sparse matrix multiplication
                values = dy.gather(-1, indices).view(-1)
                row_indices = torch.arange(
                    0,
                    dy.size()[0]).long().cuda().unsqueeze_(-1).repeat(
                        1, self.k)
                indices = torch.stack([row_indices.view(-1), indices.view(-1)])
                pdy = torch.cuda.sparse.FloatTensor(indices, values, dy.size())
                if self.needs_input_grad[0]:
                    dx = torch.dsmm(pdy, w.t())
                if self.needs_input_grad[1]:
                    dw = torch.dsmm(pdy.t(), x).t()
            else:
                pdy = torch.cuda.FloatTensor(dy.size()).zero_().scatter_(
                    -1, indices, dy.gather(-1, indices))
                if self.needs_input_grad[0]:
                    dx = torch.mm(pdy, w.t())
                if self.needs_input_grad[1]:
                    dw = torch.mm(x.t(), pdy)
        else:  # backprop without top-k selection
            if self.needs_input_grad[0]:
                dx = torch.mm(dy, w.t())
            if self.needs_input_grad[1]:
                dw = torch.mm(x.t(), dy)

        if self.needs_input_grad[2]:
            db = torch.mv(dy.t(), self.add_buffer)

        return dx, dw, db
예제 #14
0
def bdsmm(sparse, dense):
    """
    Batch dense-sparse matrix multiply
    """
    batch_size, n_rows, n_cols = sparse.size()
    batch_assignment = sparse._indices()[0]
    indices = sparse._indices()[1:].clone()
    indices[0].add_(n_rows, batch_assignment)
    indices[1].add_(n_cols, batch_assignment)
    sparse_2d = sparse.__class__(
        indices, sparse._values(),
        torch.Size((batch_size * n_rows, batch_size * n_cols)))

    dense_2d = dense.contiguous().view(batch_size * n_cols, -1)
    res = torch.dsmm(sparse_2d, dense_2d)
    res = res.view(batch_size, n_rows, -1)
    return res
예제 #15
0
 def forward(self, dense):
     if self.sparse.ndimension() == 3:
         return bdsmm(self.sparse, dense)
     else:
         return torch.dsmm(self.sparse, dense)
        mean_x = self.mean_module(x)
        covar_x = self.grid_covar_module(x)
        return GaussianRandomVariable(mean_x, covar_x)


prior_observation_model = Model()
pred = prior_observation_model(x)
lazy_toeplitz_var = pred.covar()
T = utils.toeplitz.sym_toeplitz(lazy_toeplitz_var.c.data)
W_left = utils.toeplitz.index_coef_to_sparse(lazy_toeplitz_var.J_left,
                                             lazy_toeplitz_var.C_left,
                                             len(lazy_toeplitz_var.c))
W_right = utils.toeplitz.index_coef_to_sparse(lazy_toeplitz_var.J_right,
                                              lazy_toeplitz_var.C_right,
                                              len(lazy_toeplitz_var.c))
WTW = torch.dsmm(W_right, torch.dsmm(W_left, T).t()) + torch.diag(lazy_toeplitz_var.added_diag.data)


def test_explicit_interpolate_T():
    WT_res = lazy_toeplitz_var.explicit_interpolate_T(lazy_toeplitz_var.J_left, lazy_toeplitz_var.C_left)
    WT_actual = torch.dsmm(W_left, T)
    assert utils.approx_equal(WT_res.data, WT_actual)


def test_evaluate():
    WTW_res = lazy_toeplitz_var.evaluate()
    assert utils.approx_equal(WTW_res, WTW)


def test_diag():
    diag_actual = torch.diag(WTW)
예제 #17
0
 def backward(self, grad_output):
     return torch.dsmm(self.sparse.t(), grad_output)
예제 #18
0
 def forward(self, dense):
     return torch.dsmm(self.sparse, dense)
def test_explicit_interpolate_T():
    WT_res = lazy_toeplitz_var.explicit_interpolate_T(lazy_toeplitz_var.J_left, lazy_toeplitz_var.C_left)
    WT_actual = torch.dsmm(W_left, T)
    assert utils.approx_equal(WT_res.data, WT_actual)
예제 #20
0
 def backward(self, grad_output):
     if self.sparse.ndimension() == 3:
         return bdsmm(self.sparse.transpose(1, 2), grad_output)
     else:
         return torch.dsmm(self.sparse.t(), grad_output)
pred = prior_observation_model(x)
lazy_kronecker_product_var = pred.covar()
Ts = torch.zeros(lazy_kronecker_product_var.columns.size()[0],
                 lazy_kronecker_product_var.columns.size()[1],
                 lazy_kronecker_product_var.columns.size()[1])
for i in range(lazy_kronecker_product_var.columns.size()[0]):
    Ts[i] = utils.toeplitz.sym_toeplitz(
        lazy_kronecker_product_var.columns[i].data)
K = kronecker_product(Ts)
W_left = list_of_indices_and_values_to_sparse(
    lazy_kronecker_product_var.J_lefts, lazy_kronecker_product_var.C_lefts,
    lazy_kronecker_product_var.columns)
W_right = list_of_indices_and_values_to_sparse(
    lazy_kronecker_product_var.J_rights, lazy_kronecker_product_var.C_rights,
    lazy_kronecker_product_var.columns)
WKW = torch.dsmm(W_right,
                 torch.dsmm(W_left, K).t()) + torch.diag(
                     lazy_kronecker_product_var.added_diag.data)


def test_evaluate():
    WKW_res = lazy_kronecker_product_var.evaluate()
    assert utils.approx_equal(WKW_res, WKW)


def test_diag():
    diag_actual = torch.diag(WKW)
    diag_res = lazy_kronecker_product_var.diag()
    assert utils.approx_equal(diag_res.data, diag_actual)


def test_get_item_on_interpolated_variable_no_diagonal():