def __init__(self, vc: VarCollection, momentum: float = 0.999, debias: bool = False, eps: float = 1e-6): """Creates ExponentialMovingAverage instance with given hyperparameters. Args: momentum: the decay factor for the moving average. debias: bool indicating whether to use initialization bias correction. eps: small adjustment to prevent division by zero. """ self.momentum = momentum self.debias = debias self.eps = eps self.step = StateVar(jn.array(0, jn.uint32), reduce=lambda x: x[0]) # Deduplicate variables and skip RandomState vars since they cannot be averaged. trainable, non_trainable = {}, { } # Use dicts since they are ordered since python >= 3.6 for v in vc: if isinstance(v, RandomState): continue if isinstance(v, TrainRef): v = v.ref if isinstance(v, TrainVar): trainable[v] = True else: non_trainable[v] = True self.refs = ModuleList( list(non_trainable.keys()) + [TrainRef(v) for v in trainable.keys()]) self.m = ModuleList( StateVar(jn.zeros_like(x.value)) for x in self.refs)
def __init__(self, vc: VarCollection, beta1: float = 0.9, beta2: float = 0.999, eps: float = 1e-8): """Constructor for Adam optimizer class. Args: vc: collection of variables to optimize. beta1: value of Adam's beta1 hyperparameter. Defaults to 0.9. beta2: value of Adam's beta2 hyperparameter. Defaults to 0.999. eps: value of Adam's epsilon hyperparameter. Defaults to 1e-8. """ self.beta1 = beta1 self.beta2 = beta2 self.eps = eps self.step = StateVar(jn.array(0, jn.uint32), reduce=lambda x: x[0]) self.train_vars = ModuleList(TrainRef(x) for x in vc.subset(TrainVar)) self.m = ModuleList(StateVar(jn.zeros_like(x.value)) for x in self.train_vars) self.v = ModuleList(StateVar(jn.zeros_like(x.value)) for x in self.train_vars)
def __init__(self, vc: VarCollection, momentum: float = 0.9, nesterov: bool = False): """Constructor for momentum optimizer class. Args: vc: collection of variables to optimize. momentum: the momentum hyperparameter. nesterov: bool indicating whether to use the Nesterov method. """ self.momentum = momentum self.nesterov = nesterov self.train_vars = ModuleList(TrainRef(x) for x in vc.subset(TrainVar)) self.m = ModuleList( StateVar(jn.zeros_like(x.value)) for x in self.train_vars)
def __init__(self, vc: VarCollection): """Constructor for SGD optimizer. Args: vc: collection of variables to optimize. """ self.train_vars = ModuleList(TrainRef(x) for x in vc.subset(TrainVar))
def __init__(self, vc: VarCollection, momentum: float = 0.9, weight_decay: float = 1e-4, tc: float = 1e-3, eps: float = 1e-5): """Constructor for LARS optimizer. Args: vc: collection of variables to optimize. momentum: coefficient used for the moving average of the gradient. weight_decay: weight decay coefficient. tc: trust coefficient eta ( < 1) for trust ratio computation. eps: epsilon used for trust ratio computation. """ self.momentum = momentum self.weight_decay = weight_decay self.tc = tc self.eps = eps self.train_vars = ModuleList(TrainRef(x) for x in vc.subset(TrainVar)) self.m = ModuleList( StateVar(jn.zeros_like(x.value)) for x in self.train_vars)
def __init__(self, vc: VarCollection, base_optimizer: Callable, **kwargs): self.train_vars = ModuleList(TrainRef(x) for x in vc.subset(TrainVar)) self.base_optimizer = base_optimizer(vc, **kwargs) self.state = defaultdict(dict)