def bdsmm(sparse, dense): """ Batch dense-sparse matrix multiply """ if sparse.ndimension() > 2: batch_size, n_rows, n_cols = sparse.size() batch_assignment = sparse._indices()[0] indices = sparse._indices()[1:].clone() indices[0].add_(n_rows, batch_assignment) indices[1].add_(n_cols, batch_assignment) sparse_2d = sparse.new( indices, sparse._values(), torch.Size((batch_size * n_rows, batch_size * n_cols))) if dense.size(0) == 1: dense = dense.repeat(batch_size, 1, 1) dense_2d = dense.contiguous().view(batch_size * n_cols, -1) res = torch.dsmm(sparse_2d, dense_2d) res = res.view(batch_size, n_rows, -1) return res elif dense.ndimension() == 3: batch_size, _, n_cols = dense.size() res = torch.dsmm( sparse, dense.transpose(0, 1).contiguous().view(-1, batch_size * n_cols)) res = res.view(-1, batch_size, n_cols) res = res.transpose(0, 1).contiguous() return res else: return torch.dsmm(sparse, dense)
def closure(left_vectors, right_vectors): if left_vectors.ndimension() == 1: left_factor = left_vectors.unsqueeze(0) right_factor = right_vectors.unsqueeze(0) else: left_factor = left_vectors right_factor = right_vectors if len(args) == 1: columns, = args return kp_sym_toeplitz_derivative_quadratic_form(columns, left_factor, right_factor), elif len(args) == 3: columns, W_left, W_right = args left_factor = torch.dsmm(W_left.t(), left_factor.t()).t() right_factor = torch.dsmm(W_right.t(), right_factor.t()).t() res = kp_sym_toeplitz_derivative_quadratic_form(columns, left_factor, right_factor) return res, None, None elif len(args) == 4: columns, W_left, W_right, added_diag, = args diag_grad = columns.new(len(added_diag)).zero_() diag_grad[0] = (left_factor * right_factor).sum() left_factor = torch.dsmm(W_left.t(), left_factor.t()).t() right_factor = torch.dsmm(W_right.t(), right_factor.t()).t() res = kp_sym_toeplitz_derivative_quadratic_form(columns, left_factor, right_factor) return res, None, None, diag_grad
def bdsmm(sparse, dense): """ Batch dense-sparse matrix multiply """ # Make the batch sparse matrix into a block-diagonal matrix if sparse.ndimension() > 2: # Expand the tensors to account for broadcasting output_shape = _matmul_broadcast_shape(sparse.shape, dense.shape) expanded_sparse_shape = output_shape[:-2] + sparse.shape[-2:] unsqueezed_sparse_shape = [1 for _ in range(len(output_shape) - sparse.dim())] + list(sparse.shape) repeat_sizes = tuple( output_size // sparse_size for output_size, sparse_size in zip(expanded_sparse_shape, unsqueezed_sparse_shape) ) sparse = sparse_repeat(sparse, *repeat_sizes) dense = dense.expand(*output_shape[:-2], dense.size(-2), dense.size(-1)) # Figure out how much need to be added to the row/column indices to create # a block-diagonal matrix *batch_shape, num_rows, num_cols = sparse.shape batch_size = torch.Size(batch_shape).numel() batch_multiplication_factor = torch.tensor( [torch.Size(batch_shape[i + 1 :]).numel() for i in range(len(batch_shape))], dtype=torch.long, device=sparse.device, ) if batch_multiplication_factor.is_cuda: batch_assignment = (sparse._indices()[:-2].float().t() @ batch_multiplication_factor.float()).long() else: batch_assignment = sparse._indices()[:-2].t() @ batch_multiplication_factor # Create block-diagonal sparse tensor indices = sparse._indices()[-2:].clone() indices[0].add_(batch_assignment, alpha=num_rows) indices[1].add_(batch_assignment, alpha=num_cols) sparse_2d = torch.sparse_coo_tensor( indices, sparse._values(), torch.Size((batch_size * num_rows, batch_size * num_cols)), dtype=sparse._values().dtype, device=sparse._values().device, ) dense_2d = dense.reshape(batch_size * num_cols, -1) res = torch.dsmm(sparse_2d, dense_2d) res = res.view(*batch_shape, num_rows, -1) return res elif dense.dim() > 2: *batch_shape, num_rows, num_cols = dense.size() batch_size = torch.Size(batch_shape).numel() dense = dense.view(batch_size, num_rows, num_cols) res = torch.dsmm(sparse, dense.transpose(0, 1).reshape(-1, batch_size * num_cols)) res = res.view(-1, batch_size, num_cols) res = res.transpose(0, 1).reshape(*batch_shape, -1, num_cols) return res else: return torch.dsmm(sparse, dense)
def interpolated_sym_toeplitz_matmul(toeplitz_column, vector, W_left=None, W_right=None, noise_diag=None): """ Given a interpolated symmetric Toeplitz matrix W_left*T*W_right, plus possibly an additional diagonal component s*I, compute a product with some vector or matrix vector. Args: - toeplitz_column (vector matrix) - First column of the symmetric Toeplitz matrix T - W_left (sparse matrix nxm) - Left interpolation matrix - W_right (sparse matrix pxm) - Right interpolation matrix - vector (matrix pxk) - Vector (k=1) or matrix (k>1) to multiply WTW with - noise_diag (vector p) - If not none, add (s*I)vector to WTW at the end. Returns: - matrix nxk """ noise_term = None ndim = vector.ndimension() if ndim == 1: vector = vector.unsqueeze(1) if noise_diag is not None: noise_term = noise_diag.unsqueeze(-1).expand_as(vector) * vector if W_left is not None: # Get W_{r}^{T}vector if W_left.ndimension() == 3: # Batch mode: Wt_times_v = bdsmm(W_right.transpose(1, 2), vector) # Get (TW_{r}^{T})vector TWt_v = sym_toeplitz_matmul(toeplitz_column, Wt_times_v) # Get (W_{l}TW_{r}^{T})vector WTWt_v = bdsmm(W_left, TWt_v) else: Wt_times_v = torch.dsmm(W_right.t(), vector) # Get (TW_{r}^{T})vector TWt_v = sym_toeplitz_matmul(toeplitz_column, Wt_times_v) # Get (W_{l}TW_{r}^{T})vector WTWt_v = torch.dsmm(W_left, TWt_v) else: WTWt_v = sym_toeplitz_matmul(toeplitz_column, vector) if noise_term is not None: # Get (W_{l}TW_{r}^{T} + \sigma^{2}I)vector WTWt_v = WTWt_v + noise_term if ndim == 1: WTWt_v = WTWt_v.squeeze(1) return WTWt_v
def interpolated_sym_toeplitz_mul(toeplitz_column, vector, W_left=None, W_right=None, noise_diag=None): """ Given a interpolated symmetric Toeplitz matrix W_left*T*W_right, plus possibly an additional diagonal component s*I, compute a product with some vector or matrix vector. Args: - toeplitz_column (vector matrix) - First column of the symmetric Toeplitz matrix T - W_left (sparse matrix nxm) - Left interpolation matrix - W_right (sparse matrix pxm) - Right interpolation matrix - vector (matrix pxk) - Vector (k=1) or matrix (k>1) to multiply WTW with - noise_diag (vector p) - If not none, add (s*I)vector to WTW at the end. Returns: - matrix nxk - The result of multiplying (WTW + sI)vector if noise_diag exists, or (WTW)vector otherwise. """ noise_term = None if vector.ndimension() == 1: if noise_diag is not None: noise_term = noise_diag.expand_as(vector) * vector vector = vector.unsqueeze(1) mul_func = utils.toeplitz.toeplitz_mv else: if noise_diag is not None: noise_term = noise_diag.unsqueeze(1).expand_as(vector) * vector mul_func = utils.toeplitz.toeplitz_mm if W_left is not None: # Get W_{r}^{T}vector Wt_times_v = torch.dsmm(W_right.t(), vector) # Get (TW_{r}^{T})vector TWt_v = mul_func(toeplitz_column, toeplitz_column, Wt_times_v.squeeze()) if TWt_v.ndimension() == 1: TWt_v.unsqueeze_(1) # Get (W_{l}TW_{r}^{T})vector WTWt_v = torch.dsmm(W_left, TWt_v).squeeze() else: WTWt_v = mul_func(toeplitz_column, toeplitz_column, vector) if noise_term is not None: # Get (W_{l}TW_{r}^{T} + \sigma^{2}I)vector WTWt_v = WTWt_v + noise_term return WTWt_v
def test_shape(di, dj, dk): x = self._gen_sparse(2, 20, [di, dj])[0] y = self.randn(dj, dk) res = torch.dsmm(x, y) expected = torch.mm(self.safeToDense(x), y) self.assertEqual(res, expected)
def backward(self, grad_output): grad_output_value = grad_output.squeeze()[0] toeplitz_column, labels, noise_diag = self.saved_tensors mat_inv_y = self.mat_inv_y mv_closure = self.mv_closure mat_grad = None y_grad = None noise_grad = None if self.needs_input_grad[0]: y_mat_inv_W_left = torch.dsmm(self.W_left.t(), mat_inv_y.unsqueeze(1)).t() W_right_mat_inv_y = torch.dsmm(self.W_right.t(), mat_inv_y.unsqueeze(1)) quad_form_part = sym_toeplitz_derivative_quadratic_form( y_mat_inv_W_left.squeeze(), W_right_mat_inv_y.squeeze()) log_det_part = torch.zeros(len(toeplitz_column)) sample_matrix = torch.sign( torch.randn(len(labels), self.num_samples)) left_vectors = torch.dsmm( self.W_left.t(), LinearCG().solve(mv_closure, sample_matrix)).t() right_vectors = torch.dsmm(self.W_right.t(), sample_matrix).t() for left_vector, right_vector in zip(left_vectors, right_vectors): log_det_part += sym_toeplitz_derivative_quadratic_form( left_vector, right_vector) log_det_part.div_(self.num_samples) mat_grad = quad_form_part - log_det_part mat_grad.mul_(0.5 * grad_output_value) if self.needs_input_grad[1]: # Need gradient with respect to labels y_grad = mat_inv_y.mul_(-grad_output_value) if self.needs_input_grad[2]: quad_form_part = mat_inv_y.dot(mat_inv_y) noise_grad = toeplitz_column.new().resize_(1).fill_( quad_form_part - self.tr_inv) noise_grad.mul_(0.5 * grad_output_value) return mat_grad, y_grad, noise_grad
def test_shape(di, dj, dk): x = self._gen_sparse(2, 20, [di, dj], is_cuda)[0] y = torch.randn(dj, dk) if is_cuda: y = y.cuda() res = torch.dsmm(x, y) expected = torch.mm(x.to_dense(), y) self.assertEqual(res, expected)
def kp_interpolated_toeplitz_matmul(toeplitz_columns, tensor, interp_left=None, interp_right=None, noise_diag=None): """ Given an interpolated matrix interp_left * T_1 \otimes ... \otimes T_d * interp_right, plus possibly an additional diagonal component s*I, compute a product with some tensor or matrix tensor, where T_i is symmetric Toeplitz matrices. Args: - toeplitz_columns (d x m matrix) - columns of d toeplitz matrix T_i with length n_i - interp_left (sparse matrix nxm) - Left interpolation matrix - interp_right (sparse matrix pxm) - Right interpolation matrix - tensor (matrix p x k) - Vector (k=1) or matrix (k>1) to multiply WKW with - noise_diag (tensor p) - If not none, add (s*I)tensor to WKW at the end. Returns: - tensor """ output_dims = tensor.ndimension() noise_term = None if output_dims == 1: tensor = tensor.unsqueeze(1) if noise_diag is not None: noise_term = noise_diag.unsqueeze(1).expand_as(tensor) * tensor if interp_left is not None: # Get interp_{r}^{T} tensor interp_right_tensor = torch.dsmm(interp_right.t(), tensor) # Get (T interp_{r}^{T}) tensor rhs = kronecker_product_toeplitz_matmul(toeplitz_columns, toeplitz_columns, interp_right_tensor) # Get (interp_{l} T interp_{r}^{T})tensor output = torch.dsmm(interp_left, rhs) else: output = kronecker_product_toeplitz_matmul(toeplitz_columns, toeplitz_columns, tensor) if noise_term is not None: # Get (interp_{l} T interp_{r}^{T} + \sigma^{2}I)tensor output = output + noise_term if output_dims == 1: output = output.squeeze(1) return output
def closure(left_vectors, right_vectors): if left_vectors.ndimension() == 1: left_factor = left_vectors.unsqueeze(0) right_factor = right_vectors.unsqueeze(0) else: left_factor = left_vectors right_factor = right_vectors if len(args) == 1: toeplitz_column, = args return sym_toeplitz_derivative_quadratic_form( left_factor, right_factor), elif len(args) == 3: toeplitz_column, W_left, W_right = args if W_left.ndimension() == 3: left_factor = bdsmm(W_left.transpose(1, 2), left_factor.transpose(1, 2)).transpose( 1, 2) right_factor = bdsmm(W_right.transpose(1, 2), right_factor.transpose(1, 2)).transpose( 1, 2) else: left_factor = torch.dsmm(W_left.t(), left_factor.t()).t() right_factor = torch.dsmm(W_right.t(), right_factor.t()).t() return sym_toeplitz_derivative_quadratic_form( left_factor, right_factor), None, None elif len(args) == 4: toeplitz_column, W_left, W_right, added_diag, = args diag_grad = left_vectors.new(len(added_diag)).zero_() diag_grad[0] = (left_vectors * right_vectors).sum() if W_left.ndimension() == 3: left_factor = bdsmm(W_left.transpose(1, 2), left_factor.t(1, 2)).t(1, 2) right_factor = bdsmm(W_right.t(1, 2), right_factor.t(1, 2)).t(1, 2) else: left_factor = torch.dsmm(W_left.t(), left_factor.t()).t() right_factor = torch.dsmm(W_right.t(), right_factor.t()).t() return sym_toeplitz_derivative_quadratic_form( left_factor, right_factor), None, None, diag_grad
def test_interpolation(): x = torch.linspace(0.01, 1, 100) grid = torch.linspace(-0.05, 1.05, 50) J, C = Interpolation().interpolate(grid, x) W = utils.toeplitz.index_coef_to_sparse(J, C, len(grid)) test_func_grid = grid.pow(2) test_func_x = x.pow(2) interp_func_x = torch.dsmm(W, test_func_grid.unsqueeze(1)).squeeze() assert all( torch.abs(interp_func_x - test_func_x) / (test_func_x + 1e-10) < 1e-5)
def backward(self, dy): ''' backprop with meprop if k is invalid (<=0 or > output feature), no top-k selection is applied ''' x, w, b = self.saved_tensors dx = dw = db = None if self.k > 0 and self.k < w.size(1): # backprop with top-k selection _, indices = dy.abs().topk(self.k) if self.sparse: # using sparse matrix multiplication values = dy.gather(-1, indices).view(-1) row_indices = torch.arange( 0, dy.size()[0]).long().cuda().unsqueeze_(-1).repeat( 1, self.k) indices = torch.stack([row_indices.view(-1), indices.view(-1)]) pdy = torch.cuda.sparse.FloatTensor(indices, values, dy.size()) if self.needs_input_grad[0]: dx = torch.dsmm(pdy, w.t()) if self.needs_input_grad[1]: dw = torch.dsmm(pdy.t(), x).t() else: pdy = torch.cuda.FloatTensor(dy.size()).zero_().scatter_( -1, indices, dy.gather(-1, indices)) if self.needs_input_grad[0]: dx = torch.mm(pdy, w.t()) if self.needs_input_grad[1]: dw = torch.mm(x.t(), pdy) else: # backprop without top-k selection if self.needs_input_grad[0]: dx = torch.mm(dy, w.t()) if self.needs_input_grad[1]: dw = torch.mm(x.t(), dy) if self.needs_input_grad[2]: db = torch.mv(dy.t(), self.add_buffer) return dx, dw, db
def bdsmm(sparse, dense): """ Batch dense-sparse matrix multiply """ batch_size, n_rows, n_cols = sparse.size() batch_assignment = sparse._indices()[0] indices = sparse._indices()[1:].clone() indices[0].add_(n_rows, batch_assignment) indices[1].add_(n_cols, batch_assignment) sparse_2d = sparse.__class__( indices, sparse._values(), torch.Size((batch_size * n_rows, batch_size * n_cols))) dense_2d = dense.contiguous().view(batch_size * n_cols, -1) res = torch.dsmm(sparse_2d, dense_2d) res = res.view(batch_size, n_rows, -1) return res
def forward(self, dense): if self.sparse.ndimension() == 3: return bdsmm(self.sparse, dense) else: return torch.dsmm(self.sparse, dense)
mean_x = self.mean_module(x) covar_x = self.grid_covar_module(x) return GaussianRandomVariable(mean_x, covar_x) prior_observation_model = Model() pred = prior_observation_model(x) lazy_toeplitz_var = pred.covar() T = utils.toeplitz.sym_toeplitz(lazy_toeplitz_var.c.data) W_left = utils.toeplitz.index_coef_to_sparse(lazy_toeplitz_var.J_left, lazy_toeplitz_var.C_left, len(lazy_toeplitz_var.c)) W_right = utils.toeplitz.index_coef_to_sparse(lazy_toeplitz_var.J_right, lazy_toeplitz_var.C_right, len(lazy_toeplitz_var.c)) WTW = torch.dsmm(W_right, torch.dsmm(W_left, T).t()) + torch.diag(lazy_toeplitz_var.added_diag.data) def test_explicit_interpolate_T(): WT_res = lazy_toeplitz_var.explicit_interpolate_T(lazy_toeplitz_var.J_left, lazy_toeplitz_var.C_left) WT_actual = torch.dsmm(W_left, T) assert utils.approx_equal(WT_res.data, WT_actual) def test_evaluate(): WTW_res = lazy_toeplitz_var.evaluate() assert utils.approx_equal(WTW_res, WTW) def test_diag(): diag_actual = torch.diag(WTW)
def backward(self, grad_output): return torch.dsmm(self.sparse.t(), grad_output)
def forward(self, dense): return torch.dsmm(self.sparse, dense)
def test_explicit_interpolate_T(): WT_res = lazy_toeplitz_var.explicit_interpolate_T(lazy_toeplitz_var.J_left, lazy_toeplitz_var.C_left) WT_actual = torch.dsmm(W_left, T) assert utils.approx_equal(WT_res.data, WT_actual)
def backward(self, grad_output): if self.sparse.ndimension() == 3: return bdsmm(self.sparse.transpose(1, 2), grad_output) else: return torch.dsmm(self.sparse.t(), grad_output)
pred = prior_observation_model(x) lazy_kronecker_product_var = pred.covar() Ts = torch.zeros(lazy_kronecker_product_var.columns.size()[0], lazy_kronecker_product_var.columns.size()[1], lazy_kronecker_product_var.columns.size()[1]) for i in range(lazy_kronecker_product_var.columns.size()[0]): Ts[i] = utils.toeplitz.sym_toeplitz( lazy_kronecker_product_var.columns[i].data) K = kronecker_product(Ts) W_left = list_of_indices_and_values_to_sparse( lazy_kronecker_product_var.J_lefts, lazy_kronecker_product_var.C_lefts, lazy_kronecker_product_var.columns) W_right = list_of_indices_and_values_to_sparse( lazy_kronecker_product_var.J_rights, lazy_kronecker_product_var.C_rights, lazy_kronecker_product_var.columns) WKW = torch.dsmm(W_right, torch.dsmm(W_left, K).t()) + torch.diag( lazy_kronecker_product_var.added_diag.data) def test_evaluate(): WKW_res = lazy_kronecker_product_var.evaluate() assert utils.approx_equal(WKW_res, WKW) def test_diag(): diag_actual = torch.diag(WKW) diag_res = lazy_kronecker_product_var.diag() assert utils.approx_equal(diag_res.data, diag_actual) def test_get_item_on_interpolated_variable_no_diagonal():