Exemplo n.º 1
0
def deflated_power_iteration(
    operator,
    num_eigenthings=10,
    power_iter_steps=20,
    power_iter_err_threshold=1e-4,
    momentum=0.0,
    use_gpu=True,
    to_numpy=True,
):
    """
    Compute top k eigenvalues by repeatedly subtracting out dyads
    operator: linear operator that gives us access to matrix vector product
    num_eigenvals number of eigenvalues to compute
    power_iter_steps: number of steps per run of power iteration
    power_iter_err_threshold: early stopping threshold for power iteration
    returns: np.ndarray of top eigenvalues, np.ndarray of top eigenvectors
    """
    eigenvals = []
    eigenvecs = []
    current_op = operator
    prev_vec = None

    def _deflate(x, val, vec):
        return val * vec.dot(x) * vec

    log("beginning deflated power iteration")
    for i in range(num_eigenthings):
        log("computing eigenvalue/vector %d of %d" % (i + 1, num_eigenthings))
        eigenval, eigenvec = power_iteration(
            current_op,
            power_iter_steps,
            power_iter_err_threshold,
            momentum=momentum,
            use_gpu=use_gpu,
            init_vec=prev_vec,
        )
        log("eigenvalue %d: %.4f" % (i + 1, eigenval))

        def _new_op_fn(x, op=current_op, val=eigenval, vec=eigenvec):
            return op.apply(x) - _deflate(x, val, vec)

        current_op = LambdaOperator(_new_op_fn, operator.size)
        prev_vec = eigenvec
        eigenvals.append(eigenval)
        eigenvec = eigenvec.cpu()
        if to_numpy:
            eigenvecs.append(eigenvec.numpy())
        else:
            eigenvecs.append(eigenvec)

    eigenvals = np.array(eigenvals)
    eigenvecs = np.array(eigenvecs)

    # sort them in descending order
    sorted_inds = np.argsort(eigenvals)
    eigenvals = eigenvals[sorted_inds][::-1]
    eigenvecs = eigenvecs[sorted_inds][::-1]
    return eigenvals, eigenvecs
def deflated_power_iteration(
    operator: Operator,
    num_eigenthings: int = 10,
    power_iter_steps: int = 20,
    power_iter_err_threshold: float = 1e-4,
    momentum: float = 0.0,
    use_gpu: bool = True,
    fp16: bool = False,
    to_numpy: bool = True,
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Compute top k eigenvalues by repeatedly subtracting out dyads
    operator: linear operator that gives us access to matrix vector product
    num_eigenvals number of eigenvalues to compute
    power_iter_steps: number of steps per run of power iteration
    power_iter_err_threshold: early stopping threshold for power iteration
    returns: np.ndarray of top eigenvalues, np.ndarray of top eigenvectors
    """
    eigenvals = []
    eigenvecs = []
    current_op = operator
    prev_vec = None

    def _deflate(x, val, vec):
        return val * vec.dot(x) * vec

    utils.log("beginning deflated power iteration")
    for i in range(num_eigenthings):
        utils.log("computing eigenvalue/vector %d of %d" %
                  (i + 1, num_eigenthings))
        eigenval, eigenvec = power_iteration(
            current_op,
            power_iter_steps,
            power_iter_err_threshold,
            momentum=momentum,
            use_gpu=use_gpu,
            fp16=fp16,
            init_vec=prev_vec,
        )
        utils.log("eigenvalue %d: %.4f" % (i + 1, eigenval))

        def _new_op_fn(x, op=current_op, val=eigenval, vec=eigenvec):
            return utils.maybe_fp16(op.apply(x), fp16) - _deflate(x, val, vec)

        current_op = LambdaOperator(_new_op_fn, operator.size)
        prev_vec = eigenvec
        eigenvals.append(eigenval)
        eigenvec = eigenvec.cpu()
        if to_numpy:
            # Clone so that power_iteration can continue to use torch.
            numpy_eigenvec = eigenvec.detach().clone().numpy()
            eigenvecs.append(numpy_eigenvec)
        else:
            eigenvecs.append(eigenvec)

    eigenvals = np.array(eigenvals)
    eigenvecs = np.array(eigenvecs)

    # sort them in descending order
    sorted_inds = np.argsort(eigenvals)
    eigenvals = eigenvals[sorted_inds][::-1]
    eigenvecs = eigenvecs[sorted_inds][::-1]
    return eigenvals, eigenvecs