def __init__( self, learn: Learner, do_remove: bool = True, hMin=-1, hMax=1, nBins=100, useClasses=False, # if true compute histogram of classes in the last layer liveChart=True, # show live chart of last layer modulesId=-1, # array of modules to keep ): self.hMin = hMin or (-hMax) self.hMax = hMax self.nBins = nBins self.zero_bin = math.floor(-self.nBins * self.hMin / (self.hMax - self.hMin)) self.liveChart = liveChart self.allModules = [m for m in flatten_model(learn.model)] self.useClasses = useClasses modules = self.allModules if modulesId: modules = [self.allModules[i] for i in listify(modulesId)] self.allModules = modules if modules else self.allModules self.c = learn.data.c # Number of Calsses super().__init__(learn, modules, do_remove)
def apply_tfms(self, tfms, duration:int=None, size_factor:tuple=(1,1), do_resolve:bool=True, padding_mode:str='reflection'): tfms = listify(tfms) size_tfms = [o for o in tfms if isinstance(o.tfm, TfmCrop)] if do_resolve: for tfm in tfms: tfm.resolve() x = self.clone() for tfm in tfms: if tfm in size_tfms: crop_target = self._get_duration_crop_target(duration) x = tfm(x, size=crop_target, padding_mode=padding_mode) else: x = tfm(x) # below is the resizing part, `separate from cropping` if size_factor: # read target size from size dictionary, passed to transform method, default to own length (no resize) _, *orig_size = x.shape orig_size = tuple(orig_size) # to multiply element-wise new_size = tuple(np.array(orig_size)*np.array(size_factor)) if new_size != orig_size: x.resize(new_size) # if x.config is not None: x.config._sr *= new_size/orig_size return x
def plotActsHist(self, cols=10, toDisplay=None, hScale=.05, showEpochs=False, showLayerInfo=True, aspectAuto=True, showImage=True): histsTensor = self.activations_histogram.stats_hist hists = [histsTensor[i] for i in range(histsTensor.shape[0])] if toDisplay: hists = [hists[i] for i in listify(toDisplay)] # optionally focus n = len(hists) cols = cols or 3 cols = min(cols, n) rows = int(math.ceil(n / cols)) fig = plt.figure(figsize=(20, rows * 4.5)) grid = plt.GridSpec(rows, cols, figure=fig, left=None, bottom=None, right=None, top=None, wspace=.25, hspace=.25) for i, l in enumerate(hists): img = self.getHistImg(l, self.useClasses) dead = self.getMin(l, self.useClasses, self.zero_bin) cr = math.floor(i / cols) cc = i % cols main_ax = fig.add_subplot(grid[cr, cc]) if showImage: main_ax.imshow(img) layerId = listify(toDisplay)[i] if toDisplay else i m = self.allModules[layerId] outShapeText = f' (out: {list(self.shape_out[m])})' if ( m in self.shape_out) else '' title = f'L:{layerId}' + '\n' + splitAtFirstParenthesis( str(m), False, outShapeText) main_ax.set_title(title, fontsize=8, weight='bold') imgH = img.shape[0] main_ax.set_yticks([]) main_ax.set_ylabel(str(self.hMin) + " : " + str(self.hMax)) if aspectAuto: main_ax.set_aspect('auto') imgW = img.shape[1] imgH = img.shape[0] ratioH = -self.hMin / (self.hMax - self.hMin) zeroPosH = imgH * ratioH main_ax.plot(dead * l.shape[1], 'r', linewidth=2) # X Axis main_ax.plot([0, imgW], [zeroPosH, zeroPosH], 'black') # X Axis if (showEpochs): start = 0 nEpochs = len(self.activations_histogram.stats_epoch) for i, hh in enumerate(self.activations_histogram.stats_epoch): if (i < (nEpochs - 1)): main_ax.plot([hh, hh], [0, imgH], color=[0, 0, 1]) end = hh # rolling domain = l[start:end] domain_mean = domain.mean(-1) # mean on classes if self.useClasses: self.plotPerc(main_ax, domain, hScale, 1, start, colorById=True, addLabel=(0 == i)) #plot all main_ax.legend(loc='upper left') else: self.plotPerc(main_ax, domain_mean, hScale, .5, start) self.plotPerc(main_ax, domain_mean, hScale, 1, start, linewidth=1.5) start = hh main_ax.set_xlim([0, imgW]) main_ax.set_ylim([0, imgH]) plt.show() if showLayerInfo: for i, l in enumerate(self.allModules): print('{:2} {}'.format(i, l))