def plot_and_save_statistics(solvers, args):
    """Plot some desired statistics."""
    dname = args.dataset if not args.use_gray else args.dataset + '.gray'
    for k, v in solvers.items():
        # save statistics
        stats_arr = su.ntpl2array(v.getitstat())
        np.save(os.path.join(args.output_path, k, f'{dname}.stats.npy'),
                stats_arr)
        # we save time separately
        time_stats = {'Time': v.getitstat().Time}
        pickle.dump(
            time_stats,
            open(os.path.join(args.output_path, k, f'{dname}.time_stats.pkl'),
                 'wb'))
        # save dictionaries visualization
        plt.clf()
        d = v.getdict().squeeze()
        if d.ndim == 3:  # grayscale image
            plt.imshow(su.tiledict(d), cmap='gray')
        else:
            plt.imshow(su.tiledict(d))
        plt.savefig(os.path.join(args.output_path, k, f'{dname}.pdf'),
                    bbox_inches='tight')
    if 1:
        plt.clf()
        nsol = len(solvers)
        for i, (k, v) in enumerate(solvers.items()):
            plt.subplot(1, nsol, i + 1)
            d = v.getdict().squeeze()
            if d.ndim == 3:  # grayscale image
                plt.imshow(su.tiledict(d), cmap='gray')
            else:
                plt.imshow(su.tiledict(d))
            plt.title(k)
        plt.show()
Beispiel #2
0
def snapshot_solver_stats(solver, path):
    """Save solver's running time statistics (and others) to path."""
    if _cfg.SNAPSHOT:
        stats_arr = su.ntpl2array(solver.getitstat())
        np.save(os.path.join(path, 'stats.npy'), stats_arr)
        time_stats = {'Time': solver.getitstat().Time}
        with open(os.path.join(path, 'time_stats.pkl'), 'wb') as f:
            pickle.dump(time_stats, f)
        return time_stats
    return None
Beispiel #3
0
 def test_01(self):
     nt = collections.namedtuple('NT', ('A', 'B', 'C'))
     t0 = nt(0, 1, 2)
     t0a = util.ntpl2array(t0)
     t1 = util.array2ntpl(t0a)
     assert t0 == t1
Beispiel #4
0
 def test_01(self):
     nt = collections.namedtuple('NT', ('A', 'B', 'C'))
     t0 = nt(0, 1, 2)
     t0a = util.ntpl2array(t0)
     t1 = util.array2ntpl(t0a)
     assert t0 == t1