def __init__(self, vc: VarCollection, momentum: float = 0.999, debias: bool = False, eps: float = 1e-6): """Creates ExponentialMovingAverage instance with given hyperparameters. Args: momentum: the decay factor for the moving average. debias: bool indicating whether to use initialization bias correction. eps: small adjustment to prevent division by zero. """ self.momentum = momentum self.debias = debias self.eps = eps self.step = StateVar(jn.array(0, jn.uint32), reduce=lambda x: x[0]) # Deduplicate variables and skip RandomState vars since they cannot be averaged. trainable, non_trainable = {}, { } # Use dicts since they are ordered since python >= 3.6 for v in vc: if isinstance(v, RandomState): continue if isinstance(v, TrainRef): v = v.ref if isinstance(v, TrainVar): trainable[v] = True else: non_trainable[v] = True self.refs = ModuleList( list(non_trainable.keys()) + [TrainRef(v) for v in trainable.keys()]) self.m = ModuleList( StateVar(jn.zeros_like(x.value)) for x in self.refs)
def __init__(self, vc: VarCollection, beta1: float = 0.9, beta2: float = 0.999, eps: float = 1e-8): """Constructor for Adam optimizer class. Args: vc: collection of variables to optimize. beta1: value of Adam's beta1 hyperparameter. Defaults to 0.9. beta2: value of Adam's beta2 hyperparameter. Defaults to 0.999. eps: value of Adam's epsilon hyperparameter. Defaults to 1e-8. """ self.beta1 = beta1 self.beta2 = beta2 self.eps = eps self.step = StateVar(jn.array(0, jn.uint32), reduce=lambda x: x[0]) self.train_vars = ModuleList(TrainRef(x) for x in vc.subset(TrainVar)) self.m = ModuleList(StateVar(jn.zeros_like(x.value)) for x in self.train_vars) self.v = ModuleList(StateVar(jn.zeros_like(x.value)) for x in self.train_vars)
def __init__(self, vc: VarCollection, momentum: float = 0.9, nesterov: bool = False): """Constructor for momentum optimizer class. Args: vc: collection of variables to optimize. momentum: the momentum hyperparameter. nesterov: bool indicating whether to use the Nesterov method. """ self.momentum = momentum self.nesterov = nesterov self.train_vars = ModuleList(TrainRef(x) for x in vc.subset(TrainVar)) self.m = ModuleList( StateVar(jn.zeros_like(x.value)) for x in self.train_vars)
def __init__(self, vc: VarCollection): """Constructor for SGD optimizer. Args: vc: collection of variables to optimize. """ self.train_vars = ModuleList(TrainRef(x) for x in vc.subset(TrainVar))
def __init__(self, vc: VarCollection, momentum: float = 0.9, weight_decay: float = 1e-4, tc: float = 1e-3, eps: float = 1e-5): """Constructor for LARS optimizer. Args: vc: collection of variables to optimize. momentum: coefficient used for the moving average of the gradient. weight_decay: weight decay coefficient. tc: trust coefficient eta ( < 1) for trust ratio computation. eps: epsilon used for trust ratio computation. """ self.momentum = momentum self.weight_decay = weight_decay self.tc = tc self.eps = eps self.train_vars = ModuleList(TrainRef(x) for x in vc.subset(TrainVar)) self.m = ModuleList( StateVar(jn.zeros_like(x.value)) for x in self.train_vars)
class ExponentialMovingAverage(Module): """Maintains exponential moving averages for each variable from provided VarCollection.""" def __init__(self, vc: VarCollection, momentum: float = 0.999, debias: bool = False, eps: float = 1e-6): """Creates ExponentialMovingAverage instance with given hyperparameters. Args: momentum: the decay factor for the moving average. debias: bool indicating whether to use initialization bias correction. eps: small adjustment to prevent division by zero. """ self.momentum = momentum self.debias = debias self.eps = eps self.step = StateVar(jn.array(0, jn.uint32), reduce=lambda x: x[0]) # Deduplicate variables and skip RandomState vars since they cannot be averaged. trainable, non_trainable = {}, { } # Use dicts since they are ordered since python >= 3.6 for v in vc: if isinstance(v, RandomState): continue if isinstance(v, TrainRef): v = v.ref if isinstance(v, TrainVar): trainable[v] = True else: non_trainable[v] = True self.refs = ModuleList( list(non_trainable.keys()) + [TrainRef(v) for v in trainable.keys()]) self.m = ModuleList( StateVar(jn.zeros_like(x.value)) for x in self.refs) def __call__(self): """Updates the moving average.""" self.step.value += 1 for ref, m in zip(self.refs, self.m): m.value += (1 - self.momentum) * (ref.value - m.value) def refs_and_values(self) -> Tuple[VarCollection, List[JaxArray]]: """Returns the VarCollection of variables affected by Exponential Moving Average (EMA) and their corresponding EMA values.""" if self.debias: step = self.step.value debias = 1 / (1 - (1 - self.eps) * self.momentum**step) tensors = [m.value * debias for ref, m in zip(self.refs, self.m)] else: tensors = self.m.vars().tensors() return self.refs.vars(), tensors def replace_vars(self, f: Callable): """Returns a function that acts as f called when variables are replaced by their averages. Args: f: function to be called on the stored averages. Returns: A function that returns the output of calling f with stored variables replaced by their moving averages. """ def wrap(*args, **kwargs): refs, new_values = self.refs_and_values() original_values = refs.tensors() refs.assign(new_values) output = f(*args, **kwargs) refs.assign(original_values) return output return wrap
def __init__(self, vc: VarCollection, base_optimizer: Callable, **kwargs): self.train_vars = ModuleList(TrainRef(x) for x in vc.subset(TrainVar)) self.base_optimizer = base_optimizer(vc, **kwargs) self.state = defaultdict(dict)