def update_plot(self): layers = self.get_trainable_layers() for layer in layers: for param in self.parameters: weights = [ w for w in layer.weights if param in w.name.split("_") ] if len(weights) == 0: continue val = numpy.column_stack((w.get_value() for w in weights)) name = layer.name + "_" + param self.layers_stats[name]["values"] = val.ravel() for s in self.stats: if s == "raster": if len(val.shape) > 2: val = val.reshape((val.shape[0], -1), order='F') self.layers_stats[name][s] = val # self.fig.colorbar() else: self.layers_stats[name][s].append( getattr(numpy, s)(val)) plt.figtext(.02, .02, get_model_desc(self.model), wrap=True, fontsize=8) self.fig.tight_layout() self.fig.subplots_adjust(bottom=.2) self.fig.canvas.draw() self.fig.canvas.flush_events()
def draw_plot(self): self.fig.clf() layers = self.get_trainable_layers() height = len(self.layers_stats) width = len(self.stats) + 1 plot_count = 1 for layer in layers: for param in self.parameters: weights = [ w for w in layer.weights if param in w.name.split("_") ] if len(weights) == 0: continue val = numpy.column_stack((w.get_value() for w in weights)) name = layer.name + "_" + param self.layers_stats[name]["values"] = val.ravel() ax = self.fig.add_subplot(height, width, plot_count) ax.hist(self.layers_stats[name]["values"], bins=50) ax.set_title(name, fontsize=10) ax.grid(True) ax.tick_params(labelsize=8) plot_count += 1 for s in self.stats: axs = self.fig.add_subplot(height, width, plot_count) if s == "raster": if len(val.shape) > 2: val = val.reshape((val.shape[0], -1), order='F') self.layers_stats[name][s] = val m = axs.imshow( self.layers_stats[name][s], cmap='coolwarm', interpolation='nearest', aspect='auto', ) # aspect='equal' cbar = self.fig.colorbar(mappable=m) cbar.ax.tick_params(labelsize=8) else: self.layers_stats[name][s].append( getattr(numpy, s)(val)) axs.plot(self.layers_stats[name][s]) axs.set_ylabel(s, fontsize="small") axs.set_xlabel('epoch', fontsize="small") axs.grid(True) axs.set_title(name + " - " + s, fontsize=10) axs.tick_params(labelsize=8) plot_count += 1 # plt.figtext(.1, .1, get_model_desc(self.model), wrap=True, fontsize=8) desc = get_model_desc(self.model) self.fig.text(.02, .02, desc, verticalalignment='bottom', wrap=True, fontsize=8) self.fig.tight_layout() self.fig.subplots_adjust(bottom=.14) self.fig.canvas.draw() self.fig.canvas.flush_events()
def on_epoch_end(self, epoch, logs={}): self.fig.clf() linewidth = 1.2 self.fig.set_size_inches(self.width * (1 + len(self.get_metrics(logs))), self.height, forward=True) custom_metrics_keys = self.get_metrics(logs) total_plots = len(custom_metrics_keys) + 1 ################################################## # First - Plot Models loss self.model_loss.append(logs['loss']) self.validation_loss.append(logs['val_loss']) ax = self.fig.add_subplot(1, total_plots, 1) ax.plot(self.model_loss, linewidth=linewidth) ax.plot(self.validation_loss, linewidth=linewidth) ax.set_title('model loss', fontsize=10) ax.set_ylabel('loss') ax.set_xlabel('epoch') ax.legend(['train', 'val'], loc='upper left', fancybox=True) ax.grid(True) ax.grid(b=True, which='major', color='gray', linewidth=.5) ax.grid(b=True, which='minor', color='gray', linewidth=0.5) ax.tick_params(labelsize=10) # leg = ax.gca().get_legend() ################################################## # Second - Plot Custom Metrics for i, (dataset_name, metrics) in enumerate( sorted(custom_metrics_keys.items(), reverse=False)): axs = self.fig.add_subplot(1, total_plots, i + 2) axs.set_title(dataset_name, fontsize=10) axs.set_ylabel('score') axs.set_xlabel('epoch') if self.grid_ranges: axs.set_ylim(self.grid_ranges) # append the values to the corresponding array for m in sorted(metrics): entry = ".".join([dataset_name, m]) self.custom_metrics[entry].append(logs[entry]) axs.plot(self.custom_metrics[entry], label=m, linewidth=linewidth) axs.tick_params(labelsize=10) labels = list(sorted(metrics)) if self.benchmarks: for (label, benchmark), color in zip(self.benchmarks.items(), ["y", "r"]): axs.axhline(y=benchmark, linewidth=linewidth, color=color) labels = labels + [label] axs.legend(labels, loc='upper left', fancybox=True) axs.grid(True) axs.grid(b=True, which='major', color='gray', linewidth=.5) axs.grid(b=True, which='minor', color='gray', linewidth=0.5) plt.rcParams.update({'font.size': 10}) desc = get_model_desc(self.model) self.fig.text(.02, .02, desc, verticalalignment='bottom', wrap=True, fontsize=8) self.fig.tight_layout() self.fig.subplots_adjust(bottom=.18) self.fig.canvas.draw() self.fig.canvas.flush_events() self.save_plot()