def __init__(self, vc: VarCollection): """Constructor for SGD optimizer. Args: vc: collection of variables to optimize. """ self.train_vars = ModuleList(TrainRef(x) for x in vc.subset(TrainVar))
def __init__(self, vc: VarCollection, beta1: float = 0.9, beta2: float = 0.999, eps: float = 1e-8): """Constructor for Adam optimizer class. Args: vc: collection of variables to optimize. beta1: value of Adam's beta1 hyperparameter. Defaults to 0.9. beta2: value of Adam's beta2 hyperparameter. Defaults to 0.999. eps: value of Adam's epsilon hyperparameter. Defaults to 1e-8. """ self.beta1 = beta1 self.beta2 = beta2 self.eps = eps self.step = StateVar(jn.array(0, jn.uint32), reduce=lambda x: x[0]) self.train_vars = ModuleList(TrainRef(x) for x in vc.subset(TrainVar)) self.m = ModuleList(StateVar(jn.zeros_like(x.value)) for x in self.train_vars) self.v = ModuleList(StateVar(jn.zeros_like(x.value)) for x in self.train_vars)
def __init__(self, vc: VarCollection, momentum: float = 0.9, nesterov: bool = False): """Constructor for momentum optimizer class. Args: vc: collection of variables to optimize. momentum: the momentum hyperparameter. nesterov: bool indicating whether to use the Nesterov method. """ self.momentum = momentum self.nesterov = nesterov self.train_vars = ModuleList(TrainRef(x) for x in vc.subset(TrainVar)) self.m = ModuleList( StateVar(jn.zeros_like(x.value)) for x in self.train_vars)
def __init__(self, vc: VarCollection, momentum: float = 0.9, weight_decay: float = 1e-4, tc: float = 1e-3, eps: float = 1e-5): """Constructor for LARS optimizer. Args: vc: collection of variables to optimize. momentum: coefficient used for the moving average of the gradient. weight_decay: weight decay coefficient. tc: trust coefficient eta ( < 1) for trust ratio computation. eps: epsilon used for trust ratio computation. """ self.momentum = momentum self.weight_decay = weight_decay self.tc = tc self.eps = eps self.train_vars = ModuleList(TrainRef(x) for x in vc.subset(TrainVar)) self.m = ModuleList( StateVar(jn.zeros_like(x.value)) for x in self.train_vars)
def __init__(self, vc: VarCollection, base_optimizer: Callable, **kwargs): self.train_vars = ModuleList(TrainRef(x) for x in vc.subset(TrainVar)) self.base_optimizer = base_optimizer(vc, **kwargs) self.state = defaultdict(dict)