Ejemplo n.º 1
0
Archivo: log.py Proyecto: certik/sfepy
    def __call__(self, *args, **kwargs):
        finished = False
        save_figure = ""
        x_values = None
        if kwargs:
            if "finished" in kwargs:
                finished = kwargs["finished"]
            if "save_figure" in kwargs:
                save_figure = kwargs["save_figure"]
            if "x" in kwargs:
                x_values = kwargs["x"]

        if save_figure and (self.plot_queue is not None):
            self.plot_queue.put(["save", save_figure])

        if finished:
            self.terminate()
            return

        ls = len(args), self.n_arg
        if ls[0] != ls[1]:
            msg = "log called with wrong number of arguments! (%d == %d)" % ls
            raise IndexError(msg)

        for ii, name in enumerate(self.seq_data_names):
            aux = args[ii]
            if isinstance(aux, nm.ndarray):
                aux = nm.array(aux, ndmin=1)
                if len(aux) == 1:
                    aux = aux[0]
                else:
                    raise ValueError, "can log only scalars (%s)" % aux
            key = name_to_key(name, ii)
            self.data[key].append(aux)

        for ig in range(self.n_gr):
            if (x_values is not None) and x_values[ig]:
                self.x_values[ig].append(x_values[ig])
            else:
                self.x_values[ig].append(self.n_calls)

        if self.is_plot and self.can_plot:
            if self.n_calls == 0:
                atexit.register(self.terminate)

                self.plot_queue = Queue()
                self.plotter = ProcessPlotter(self.aggregate)
                self.plot_process = Process(
                    target=self.plotter,
                    args=(
                        self.plot_queue,
                        self.data_names,
                        self.igs,
                        self.seq_data_names,
                        self.yscales,
                        self.xaxes,
                        self.yaxes,
                    ),
                )
                self.plot_process.daemon = True
                self.plot_process.start()

            self.plot_data()

        self.n_calls += 1
Ejemplo n.º 2
0
Archivo: log.py Proyecto: certik/sfepy
class Log(Struct):
    """Log data and (optionally) plot them in the second process via
    ProcessPlotter."""

    def from_conf(conf, data_names):
        """`data_names` ... tuple of names grouped by subplots:
                            ([name1, name2, ...], [name3, name4, ...], ...)
        where name<n> are strings to display in (sub)plot legends."""
        if not isinstance(data_names, tuple):
            data_names = (data_names,)

        obj = Log(data_names=data_names, seq_data_names=[], igs=[], data={}, x_values={}, n_calls=0, plot_queue=None)

        ii = 0
        for ig, names in enumerate(obj.data_names):
            obj.x_values[ig] = []
            for name in names:
                key = name_to_key(name, ii)
                obj.data[key] = []
                obj.igs.append(ig)
                obj.seq_data_names.append(name)
                ii += 1
        obj.n_arg = len(obj.igs)

        obj.n_gr = len(obj.data_names)

        if isinstance(conf, dict):
            get = conf.get
        else:
            get = conf.get_default_attr

        obj.is_plot = get("is_plot", True)
        obj.yscales = get("yscales", ["linear"] * obj.n_arg)
        obj.xaxes = get("xaxes", ["iteration"] * obj.n_arg)
        obj.yaxes = get("yaxes", [""] * obj.n_arg)
        obj.aggregate = get("aggregate", 100)

        obj.can_plot = (pylab is not None) and (Process is not None)

        if obj.is_plot and (not obj.can_plot):
            output("warning: log plot is disabled, install pylab (GTKAgg)")
            output("         and multiprocessing")

        return obj

    from_conf = staticmethod(from_conf)

    def __call__(self, *args, **kwargs):
        finished = False
        save_figure = ""
        x_values = None
        if kwargs:
            if "finished" in kwargs:
                finished = kwargs["finished"]
            if "save_figure" in kwargs:
                save_figure = kwargs["save_figure"]
            if "x" in kwargs:
                x_values = kwargs["x"]

        if save_figure and (self.plot_queue is not None):
            self.plot_queue.put(["save", save_figure])

        if finished:
            self.terminate()
            return

        ls = len(args), self.n_arg
        if ls[0] != ls[1]:
            msg = "log called with wrong number of arguments! (%d == %d)" % ls
            raise IndexError(msg)

        for ii, name in enumerate(self.seq_data_names):
            aux = args[ii]
            if isinstance(aux, nm.ndarray):
                aux = nm.array(aux, ndmin=1)
                if len(aux) == 1:
                    aux = aux[0]
                else:
                    raise ValueError, "can log only scalars (%s)" % aux
            key = name_to_key(name, ii)
            self.data[key].append(aux)

        for ig in range(self.n_gr):
            if (x_values is not None) and x_values[ig]:
                self.x_values[ig].append(x_values[ig])
            else:
                self.x_values[ig].append(self.n_calls)

        if self.is_plot and self.can_plot:
            if self.n_calls == 0:
                atexit.register(self.terminate)

                self.plot_queue = Queue()
                self.plotter = ProcessPlotter(self.aggregate)
                self.plot_process = Process(
                    target=self.plotter,
                    args=(
                        self.plot_queue,
                        self.data_names,
                        self.igs,
                        self.seq_data_names,
                        self.yscales,
                        self.xaxes,
                        self.yaxes,
                    ),
                )
                self.plot_process.daemon = True
                self.plot_process.start()

            self.plot_data()

        self.n_calls += 1

    def terminate(self):
        if self.is_plot and self.can_plot:
            self.plot_queue.put(None)
            self.plot_process.join()
            self.n_calls = 0
            output("terminated")

    def plot_data(self):
        put = self.plot_queue.put

        put(["clear"])
        for ii, name in enumerate(self.seq_data_names):
            key = name_to_key(name, ii)
            try:
                put(["iseq", ii])
                put(["plot", nm.array(self.x_values[self.igs[ii]]), nm.array(self.data[key])])
            except:
                print ii, name, self.data[key]
                raise
        put(["legends"])
        put(["continue"])