예제 #1
0
 def __init__(self, params, steppers, **defaults):
     self.steppers = listify(steppers)
     # if '_defaults' attribute of every stepper is not in defaults, add it
     maybe_update(self.steppers, defaults, get_defaults)
     # might be a generator
     self.param_groups = list(params)
     # ensure params is a list of lists
     if not isinstance(self.param_groups[0], list):
         self.param_groups = [self.param_groups]
     self.hypers = [{**defaults} for p in self.param_groups]
예제 #2
0
def adam_opt(xtra_step=None, **kwargs):
    "adam optimizer"
    return partial(StatefulOptimizer,
                   steppers=[adam_step, weight_decay] + listify(xtra_step),
                   stateupdaters=[
                       AverageGrad(dampening=True),
                       AverageSqrGrad(),
                       StepCount()
                   ],
                   **kwargs)
예제 #3
0
def compose(x, funcs, *args, order_key="_order", **kwargs):
    "return steppers' result in the ascending way of steppers' order"

    # key = lambda o: getattr(o, order_key, 0)
    def key(obj):
        "get steppers' order"
        return getattr(obj, order_key, 0)

    for f in sorted(listify(funcs), key=key):
        x = f(x, **kwargs)
    return x
예제 #4
0
    def __init__(self,
                 model,
                 data,
                 loss_func,
                 opt_func,
                 lr=1e-2,
                 cbs=None,
                 cb_funcs=None):
        self.model = model
        self.data = data
        self.loss_func = loss_func
        self.opt_func = opt_func
        self.lr = lr
        self.in_train = False
        self.logger = print
        self.opt = None

        # NB: Things marked "NEW" are covered in lesson 12
        # NEW: avoid need for set_runner
        self.cbs = []
        # self.add_cb(TrainEvalCallback())
        self.add_cbs(cbs)
        self.add_cbs(cbf() for cbf in listify(cb_funcs))
예제 #5
0
파일: stats.py 프로젝트: uchange/ulangel
 def __init__(self, metrics, in_train):
     self.metrics = listify(metrics)
     self.in_train = in_train
예제 #6
0
 def remove_cbs(self, cbs):
     for cb in listify(cbs):
         self.cbs.remove(cb)
예제 #7
0
 def add_cbs(self, cbs):
     for cb in listify(cbs):
         self.add_cb(cb)
예제 #8
0
 def __init__(self, params, steppers, stateupdaters=None, **defaults):
     self.stateupdaters = listify(stateupdaters)
     maybe_update(self.stateupdaters, defaults, get_defaults)
     super().__init__(params, steppers, **defaults)
     self.state = {}