Example #1
0
 def __init__(self,
              vc: VarCollection,
              momentum: float = 0.999,
              debias: bool = False,
              eps: float = 1e-6):
     """Creates ExponentialMovingAverage instance with given hyperparameters.
     
     Args:
         momentum: the decay factor for the moving average.
         debias: bool indicating whether to use initialization bias correction.
         eps: small adjustment to prevent division by zero.
     """
     self.momentum = momentum
     self.debias = debias
     self.eps = eps
     self.step = StateVar(jn.array(0, jn.uint32), reduce=lambda x: x[0])
     # Deduplicate variables and skip RandomState vars since they cannot be averaged.
     trainable, non_trainable = {}, {
     }  # Use dicts since they are ordered since python >= 3.6
     for v in vc:
         if isinstance(v, RandomState):
             continue
         if isinstance(v, TrainRef):
             v = v.ref
         if isinstance(v, TrainVar):
             trainable[v] = True
         else:
             non_trainable[v] = True
     self.refs = ModuleList(
         list(non_trainable.keys()) +
         [TrainRef(v) for v in trainable.keys()])
     self.m = ModuleList(
         StateVar(jn.zeros_like(x.value)) for x in self.refs)
Example #2
0
    def __init__(self, vc: VarCollection, beta1: float = 0.9, beta2: float = 0.999, eps: float = 1e-8):
        """Constructor for Adam optimizer class.

        Args:
            vc: collection of variables to optimize.
            beta1: value of Adam's beta1 hyperparameter. Defaults to 0.9.
            beta2: value of Adam's beta2 hyperparameter. Defaults to 0.999.
            eps: value of Adam's epsilon hyperparameter. Defaults to 1e-8.
        """
        self.beta1 = beta1
        self.beta2 = beta2
        self.eps = eps
        self.step = StateVar(jn.array(0, jn.uint32), reduce=lambda x: x[0])
        self.train_vars = ModuleList(TrainRef(x) for x in vc.subset(TrainVar))
        self.m = ModuleList(StateVar(jn.zeros_like(x.value)) for x in self.train_vars)
        self.v = ModuleList(StateVar(jn.zeros_like(x.value)) for x in self.train_vars)
Example #3
0
    def __init__(self,
                 vc: VarCollection,
                 momentum: float = 0.9,
                 nesterov: bool = False):
        """Constructor for momentum optimizer class.

        Args:
            vc: collection of variables to optimize.
            momentum: the momentum hyperparameter.
            nesterov: bool indicating whether to use the Nesterov method.
        """
        self.momentum = momentum
        self.nesterov = nesterov
        self.train_vars = ModuleList(TrainRef(x) for x in vc.subset(TrainVar))
        self.m = ModuleList(
            StateVar(jn.zeros_like(x.value)) for x in self.train_vars)
Example #4
0
    def __init__(self, vc: VarCollection):
        """Constructor for SGD optimizer.

        Args:
            vc: collection of variables to optimize.
        """
        self.train_vars = ModuleList(TrainRef(x) for x in vc.subset(TrainVar))
Example #5
0
    def __init__(self,
                 vc: VarCollection,
                 momentum: float = 0.9,
                 weight_decay: float = 1e-4,
                 tc: float = 1e-3,
                 eps: float = 1e-5):
        """Constructor for LARS optimizer.

        Args:
            vc: collection of variables to optimize.
            momentum: coefficient used for the moving average of the gradient.
            weight_decay: weight decay coefficient.
            tc: trust coefficient eta ( < 1) for trust ratio computation.
            eps: epsilon used for trust ratio computation.
        """
        self.momentum = momentum
        self.weight_decay = weight_decay
        self.tc = tc
        self.eps = eps
        self.train_vars = ModuleList(TrainRef(x) for x in vc.subset(TrainVar))
        self.m = ModuleList(
            StateVar(jn.zeros_like(x.value)) for x in self.train_vars)
Example #6
0
 def __init__(self, vc: VarCollection, base_optimizer: Callable, **kwargs):
     self.train_vars = ModuleList(TrainRef(x) for x in vc.subset(TrainVar))
     self.base_optimizer = base_optimizer(vc, **kwargs)
     self.state = defaultdict(dict)