예제 #1
0
 def plot_sched(self, keys=None, figsize=None):
     assert hasattr(self, 'hps'), '[ERROR] You must run ParamSchedCallback to plot the sched.'
     keys = self.hps.keys() if keys is None else listify(keys)
     rows,cols = (len(keys)+1)//2, min(2, len(keys))
     figsize = figsize or (6*cols,4*rows)
     _, axs = plt.subplots(rows, cols, figsize=figsize)
     axs = axs.flatten() if len(keys) > 1 else listify(axs)
     for p,ax in zip(keys, axs):
         ax.plot(self.hps[p])
         ax.set_ylabel(p)
예제 #2
0
 def fit_one_cycle(self, epochs, lr_max=None, div=25., div_final=1e5, pct_start=0.25, wd=None,
                   moms=None, cbs=None, reset_opt=False):
     lr = lr_max
     if lr is None:
         lr = self.lr
     if reset_opt or not self.opt:
         self.create_opt()
     set_hyper(self.opt, 'lr', lr)
     lr_max = np.array([p['lr'] for p in self.opt.param_groups])
     scheds = {
         'lr': combined_cos(pct_start, lr_max/div, lr_max, lr_max/div_final),
         'momentum': combined_cos(pct_start, *(self.moms if moms is None else moms))
     }
     self.fit(epochs, cbs=[ParamScheduler(scheds)] + listify(cbs), reset_opt=reset_opt, wd=wd)
예제 #3
0
def set_hyper(opt, key, val):
    if isinstance(val, slice):
        if val.start:
            val = even_mults(val.start, val.stop, len(opt.param_groups))
        else:
            val = [val.stop / 10] * (len(opt.param_groups) - 1) + (val.stop)
    vs = listify(val)
    if len(vs) == 1:
        vs = vs * len(opt.param_groups)
    assert len(vs) == len(
        opt.param_groups
    ), f"Trying to set {len(vs)} values for {n} but there are {len(opt.param_groups)} parameter groups."
    for v, p in zip(vs, opt.param_groups):
        p[key] = v
예제 #4
0
    def __init__(self, model, data, loss_func,
                 opt_func, lr=1e-2, wd=None, moms=(0.95, 0.85, 0.95),
                 metrics=None, cbs=None, cb_funcs=None,
                 splitter=get_trainable_params,
                 path=Path('.'), model_dir='models'):
        self.model, self.data, self.loss_func = model, data, loss_func
        self.opt_func, self.lr, self.metrics = opt_func, lr, metrics
        self.splitter = splitter
        self.wd, self.moms = wd, moms
        self.path = path
        self.model_dir = self.path / model_dir

        self.opt = None
        self.cbs = []
        self.in_train = False
        self.epoch = 0
        self.epochs = 1
        self.loss = tensor(0.)
        self.logger = print

        self.add_cbs([cb() for cb in self._default_cbs])
        self.add_cbs(cbs)
        self.add_cbs([cbf() for cbf in listify(cb_funcs)])
예제 #5
0
 def remove_cbs(self, cbs):
     for cb in listify(cbs): self.cbs.remove(cb)
예제 #6
0
 def add_cbs(self, cbs):
     for cb in listify(cbs): self.add_cb(cb)
예제 #7
0
 def __init__(self, metrics, in_train):
     self.metrics, self.in_train = listify(metrics), in_train
     self.reset()