def plot_sched(self, keys=None, figsize=None): assert hasattr(self, 'hps'), '[ERROR] You must run ParamSchedCallback to plot the sched.' keys = self.hps.keys() if keys is None else listify(keys) rows,cols = (len(keys)+1)//2, min(2, len(keys)) figsize = figsize or (6*cols,4*rows) _, axs = plt.subplots(rows, cols, figsize=figsize) axs = axs.flatten() if len(keys) > 1 else listify(axs) for p,ax in zip(keys, axs): ax.plot(self.hps[p]) ax.set_ylabel(p)
def fit_one_cycle(self, epochs, lr_max=None, div=25., div_final=1e5, pct_start=0.25, wd=None, moms=None, cbs=None, reset_opt=False): lr = lr_max if lr is None: lr = self.lr if reset_opt or not self.opt: self.create_opt() set_hyper(self.opt, 'lr', lr) lr_max = np.array([p['lr'] for p in self.opt.param_groups]) scheds = { 'lr': combined_cos(pct_start, lr_max/div, lr_max, lr_max/div_final), 'momentum': combined_cos(pct_start, *(self.moms if moms is None else moms)) } self.fit(epochs, cbs=[ParamScheduler(scheds)] + listify(cbs), reset_opt=reset_opt, wd=wd)
def set_hyper(opt, key, val): if isinstance(val, slice): if val.start: val = even_mults(val.start, val.stop, len(opt.param_groups)) else: val = [val.stop / 10] * (len(opt.param_groups) - 1) + (val.stop) vs = listify(val) if len(vs) == 1: vs = vs * len(opt.param_groups) assert len(vs) == len( opt.param_groups ), f"Trying to set {len(vs)} values for {n} but there are {len(opt.param_groups)} parameter groups." for v, p in zip(vs, opt.param_groups): p[key] = v
def __init__(self, model, data, loss_func, opt_func, lr=1e-2, wd=None, moms=(0.95, 0.85, 0.95), metrics=None, cbs=None, cb_funcs=None, splitter=get_trainable_params, path=Path('.'), model_dir='models'): self.model, self.data, self.loss_func = model, data, loss_func self.opt_func, self.lr, self.metrics = opt_func, lr, metrics self.splitter = splitter self.wd, self.moms = wd, moms self.path = path self.model_dir = self.path / model_dir self.opt = None self.cbs = [] self.in_train = False self.epoch = 0 self.epochs = 1 self.loss = tensor(0.) self.logger = print self.add_cbs([cb() for cb in self._default_cbs]) self.add_cbs(cbs) self.add_cbs([cbf() for cbf in listify(cb_funcs)])
def remove_cbs(self, cbs): for cb in listify(cbs): self.cbs.remove(cb)
def add_cbs(self, cbs): for cb in listify(cbs): self.add_cb(cb)
def __init__(self, metrics, in_train): self.metrics, self.in_train = listify(metrics), in_train self.reset()