def _matmul(self, rhs):
        output_shape = _matmul_broadcast_shape(self.shape, rhs.shape)
        output_batch_shape = output_shape[:-2]

        is_vector = False
        if rhs.ndimension() == 1:
            rhs = rhs.unsqueeze(1)
            is_vector = True

        # Here we have a root decomposition
        if isinstance(self.left_lazy_tensor, RootLazyTensor):
            left_root = self.left_lazy_tensor.root.evaluate()
            left_res = rhs.unsqueeze(-2) * left_root.unsqueeze(-1)

            rank = left_root.size(-1)
            n = self.size(-1)
            m = rhs.size(-1)
            # Now implement the formula (A . B) v = diag(A D_v B)
            left_res = left_res.view(*output_batch_shape, n, rank * m)
            left_res = self.right_lazy_tensor._matmul(left_res)
            left_res = left_res.view(*output_batch_shape, n, rank, m)
            res = left_res.mul_(left_root.unsqueeze(-1)).sum(-2)
        # This is the case where we're not doing a root decomposition, because the matrix is too small
        else:
            res = (self.left_lazy_tensor.evaluate() * self.right_lazy_tensor.evaluate()).matmul(rhs)
        res = res.squeeze(-1) if is_vector else res
        return res
Ejemplo n.º 2
0
    def _cholesky_solve(self, rhs, upper: bool = False):
        # TODO: Figure out how to deal with this with TriangularLazyTensor if returned by _cholesky
        output_shape = _matmul_broadcast_shape(self.shape, rhs.shape)
        if rhs.shape != output_shape:
            rhs = rhs.expand(*output_shape)

        rhs = self._move_repeat_batches_to_columns(rhs, output_shape)
        res = self.base_lazy_tensor._cholesky_solve(rhs, upper=upper)
        res = self._move_repeat_batches_back(res, output_shape)
        return res
Ejemplo n.º 3
0
    def _quad_form_derivative(self, left_vectors, right_vectors):
        if self.is_square:
            left_output_shape = _matmul_broadcast_shape(
                self.shape, left_vectors.shape)
            if left_output_shape != left_vectors.shape:
                left_vectors = left_vectors.expand(left_output_shape)

            right_output_shape = _matmul_broadcast_shape(
                self.shape, right_vectors.shape)
            if right_output_shape != right_vectors.shape:
                right_vectors = right_vectors.expand(right_output_shape)

            left_vectors = self._move_repeat_batches_to_columns(
                left_vectors, left_output_shape)
            right_vectors = self._move_repeat_batches_to_columns(
                right_vectors, right_output_shape)

            return self.base_lazy_tensor._quad_form_derivative(
                left_vectors, right_vectors)
        else:
            return super()._quad_form_derivative(left_vectors, right_vectors)
Ejemplo n.º 4
0
def _matmul(lazy_tensors, kp_shape, rhs):
    output_shape = _matmul_broadcast_shape(kp_shape, rhs.shape)
    output_batch_shape = output_shape[:-2]

    res = rhs.contiguous().expand(*output_batch_shape, *rhs.shape[-2:])
    num_cols = rhs.size(-1)
    for lazy_tensor in lazy_tensors:
        res = res.view(*output_batch_shape, lazy_tensor.size(-1), -1)
        factor = lazy_tensor._matmul(res)
        factor = factor.view(*output_batch_shape, lazy_tensor.size(-2), -1,
                             num_cols).transpose(-3, -2)
        res = factor.reshape(*output_batch_shape, -1, num_cols)
    return res
Ejemplo n.º 5
0
    def _size(self):
        if settings.debug.on():
            if hasattr(self.kernel, "size"):
                raise RuntimeError(
                    "Kernels must define `num_outputs_per_input` and should not define `size`"
                )

        x1 = self.x1
        x2 = self.x2
        num_outputs_per_input = self.kernel.num_outputs_per_input(x1, x2)
        num_rows = x1.size(-2) * num_outputs_per_input
        num_cols = x2.size(-2) * num_outputs_per_input

        # Default case - when we're not using broadcasting
        # We write this case special for efficiency
        if x1.shape[:
                    -2] == x2.shape[:
                                    -2] and x1.shape[:
                                                     -2] == self.kernel.batch_shape:
            expected_size = self.kernel.batch_shape + torch.Size(
                (num_rows, num_cols))

        # When we're using broadcasting
        else:
            expected_size = broadcasting._matmul_broadcast_shape(
                torch.Size([*x1.shape[:-2], num_rows,
                            x1.size(-1)]),
                torch.Size([*x2.shape[:-2],
                            x2.size(-1), num_cols]),
                error_msg=
                "x1 and x2 were not broadcastable to a proper kernel shape. "
                "Got x1.shape = {} and x2.shape = {}".format(
                    str(x1.shape), str(x2.shape)),
            )
            expected_size = (broadcasting._mul_broadcast_shape(
                expected_size[:-2],
                self.kernel.batch_shape,
                error_msg=
                (f"x1 and x2 were not broadcastable with kernel of batch_shape {self.kernel.batch_shape}. "
                 f"Got x1.shape = {x1.shape} and x2.shape = {x2.shape}"),
            ) + expected_size[-2:])

        # Handle when the last dim is batch
        if self.last_dim_is_batch:
            expected_size = expected_size[:-2] + x1.shape[-1:] + expected_size[
                -2:]
        return expected_size
Ejemplo n.º 6
0
    def _matmul(self, rhs):
        output_device = self.device if self.device is not None else rhs.device
        # make a copy of `rhs` on each device
        rhs_ = []
        for d in self.devices:
            if d != rhs.device:
                rhs_.append(rhs.to(d))
            else:
                rhs_.append(rhs)

        if self.cat_dim == -2:
            res_list = [
                t._matmul(rhs) for t, rhs in zip(self.lazy_tensors, rhs_)
            ]
            # copy result back to output device
            res_list = [x.to(output_device) for x in res_list]
            res = torch.cat(res_list, dim=-2)
        elif self.cat_dim == -1:
            curr_idx = 0
            res_list = []
            index = [slice(None, None, None) for _ in range(rhs.ndimension())]
            for t, size, rhs in zip(self.lazy_tensors, self.cat_dim_sizes,
                                    rhs_):
                index[-2] = slice(curr_idx, curr_idx + size, None)
                res_list.append(t._matmul(rhs[index]))
                curr_idx += size
            # copy result back to output device and sum
            res_list = [x.to(output_device) for x in res_list]
            res = 0.0
            for x in res_list:
                res = res + x
        else:
            output_shape = _matmul_broadcast_shape(self.shape, rhs.shape)
            rhs = rhs.expand(*output_shape[:-2], *rhs.shape[-2:])
            curr_idx = 0
            res_list = []
            for t, size in zip(self.lazy_tensors, self.cat_dim_sizes):
                sub_rhs = rhs.narrow(self.cat_dim, curr_idx, size)
                res_list.append(t._matmul(sub_rhs))
                curr_idx += size
            # copy result back to output device
            res_list = [x.to(output_device) for x in res_list]
            res = torch.cat(res_list, dim=self.cat_dim)

        return res
Ejemplo n.º 7
0
    def inv_quad_logdet(self,
                        inv_quad_rhs=None,
                        logdet=False,
                        reduce_inv_quad=True):
        if not self.is_square:
            raise RuntimeError(
                "inv_quad_logdet only operates on (batches of) square (positive semi-definite) LazyTensors. "
                "Got a {} of size {}.".format(self.__class__.__name__,
                                              self.size()))

        if inv_quad_rhs is not None:
            if self.dim() != inv_quad_rhs.dim():
                raise RuntimeError(
                    "LazyTensor (size={}) and right-hand-side Tensor (size={}) should have the same number "
                    "of dimensions.".format(self.shape, inv_quad_rhs.shape))
            elif self.batch_shape != inv_quad_rhs.shape[:-2] or self.shape[
                    -1] != inv_quad_rhs.shape[-2]:
                raise RuntimeError(
                    "LazyTensor (size={}) cannot be multiplied with right-hand-side Tensor (size={})."
                    .format(self.shape, inv_quad_rhs.shape))

        if inv_quad_rhs is not None:
            output_shape = _matmul_broadcast_shape(self.shape,
                                                   inv_quad_rhs.shape)
            inv_quad_rhs = self._move_repeat_batches_to_columns(
                inv_quad_rhs, output_shape)

        inv_quad_term, logdet_term = self.base_lazy_tensor.inv_quad_logdet(
            inv_quad_rhs, logdet, reduce_inv_quad=False)

        if inv_quad_term is not None and inv_quad_term.numel():
            inv_quad_term = inv_quad_term.view(*inv_quad_term.shape[:-1], -1,
                                               1, self.batch_repeat.numel())
            output_shape = list(output_shape)
            output_shape[-2] = 1
            inv_quad_term = self._move_repeat_batches_back(
                inv_quad_term, output_shape).squeeze(-2)
            if reduce_inv_quad:
                inv_quad_term = inv_quad_term.sum(-1)

        if logdet_term is not None and logdet_term.numel():
            logdet_term = logdet_term.repeat(*self.batch_repeat)

        return inv_quad_term, logdet_term
Ejemplo n.º 8
0
    def _matmul(self, rhs):
        output_shape = _matmul_broadcast_shape(self.shape, rhs.shape)

        # only attempt broadcasting if the non-batch dimensions are the same
        if self.is_square:
            if rhs.shape != output_shape:
                rhs = rhs.expand(*output_shape)

            rhs = self._move_repeat_batches_to_columns(rhs, output_shape)
            res = self.base_lazy_tensor._matmul(rhs)
            res = self._move_repeat_batches_back(res, output_shape)
            return res
        else:
            # otherwise, we will rely on base tensor broadcasting
            res = self.base_lazy_tensor._matmul(rhs)
            if res.shape != output_shape:
                res = res.expand(*output_shape)

            return res
def toeplitz_matmul(toeplitz_column, toeplitz_row, tensor):
    """
    Performs multiplication T * M where the matrix T is Toeplitz.
    Args:
        - toeplitz_column (vector n or b x n) - First column of the Toeplitz matrix T.
        - toeplitz_row (vector n or b x n) - First row of the Toeplitz matrix T.
        - tensor (matrix n x p or b x n x p) - Matrix or vector to multiply the Toeplitz matrix with.
    Returns:
        - tensor (n x p or b x n x p) - The result of the matrix multiply T * M.
    """
    if toeplitz_column.size() != toeplitz_row.size():
        raise RuntimeError(
            "c and r should have the same length (Toeplitz matrices are necessarily square)."
        )

    toeplitz_shape = torch.Size(
        (*toeplitz_column.shape, toeplitz_row.size(-1)))
    output_shape = broadcasting._matmul_broadcast_shape(
        toeplitz_shape, tensor.shape)
    broadcasted_t_shape = output_shape[:-1] if tensor.dim(
    ) > 1 else output_shape

    if tensor.ndimension() == 1:
        tensor = tensor.unsqueeze(-1)
    toeplitz_column = toeplitz_column.expand(*broadcasted_t_shape).reshape(
        -1, toeplitz_column.size(-1))
    toeplitz_row = toeplitz_row.expand(*broadcasted_t_shape).reshape(
        -1, toeplitz_row.size(-1))
    tensor = tensor.expand(*output_shape).reshape(-1, *tensor.shape[-2:])

    if not torch.equal(toeplitz_column[:, 0], toeplitz_row[:, 0]):
        raise RuntimeError(
            "The first column and first row of the Toeplitz matrix should have "
            "the same first element, otherwise the value of T[0,0] is ambiguous. "
            "Got: c[0]={} and r[0]={}".format(toeplitz_column[0],
                                              toeplitz_row[0]))

    if type(toeplitz_column) != type(toeplitz_row) or type(
            toeplitz_column) != type(tensor):
        raise RuntimeError("The types of all inputs to ToeplitzMV must match.")

    batch_size, orig_size, num_rhs = tensor.size()
    r_reverse = toeplitz_row[:, 1:].flip(dims=(1, ))

    c_r_rev = torch.zeros(batch_size,
                          orig_size + r_reverse.size(1),
                          dtype=tensor.dtype,
                          device=tensor.device)
    c_r_rev[:, :orig_size] = toeplitz_column
    c_r_rev[:, orig_size:] = r_reverse

    temp_tensor = torch.zeros(batch_size,
                              2 * orig_size - 1,
                              num_rhs,
                              dtype=toeplitz_column.dtype,
                              device=toeplitz_column.device)
    temp_tensor[:, :orig_size, :] = tensor

    fft_M = fft.fft1(temp_tensor.transpose(1, 2).contiguous())
    fft_c = fft.fft1(c_r_rev).unsqueeze(1).expand_as(fft_M)
    fft_product = torch.zeros_like(fft_M)

    fft_product[:, :, :, 0].addcmul_(fft_c[:, :, :, 0], fft_M[:, :, :, 0])
    fft_product[:, :, :, 0].addcmul_(fft_c[:, :, :, 1],
                                     fft_M[:, :, :, 1],
                                     value=-1)
    fft_product[:, :, :, 1].addcmul_(fft_c[:, :, :, 1], fft_M[:, :, :, 0])
    fft_product[:, :, :, 1].addcmul_(fft_c[:, :, :, 0], fft_M[:, :, :, 1])

    output = fft.ifft1(fft_product).transpose(1, 2)
    output = output[:, :orig_size, :]

    output = output.view(*output_shape)
    return output
 def _size(self):
     return _matmul_broadcast_shape(self.left_lazy_tensor.shape,
                                    self.right_lazy_tensor.shape)