def plot_pseudospectra_single(perf_data, contour=False, levels=False, invert=True, linewidth=3.0): W = perf_data.net.W N = W.shape[0] plt.figure() fig = plt.gcf() ax = fig.add_subplot(1, 1, 1) if contour: plot_pseudospectra_contour(W, bounds=[-3, 3, -3, 3], npts=50, ax=ax, colorbar=False, invert=False, log=False, levels=levels, linewidth=linewidth) else: plot_pseudospectra(W, bounds=[-3, 3, -3, 3], npts=50, ax=ax, colorbar=True, log=True, invert=invert) plt.axhline(0.0, color='k', axes=ax) plt.axvline(0.0, color='k', axes=ax) cir = pylab.Circle((0.0, 0.0), radius=1.00, fc='gray', fill=False) pylab.gca().add_patch(cir) plt.xticks([], []) plt.yticks([], []) if not contour: for m in range(N): ev = perf_data.eigen_values[m] ax.plot(ev.real, ev.imag, 'ko', markerfacecolor='w')
def plot_pseudospectra_by_perf(pdata, perf_attr='logit_perf', contour=False, levels=None, invert=True): num_plots = 25 indx_off = [0, len(pdata)-num_plots] weights = [[], []] for k,offset in enumerate(indx_off): pend = offset + num_plots for m,p in enumerate(pdata[offset:pend]): weights[k].append(p.W) perrow = int(np.sqrt(num_plots)) percol = perrow for j,offset in enumerate(indx_off): fig = plt.figure() fig.subplots_adjust(wspace=0.1, hspace=0.1) for k in range(num_plots): W = weights[j][k] N = W.shape[0] ax = fig.add_subplot(perrow, percol, k) if contour: plot_pseudospectra_contour(W, bounds=[-3, 3, -3, 3], npts=50, ax=ax, colorbar=False, invert=False, log=False, levels=levels) else: plot_pseudospectra(W, bounds=[-3, 3, -3, 3], npts=50, ax=ax, colorbar=True, log=True, invert=invert) plt.axhline(0.0, color='k', axes=ax) plt.axvline(0.0, color='k', axes=ax) cir = pylab.Circle((0.0, 0.0), radius=1.00, fc='gray', fill=False) pylab.gca().add_patch(cir) plt.xticks([], []) plt.yticks([], []) if not contour: p = pdata[offset + k] for m in range(N): ev = p.eigen_values[m] ax.plot(ev.real, ev.imag, 'ko', markerfacecolor='w') if offset == 0: plt.suptitle('Pseudospectra of Top %d Networks' % num_plots) else: plt.suptitle('Pseudospectra of Bottom %d Networks' % num_plots) plt.show()