Esempio n. 1
0
    def update_plot(self):
        layers = self.get_trainable_layers()

        for layer in layers:
            for param in self.parameters:
                weights = [
                    w for w in layer.weights if param in w.name.split("_")
                ]

                if len(weights) == 0:
                    continue

                val = numpy.column_stack((w.get_value() for w in weights))
                name = layer.name + "_" + param
                self.layers_stats[name]["values"] = val.ravel()
                for s in self.stats:
                    if s == "raster":
                        if len(val.shape) > 2:
                            val = val.reshape((val.shape[0], -1), order='F')
                        self.layers_stats[name][s] = val
                        # self.fig.colorbar()
                    else:
                        self.layers_stats[name][s].append(
                            getattr(numpy, s)(val))

        plt.figtext(.02,
                    .02,
                    get_model_desc(self.model),
                    wrap=True,
                    fontsize=8)
        self.fig.tight_layout()
        self.fig.subplots_adjust(bottom=.2)
        self.fig.canvas.draw()
        self.fig.canvas.flush_events()
Esempio n. 2
0
    def draw_plot(self):
        self.fig.clf()

        layers = self.get_trainable_layers()
        height = len(self.layers_stats)
        width = len(self.stats) + 1

        plot_count = 1
        for layer in layers:
            for param in self.parameters:
                weights = [
                    w for w in layer.weights if param in w.name.split("_")
                ]

                if len(weights) == 0:
                    continue

                val = numpy.column_stack((w.get_value() for w in weights))
                name = layer.name + "_" + param

                self.layers_stats[name]["values"] = val.ravel()
                ax = self.fig.add_subplot(height, width, plot_count)
                ax.hist(self.layers_stats[name]["values"], bins=50)
                ax.set_title(name, fontsize=10)
                ax.grid(True)
                ax.tick_params(labelsize=8)
                plot_count += 1

                for s in self.stats:
                    axs = self.fig.add_subplot(height, width, plot_count)

                    if s == "raster":
                        if len(val.shape) > 2:
                            val = val.reshape((val.shape[0], -1), order='F')
                        self.layers_stats[name][s] = val
                        m = axs.imshow(
                            self.layers_stats[name][s],
                            cmap='coolwarm',
                            interpolation='nearest',
                            aspect='auto',
                        )  # aspect='equal'
                        cbar = self.fig.colorbar(mappable=m)
                        cbar.ax.tick_params(labelsize=8)
                    else:
                        self.layers_stats[name][s].append(
                            getattr(numpy, s)(val))
                        axs.plot(self.layers_stats[name][s])
                        axs.set_ylabel(s, fontsize="small")
                        axs.set_xlabel('epoch', fontsize="small")
                        axs.grid(True)

                    axs.set_title(name + " - " + s, fontsize=10)
                    axs.tick_params(labelsize=8)
                    plot_count += 1

        # plt.figtext(.1, .1, get_model_desc(self.model), wrap=True, fontsize=8)
        desc = get_model_desc(self.model)
        self.fig.text(.02,
                      .02,
                      desc,
                      verticalalignment='bottom',
                      wrap=True,
                      fontsize=8)
        self.fig.tight_layout()
        self.fig.subplots_adjust(bottom=.14)
        self.fig.canvas.draw()
        self.fig.canvas.flush_events()
Esempio n. 3
0
    def on_epoch_end(self, epoch, logs={}):
        self.fig.clf()
        linewidth = 1.2
        self.fig.set_size_inches(self.width *
                                 (1 + len(self.get_metrics(logs))),
                                 self.height,
                                 forward=True)
        custom_metrics_keys = self.get_metrics(logs)

        total_plots = len(custom_metrics_keys) + 1
        ##################################################
        # First - Plot Models loss
        self.model_loss.append(logs['loss'])
        self.validation_loss.append(logs['val_loss'])

        ax = self.fig.add_subplot(1, total_plots, 1)
        ax.plot(self.model_loss, linewidth=linewidth)
        ax.plot(self.validation_loss, linewidth=linewidth)
        ax.set_title('model loss', fontsize=10)
        ax.set_ylabel('loss')
        ax.set_xlabel('epoch')
        ax.legend(['train', 'val'], loc='upper left', fancybox=True)
        ax.grid(True)
        ax.grid(b=True, which='major', color='gray', linewidth=.5)
        ax.grid(b=True, which='minor', color='gray', linewidth=0.5)
        ax.tick_params(labelsize=10)
        # leg = ax.gca().get_legend()

        ##################################################
        # Second - Plot Custom Metrics
        for i, (dataset_name, metrics) in enumerate(
                sorted(custom_metrics_keys.items(), reverse=False)):
            axs = self.fig.add_subplot(1, total_plots, i + 2)
            axs.set_title(dataset_name, fontsize=10)
            axs.set_ylabel('score')
            axs.set_xlabel('epoch')
            if self.grid_ranges:
                axs.set_ylim(self.grid_ranges)

            # append the values to the corresponding array
            for m in sorted(metrics):
                entry = ".".join([dataset_name, m])
                self.custom_metrics[entry].append(logs[entry])
                axs.plot(self.custom_metrics[entry],
                         label=m,
                         linewidth=linewidth)

            axs.tick_params(labelsize=10)
            labels = list(sorted(metrics))
            if self.benchmarks:
                for (label, benchmark), color in zip(self.benchmarks.items(),
                                                     ["y", "r"]):
                    axs.axhline(y=benchmark, linewidth=linewidth, color=color)
                    labels = labels + [label]
            axs.legend(labels, loc='upper left', fancybox=True)
            axs.grid(True)
            axs.grid(b=True, which='major', color='gray', linewidth=.5)
            axs.grid(b=True, which='minor', color='gray', linewidth=0.5)

        plt.rcParams.update({'font.size': 10})

        desc = get_model_desc(self.model)
        self.fig.text(.02,
                      .02,
                      desc,
                      verticalalignment='bottom',
                      wrap=True,
                      fontsize=8)
        self.fig.tight_layout()
        self.fig.subplots_adjust(bottom=.18)
        self.fig.canvas.draw()
        self.fig.canvas.flush_events()

        self.save_plot()