Beispiel #1
0
def plot_gaussian(X, axis=-1, scale=2, **kwargs):
    """
    Plot Gaussian node as a 1-D function
    
    Parameters
    ----------
    X : node
        Node with Gaussian moments.
    axis : int
        The index of the time axis.
    """
    X = X._convert(GaussianMoments)
    u_X = X.get_moments()
    x = u_X[0]
    xx = misc.get_diag(u_X[1], ndim=len(X.dims[0]))
    std = scale * np.sqrt(xx - x**2)
    #std = scale * np.sqrt(np.einsum('...ii->...i', xx) - x**2)

    return _timeseries_mean_and_error(x, std, axis=axis, **kwargs)
Beispiel #2
0
def plot_gaussian(X, axis=-1, scale=2, **kwargs):
    """
    Plot Gaussian node as a 1-D function
    
    Parameters
    ----------
    X : node
        Node with Gaussian moments.
    axis : int
        The index of the time axis.
    """
    X = X._ensure_moments(X, GaussianMoments, ndim=0)
    u_X = X.get_moments()
    x = u_X[0]
    xx = misc.get_diag(u_X[1], ndim=len(X.dims[0]))
    std = scale * np.sqrt(xx - x**2)
    #std = scale * np.sqrt(np.einsum('...ii->...i', xx) - x**2)
    
    return _timeseries_mean_and_error(x, std, axis=axis, **kwargs)
Beispiel #3
0
def gaussian_hinton(X, rows=None, cols=None, scale=1, fig=None):
    """
    Plot the Hinton diagram of a Gaussian node
    """

    if fig is None:
        fig = plt.gcf()

    # Get mean and second moment
    X = X._convert(GaussianMoments)
    (x, xx) = X.get_moments()
    ndim = len(X.dims[0])
    shape = X.get_shape(0)
    size = len(X.get_shape(0))

    # Compute standard deviation
    xx = misc.get_diag(xx, ndim=ndim)
    std = np.sqrt(xx - x**2)

    # Force explicit elements when broadcasting
    x = x * np.ones(shape)
    std = std * np.ones(shape)

    if rows is None:
        rows = np.nan
    if cols is None:
        cols = np.nan

    # Preprocess the axes to 0,...,ndim
    if rows < 0:
        rows += size
    if cols < 0:
        cols += size
    if rows < 0 or rows >= size:
        raise ValueError("Row axis invalid")
    if cols < 0 or cols >= size:
        raise ValueError("Column axis invalid")

    # Remove non-row and non-column axes that have length 1
    squeezed_shape = list(shape)
    for i in reversed(range(len(shape))):
        if shape[i] == 1 and i != rows and i != cols:
            squeezed_shape.pop(i)
            if i < cols:
                cols -= 1
            if i < rows:
                rows -= 1
    x = np.reshape(x, squeezed_shape)
    std = np.reshape(std, squeezed_shape)

    size = np.ndim(x)
    if np.isnan(cols):
        if rows != size - 1:
            cols = size - 1
        else:
            cols = size - 2
    if np.isnan(rows):
        if cols != size - 1:
            rows = size - 1
        else:
            rows = size - 2

    # Put the row and column axes to the end
    axes = [i for i in range(size) if i not in (rows, cols)] + [rows, cols]
    x = np.transpose(x, axes=axes)
    std = np.transpose(std, axes=axes)

    vmax = np.max(np.abs(x) + scale * std)

    if scale == 0:
        _subplots(_hinton, (x, 2), fig=fig, kwargs=dict(vmax=vmax))
    else:

        def plotfunc(z, e, **kwargs):
            return _hinton(z, error=e, **kwargs)

        _subplots(plotfunc, (x, 2), (scale * std, 2),
                  fig=fig,
                  kwargs=dict(vmax=vmax))
Beispiel #4
0
def gaussian_hinton(X, rows=None, cols=None, scale=1, fig=None):
    """
    Plot the Hinton diagram of a Gaussian node
    """

    if fig is None:
        fig = plt.gcf()

    # Get mean and second moment
    X = X._ensure_moments(X, GaussianMoments, ndim=0)
    (x, xx) = X.get_moments()
    ndim = len(X.dims[0])
    shape = X.get_shape(0)
    size = len(X.get_shape(0))

    # Compute standard deviation
    xx = misc.get_diag(xx, ndim=ndim)
    std = np.sqrt(xx - x**2)

    # Force explicit elements when broadcasting
    x = x * np.ones(shape)
    std = std * np.ones(shape)

    if rows is None:
        rows = np.nan
    if cols is None:
        cols = np.nan

    # Preprocess the axes to 0,...,ndim
    if rows < 0:
        rows += size
    if cols < 0:
        cols += size
    if rows < 0 or rows >= size:
        raise ValueError("Row axis invalid")
    if cols < 0 or cols >= size:
        raise ValueError("Column axis invalid")

    # Remove non-row and non-column axes that have length 1
    squeezed_shape = list(shape)
    for i in reversed(range(len(shape))):
        if shape[i] == 1 and i != rows and i != cols:
            squeezed_shape.pop(i)
            if i < cols:
                cols -= 1
            if i < rows:
                rows -= 1
    x = np.reshape(x, squeezed_shape)
    std = np.reshape(std, squeezed_shape)

    if np.ndim(x) < 2:
        cols += 2 - np.ndim(x)
        rows += 2 - np.ndim(x)
        x = np.atleast_2d(x)
        std = np.atleast_2d(std)

    size = np.ndim(x)
    if np.isnan(cols):
        if rows != size - 1:
            cols = size - 1
        else:
            cols = size - 2
    if np.isnan(rows):
        if cols != size - 1:
            rows = size - 1
        else:
            rows = size - 2

    # Put the row and column axes to the end
    axes = [i for i in range(size) if i not in (rows, cols)] + [rows, cols]
    x = np.transpose(x, axes=axes)
    std = np.transpose(std, axes=axes)

    vmax = np.max(np.abs(x) + scale*std)
    
    if scale == 0:
        _subplots(_hinton, (x, 2), fig=fig, kwargs=dict(vmax=vmax))
    else:
        def plotfunc(z, e, **kwargs):
            return _hinton(z, error=e, **kwargs)

        _subplots(plotfunc, (x, 2), (scale*std, 2), fig=fig, kwargs=dict(vmax=vmax))
Beispiel #5
0
def gaussian_hinton(X, rows=None, cols=None, scale=1):
    """
    Plot the Hinton diagram of a Gaussian node
    """

    # Get mean and second moment
    X = X._convert(GaussianMoments)
    (x, xx) = X.get_moments()
    ndim = len(X.dims[0])
    shape = X.get_shape(0)
    size = len(X.get_shape(0))

    # Compute standard deviation
    xx = misc.get_diag(xx, ndim=ndim)
    std = np.sqrt(xx - x**2)

    # Force explicit elements when broadcasting
    x = x * np.ones(shape)
    std = std * np.ones(shape)

    if rows is None:
        rows = np.nan
    if cols is None:
        cols = np.nan

    # Preprocess the axes to 0,...,ndim
    if rows < 0:
        rows += size
    if cols < 0:
        cols += size
    if rows < 0 or rows >= size:
        raise ValueError("Row axis invalid")
    if cols < 0 or cols >= size:
        raise ValueError("Column axis invalid")

    # Remove non-row and non-column axes that have length 1
    squeezed_shape = list(shape)
    for i in reversed(range(len(shape))):
        if shape[i] == 1 and i != rows and i != cols:
            squeezed_shape.pop(i)
            if i < cols:
                cols -= 1
            if i < rows:
                rows -= 1
    x = np.reshape(x, squeezed_shape)
    std = np.reshape(std, squeezed_shape)

    # Make explicit four axes
    cols = cols + (4 - np.ndim(x))
    rows = rows + (4 - np.ndim(x))
    x = misc.atleast_nd(x, 4)
    std = misc.atleast_nd(std, 4)

    size = np.ndim(x)
    if np.isnan(cols):
        if rows != size - 1:
            cols = size - 1
        else:
            cols = size - 2
    if np.isnan(rows):
        if cols != size - 1:
            rows = size - 1
        else:
            rows = size - 2

    # Put the row and column axes to the end
    axes = [i for i in range(size) if i not in (rows, cols)] + [rows, cols]
    x = np.transpose(x, axes=axes)
    std = np.transpose(std, axes=axes)

    if np.ndim(x) != 4:
        raise ValueError("Can not plot arrays with over 4 axes")

    M = np.shape(x)[0]
    N = np.shape(x)[1]
    vmax = np.max(np.abs(x) + scale*std)
    #plt.subplots(M, N, sharey=True, sharex=True, fig_kw)
    ax = [plt.subplot(M, N, i*N+j+1) for i in range(M) for j in range(N)]
    for i in range(M):
        for j in range(N):
            plt.subplot(M, N, i*N+j+1)

            #plt.subplot(M, N, i*N+j+1, sharey=ax[0], sharex=ax[0])
            if scale == 0:
                _hinton(x[i,j], vmax=vmax)
            else:
                _hinton(x[i,j], vmax=vmax, error=scale*std[i,j])