Esempio n. 1
0
class Plotter(object):
    """This class is responsible for managing the layout of a figure, and also implementing
    the plotting commands for simple graphs, hopefully making it even easier than using
    matplotlib directly."""
    ACTIVE_AXES = "active axes"  # A constant used as default target of the plotting methods
    default_hspace = 1.0 / 8.0
    default_vspace = 1.0 / 8.0
    default_hpadding = 0.05
    default_vpadding = 0.05

    def __init__(self, title=None, rows=0, cols=0,
                 hspace=None, vspace=None,
                 hpadding=None, vpadding=None):
        if hspace is None:
            hspace = self.default_hspace
        if vspace is None:
            vspace = self.default_vspace
        if hpadding is None:
            hpadding = self.default_hpadding
        if vpadding is None:
            vpadding = self.default_vpadding
        self.figure = Figure()
        self.figure.plotter = self
        self.title = self.figure.suptitle("" if title is None else title)
        self.rows = rows
        self.cols = cols
        self.hspace = hspace
        self.vspace = vspace
        self.hpadding = hpadding
        self.vpadding = vpadding
        self.active_axes = None

    def update(self):
        """Show changes in the figure."""
        self.figure.show()

    def save(self, *args, **kwargs):
        """Save the figure. Please refer to matplotlib's documentation."""
        self.figure.savefig(*args, **kwargs)

    def close(self):
        close_figure(self.figure)

    def layout(self, rows=None, cols=None, update=True):
        """Change the layout of the figure, i.e. the number of rows and columns."""
        n = len(self.figure.axes)
        if rows is None and cols is None:
            rows = int(round(sqrt(n)))
            cols = rows + (1 if n > rows**2 else 0)
        elif rows is None:
            rows = int(ceil(float(n) / cols))
        elif cols is None:
            cols = int(ceil(float(n) / rows))
        elif rows * cols < n:
            raise ValueError("insufficient cells")
        self.rows = rows
        self.cols = cols
        self.redraw(update)

    def spacing(self, hspace=None, vspace=None, update=True):
        """Change the spacing between axes in the figure."""
        if hspace is None and vspace is None:
            return
        if hspace is not None:
            if not 0.0 <= hspace <= 1.0:
                raise ValueError("illegal horizontal spacing (must be in [0, 1])")
            self.hspace = hspace
        if vspace is not None:
            if not 0.0 <= vspace <= 1.0:
                raise ValueError("illegal vertical spacing (must be in [0, 1])")
            self.vspace = vspace
        self.redraw(update)

    def padding(self, hpadding=None, vpadding=None, update=True):
        if hpadding is None and vpadding is None:
            return
        if hpadding is not None:
            if not 0.0 <= hpadding < 0.5:
                raise ValueError("illegal horizontal spacing (must be in [0, 0.5))")
            self.hpadding = hpadding
        if vpadding is not None:
            if not 0.0 <= vpadding < 0.5:
                raise ValueError("illegal vertical spacing (must be in [0, 0.5))")
            self.vpadding = vpadding
        self.redraw(update)

    def redraw(self, update=True):
        """Redraw the axes in the figure. This is usually used only after changes to layout
        or spacing in the figure."""
        total_width = self.cols * (1.0 + 2.0*self.hspace)
        total_height = self.rows * (1.0 + 2.0*self.vspace)
        plot_width = (1.0 - 2.0*self.hpadding) / total_width
        plot_height = (1.0 - 2.0*self.vpadding) / total_height
        space_width = self.hspace * plot_width
        space_height = self.vspace * plot_height
        # Reposition the axes according to the new dimensions
        n = len(self.figure.axes)
        index = 0
        y_pos = 1.0 - plot_height - space_height - self.vpadding
        for _ in xrange(self.rows):
            x_pos = space_width + self.hpadding
            for x in xrange(self.cols):
                axes = self.figure.axes[index]
                axes.set_position([x_pos, y_pos, plot_width, plot_height])
                index += 1
                x_pos += plot_width + 2 * space_width
                if index == n:
                    break
            y_pos -= plot_height + 2 * space_height
            if index == n:
                break
        if update:
            self.update()

    def set_title(self, title, update=True):
        """Set the title of the figure (not axes!)."""
        self.title.set_text("" if title is None else title)
        if update: self.update()

    def set_size(self, width, height, inches=False):
        """Sets the size of the image in pixels or inches (if 'inches' is true)."""
        dpi = self.figure.get_dpi()
        if not inches:
            width  /= dpi
            height /= dpi
        self.figure.set_size_inches(width, height, forward=True)

    def add_axes(self, make_active=True):
        """Add a new axes to the figure and place it in the right position. If the current
        grid of graphs cannot accommodate the new axes, the figure layout is recalculated."""
        axes = self.figure.add_subplot(1, 1, 1, label=str(len(self.figure.axes)))
        if make_active:
            self.active_axes = axes
        rows, cols = None, None
        if self.rows * self.cols >= len(self.figure.axes):
            rows, cols = self.rows, self.cols
        self.layout(rows, cols, update=False)
        return axes

    def config_axes(self, axes=ACTIVE_AXES, title=None, xlabel=None, ylabel=None,
                    xlimit=None, ylimit=None, legend=None, grid=None, update=True):
        """A many-in-one configuration method. Saves a few boring lines of matplotlib code."""
        axes = self.get_axes(axes)
        if title  is not None: axes.set_title(title)
        if xlabel is not None: axes.set_xlabel(xlabel)
        if ylabel is not None: axes.set_ylabel(ylabel)
        if xlimit is not None: axes.set_xlim(xlimit)
        if ylimit is not None: axes.set_ylim(ylimit)
        if legend is not None: axes.legend(loc=legend)
        if grid   is not None: axes.grid(bool(grid))
        if update:
            self.update()

    def get_axes(self, axes=ACTIVE_AXES):
        """This method can be used to check if a given axes belongs to the figure, retrieve
        the currently active axes, or add new axes to the figure."""
        if axes is Plotter.ACTIVE_AXES:
            if self.active_axes is None:
                self.add_axes(make_active=True)
            return self.active_axes
        if axes is None:
            return self.add_axes(make_active=False)
        if axes in self.figure.axes:
            return axes
        raise Exception("axes does not belong to this Plotter")

    def set_active(self, axes):
        """Set the plotter's active axes, i.e. the default target of plotting commands."""
        if axes not in self.figure.axes:
            raise Exception("axes does not belong to this Plotter")
        self.active_axes = axes

    # -------------------------------------------------
    # Plotting methods
    @contextmanager
    def plotting_on(self, axes, update=True):
        yield self.get_axes(axes)
        if update:
            self.update()

    def legend(self, axes=ACTIVE_AXES, update=True, **kwargs):
        """Add a legend to the given axes."""
        with self.plotting_on(axes, update) as axes:
            axes.legend(**kwargs)
        return axes

    def pie_chart(self, values, freqs, axes=ACTIVE_AXES, update=True, **kwargs):
        with self.plotting_on(axes, update) as axes:
            axes.pie(freqs, labels=values, **kwargs)
        return axes

    def bar_chart(self, values, freqs, axes=ACTIVE_AXES, update=True, **kwargs):
        with self.plotting_on(axes, update) as axes:
            data = self.__prepare_bar_chart(values, freqs)
            axes.xaxis.set_ticks(data.xtick_locs)
            axes.xaxis.set_ticklabels(data.xtick_labels)
            axes.bar(data.left, data.height, width=data.bar_width, **kwargs)
        return axes

    def pareto_chart(self, values, freqs, axes=ACTIVE_AXES, update=True, **kwargs):
        with self.plotting_on(axes, update) as axes:
            data = self.__prepare_pareto_chart(values, freqs)
            axes.xaxis.set_ticks(data.xtick_locs)
            axes.xaxis.set_ticklabels(data.xtick_labels)
            axes.bar(data.left, data.height, width=data.bar_width, **kwargs)
            axes.plot(data.xs, data.ys, "r-", label="Cumulative frequency")
        return axes

    def histogram(self, values, freqs=None, bins=10, axes=ACTIVE_AXES, update=True, **kwargs):
        with self.plotting_on(axes, update) as axes:
            data = self.__prepare_histogram(values, freqs, bins)
            axes.bar(data.left, data.height, width=data.bar_width, **kwargs)
        return axes

    def box_plot(self, values, axes=ACTIVE_AXES, update=True, **kwargs):
        with self.plotting_on(axes, update) as axes:
            axes.boxplot(values, **kwargs)
        return axes

    def run_chart(self, times, values, numeric=True, axes=ACTIVE_AXES, update=True, **kwargs):
        """A run chart of a time series."""
        with self.plotting_on(axes, update) as axes:
            data = self.__prepare_run_chart(list(times), list(values), numeric)
            if not numeric:
                axes.yaxis.set_ticks(data.ytick_locs)
                axes.yaxis.set_ticklabels(data.ytick_labels)
            axes.plot(data.xs, data.ys, **kwargs)
        return axes

    def line_plot(self, xs, ys, axes=ACTIVE_AXES, update=True, **kwargs):
        """A simple 2D line plot."""
        with self.plotting_on(axes, update) as axes:
            axes.plot(list(xs), list(ys), **kwargs)
        return axes

    def function_plot(self, function, start=0, stop=1.0, observations=100,
                      axes=ACTIVE_AXES, update=True, **kwargs):
        """Make a quick plot of a function on a given interval."""
        xs, ys = [], []
        dx = float(stop - start) / (observations - 1)
        x = start
        for i in xrange(observations):
            xs.append(x)
            ys.append(function(x))
            x += dx
        return self.line_plot(xs, ys, axes=axes, update=True, **kwargs)

    # -------------------------------------------------
    # Preparation of data for plotting
    def __prepare_bar_chart(self, values, freqs, bar_width=1.0, bar_space=0.5):
        left = [(bar_width + bar_space) * x for x in xrange(len(freqs))]
        xtick_locs = [l + bar_width / 2.0 for l in left]
        return Namespace(left=left, height=freqs,
                         bar_width=bar_width,
                         xtick_locs=xtick_locs,
                         xtick_labels=values)

    def __prepare_pareto_chart(self, values, freqs, bar_width=1.0):
        total = float(sum(freqs))
        items = sorted([(f / total, v) for f, v in zip(freqs, values)], reverse=True)
        height = []
        left = [bar_width * x for x in xrange(len(items))]
        xtick_locs = [l + bar_width / 2.0 for l in left]
        xtick_labels = []
        xs = [bar_width * x for x in xrange(len(items) + 1)]
        ys = [0.0]
        for f, v in items:
            height.append(f)
            xtick_labels.append(v)
            ys.append(ys[-1] + f)
        return Namespace(left=left, height=height,
                         xs=xs, ys=ys,
                         bar_width=bar_width,
                         xtick_locs=xtick_locs,
                         xtick_labels=xtick_labels)

    def __prepare_histogram(self, values, freqs, bins):
        if freqs is None:
            freqs = [1.0] * len(values)
        items = sorted(zip(values, freqs))
        minimum = items[ 0][0]
        maximum = items[-1][0]
        total_freq = sum(freqs)
        bin_span = float(maximum - minimum) / bins
        bin_end = [minimum + bin_span * (x + 1) for x in xrange(bins)]
        bin_end[-1] = maximum
        bin_freq = [0.0] * bins
        cur_bin = 0
        for v, f in items:
            while v > bin_end[cur_bin]:
                cur_bin += 1
            bin_freq[cur_bin] += f / (bin_span * total_freq)
        return Namespace(left=[end - bin_span for end in bin_end],
                         height=bin_freq,
                         bar_width=bin_span)

    def __prepare_run_chart(self, times, values, numeric):
        xs = []
        ys = []
        prev_y = values[0]
        for y, t in zip(values, times):
            xs.extend((t, t))
            ys.extend((prev_y, y))
            prev_y = y
        ytick_locs = None
        ytick_labels = None
        if not numeric:
            # map objects to integer y values if the tseries is not numeric
            y_set = sorted(set(ys))
            y_mapping = dict(zip(y_set, xrange(len(y_set))))
            ys = [y_mapping[y] for y in ys]
            # prepare y ticks explaining the translation from objects to integers
            yticks = sorted((i, v) for v, i in y_mapping.iteritems())
            ytick_locs   = [i for i, _ in yticks]
            ytick_labels = [v for _, v in yticks]
        return Namespace(xs=xs, ys=ys,
                         ytick_locs=ytick_locs,
                         ytick_labels=ytick_labels)