def plot_gaussian(X, axis=-1, scale=2, **kwargs): """ Plot Gaussian node as a 1-D function Parameters ---------- X : node Node with Gaussian moments. axis : int The index of the time axis. """ X = X._convert(GaussianMoments) u_X = X.get_moments() x = u_X[0] xx = misc.get_diag(u_X[1], ndim=len(X.dims[0])) std = scale * np.sqrt(xx - x**2) #std = scale * np.sqrt(np.einsum('...ii->...i', xx) - x**2) return _timeseries_mean_and_error(x, std, axis=axis, **kwargs)
def plot_gaussian(X, axis=-1, scale=2, **kwargs): """ Plot Gaussian node as a 1-D function Parameters ---------- X : node Node with Gaussian moments. axis : int The index of the time axis. """ X = X._ensure_moments(X, GaussianMoments, ndim=0) u_X = X.get_moments() x = u_X[0] xx = misc.get_diag(u_X[1], ndim=len(X.dims[0])) std = scale * np.sqrt(xx - x**2) #std = scale * np.sqrt(np.einsum('...ii->...i', xx) - x**2) return _timeseries_mean_and_error(x, std, axis=axis, **kwargs)
def gaussian_hinton(X, rows=None, cols=None, scale=1, fig=None): """ Plot the Hinton diagram of a Gaussian node """ if fig is None: fig = plt.gcf() # Get mean and second moment X = X._convert(GaussianMoments) (x, xx) = X.get_moments() ndim = len(X.dims[0]) shape = X.get_shape(0) size = len(X.get_shape(0)) # Compute standard deviation xx = misc.get_diag(xx, ndim=ndim) std = np.sqrt(xx - x**2) # Force explicit elements when broadcasting x = x * np.ones(shape) std = std * np.ones(shape) if rows is None: rows = np.nan if cols is None: cols = np.nan # Preprocess the axes to 0,...,ndim if rows < 0: rows += size if cols < 0: cols += size if rows < 0 or rows >= size: raise ValueError("Row axis invalid") if cols < 0 or cols >= size: raise ValueError("Column axis invalid") # Remove non-row and non-column axes that have length 1 squeezed_shape = list(shape) for i in reversed(range(len(shape))): if shape[i] == 1 and i != rows and i != cols: squeezed_shape.pop(i) if i < cols: cols -= 1 if i < rows: rows -= 1 x = np.reshape(x, squeezed_shape) std = np.reshape(std, squeezed_shape) size = np.ndim(x) if np.isnan(cols): if rows != size - 1: cols = size - 1 else: cols = size - 2 if np.isnan(rows): if cols != size - 1: rows = size - 1 else: rows = size - 2 # Put the row and column axes to the end axes = [i for i in range(size) if i not in (rows, cols)] + [rows, cols] x = np.transpose(x, axes=axes) std = np.transpose(std, axes=axes) vmax = np.max(np.abs(x) + scale * std) if scale == 0: _subplots(_hinton, (x, 2), fig=fig, kwargs=dict(vmax=vmax)) else: def plotfunc(z, e, **kwargs): return _hinton(z, error=e, **kwargs) _subplots(plotfunc, (x, 2), (scale * std, 2), fig=fig, kwargs=dict(vmax=vmax))
def gaussian_hinton(X, rows=None, cols=None, scale=1, fig=None): """ Plot the Hinton diagram of a Gaussian node """ if fig is None: fig = plt.gcf() # Get mean and second moment X = X._ensure_moments(X, GaussianMoments, ndim=0) (x, xx) = X.get_moments() ndim = len(X.dims[0]) shape = X.get_shape(0) size = len(X.get_shape(0)) # Compute standard deviation xx = misc.get_diag(xx, ndim=ndim) std = np.sqrt(xx - x**2) # Force explicit elements when broadcasting x = x * np.ones(shape) std = std * np.ones(shape) if rows is None: rows = np.nan if cols is None: cols = np.nan # Preprocess the axes to 0,...,ndim if rows < 0: rows += size if cols < 0: cols += size if rows < 0 or rows >= size: raise ValueError("Row axis invalid") if cols < 0 or cols >= size: raise ValueError("Column axis invalid") # Remove non-row and non-column axes that have length 1 squeezed_shape = list(shape) for i in reversed(range(len(shape))): if shape[i] == 1 and i != rows and i != cols: squeezed_shape.pop(i) if i < cols: cols -= 1 if i < rows: rows -= 1 x = np.reshape(x, squeezed_shape) std = np.reshape(std, squeezed_shape) if np.ndim(x) < 2: cols += 2 - np.ndim(x) rows += 2 - np.ndim(x) x = np.atleast_2d(x) std = np.atleast_2d(std) size = np.ndim(x) if np.isnan(cols): if rows != size - 1: cols = size - 1 else: cols = size - 2 if np.isnan(rows): if cols != size - 1: rows = size - 1 else: rows = size - 2 # Put the row and column axes to the end axes = [i for i in range(size) if i not in (rows, cols)] + [rows, cols] x = np.transpose(x, axes=axes) std = np.transpose(std, axes=axes) vmax = np.max(np.abs(x) + scale*std) if scale == 0: _subplots(_hinton, (x, 2), fig=fig, kwargs=dict(vmax=vmax)) else: def plotfunc(z, e, **kwargs): return _hinton(z, error=e, **kwargs) _subplots(plotfunc, (x, 2), (scale*std, 2), fig=fig, kwargs=dict(vmax=vmax))
def gaussian_hinton(X, rows=None, cols=None, scale=1): """ Plot the Hinton diagram of a Gaussian node """ # Get mean and second moment X = X._convert(GaussianMoments) (x, xx) = X.get_moments() ndim = len(X.dims[0]) shape = X.get_shape(0) size = len(X.get_shape(0)) # Compute standard deviation xx = misc.get_diag(xx, ndim=ndim) std = np.sqrt(xx - x**2) # Force explicit elements when broadcasting x = x * np.ones(shape) std = std * np.ones(shape) if rows is None: rows = np.nan if cols is None: cols = np.nan # Preprocess the axes to 0,...,ndim if rows < 0: rows += size if cols < 0: cols += size if rows < 0 or rows >= size: raise ValueError("Row axis invalid") if cols < 0 or cols >= size: raise ValueError("Column axis invalid") # Remove non-row and non-column axes that have length 1 squeezed_shape = list(shape) for i in reversed(range(len(shape))): if shape[i] == 1 and i != rows and i != cols: squeezed_shape.pop(i) if i < cols: cols -= 1 if i < rows: rows -= 1 x = np.reshape(x, squeezed_shape) std = np.reshape(std, squeezed_shape) # Make explicit four axes cols = cols + (4 - np.ndim(x)) rows = rows + (4 - np.ndim(x)) x = misc.atleast_nd(x, 4) std = misc.atleast_nd(std, 4) size = np.ndim(x) if np.isnan(cols): if rows != size - 1: cols = size - 1 else: cols = size - 2 if np.isnan(rows): if cols != size - 1: rows = size - 1 else: rows = size - 2 # Put the row and column axes to the end axes = [i for i in range(size) if i not in (rows, cols)] + [rows, cols] x = np.transpose(x, axes=axes) std = np.transpose(std, axes=axes) if np.ndim(x) != 4: raise ValueError("Can not plot arrays with over 4 axes") M = np.shape(x)[0] N = np.shape(x)[1] vmax = np.max(np.abs(x) + scale*std) #plt.subplots(M, N, sharey=True, sharex=True, fig_kw) ax = [plt.subplot(M, N, i*N+j+1) for i in range(M) for j in range(N)] for i in range(M): for j in range(N): plt.subplot(M, N, i*N+j+1) #plt.subplot(M, N, i*N+j+1, sharey=ax[0], sharex=ax[0]) if scale == 0: _hinton(x[i,j], vmax=vmax) else: _hinton(x[i,j], vmax=vmax, error=scale*std[i,j])