예제 #1
0
 def __init__(self,
              vc: VarCollection,
              momentum: float = 0.999,
              debias: bool = False,
              eps: float = 1e-6):
     """Creates ExponentialMovingAverage instance with given hyperparameters.
     
     Args:
         momentum: the decay factor for the moving average.
         debias: bool indicating whether to use initialization bias correction.
         eps: small adjustment to prevent division by zero.
     """
     self.momentum = momentum
     self.debias = debias
     self.eps = eps
     self.step = StateVar(jn.array(0, jn.uint32), reduce=lambda x: x[0])
     # Deduplicate variables and skip RandomState vars since they cannot be averaged.
     trainable, non_trainable = {}, {
     }  # Use dicts since they are ordered since python >= 3.6
     for v in vc:
         if isinstance(v, RandomState):
             continue
         if isinstance(v, TrainRef):
             v = v.ref
         if isinstance(v, TrainVar):
             trainable[v] = True
         else:
             non_trainable[v] = True
     self.refs = ModuleList(
         list(non_trainable.keys()) +
         [TrainRef(v) for v in trainable.keys()])
     self.m = ModuleList(
         StateVar(jn.zeros_like(x.value)) for x in self.refs)
예제 #2
0
파일: adam.py 프로젝트: spacexcorp/objax
    def __init__(self, vc: VarCollection, beta1: float = 0.9, beta2: float = 0.999, eps: float = 1e-8):
        """Constructor for Adam optimizer class.

        Args:
            vc: collection of variables to optimize.
            beta1: value of Adam's beta1 hyperparameter. Defaults to 0.9.
            beta2: value of Adam's beta2 hyperparameter. Defaults to 0.999.
            eps: value of Adam's epsilon hyperparameter. Defaults to 1e-8.
        """
        self.beta1 = beta1
        self.beta2 = beta2
        self.eps = eps
        self.step = StateVar(jn.array(0, jn.uint32), reduce=lambda x: x[0])
        self.train_vars = ModuleList(TrainRef(x) for x in vc.subset(TrainVar))
        self.m = ModuleList(StateVar(jn.zeros_like(x.value)) for x in self.train_vars)
        self.v = ModuleList(StateVar(jn.zeros_like(x.value)) for x in self.train_vars)
예제 #3
0
    def __init__(self,
                 vc: VarCollection,
                 momentum: float = 0.9,
                 nesterov: bool = False):
        """Constructor for momentum optimizer class.

        Args:
            vc: collection of variables to optimize.
            momentum: the momentum hyperparameter.
            nesterov: bool indicating whether to use the Nesterov method.
        """
        self.momentum = momentum
        self.nesterov = nesterov
        self.train_vars = ModuleList(TrainRef(x) for x in vc.subset(TrainVar))
        self.m = ModuleList(
            StateVar(jn.zeros_like(x.value)) for x in self.train_vars)
예제 #4
0
파일: sgd.py 프로젝트: utkarshgiri/objax
    def __init__(self, vc: VarCollection):
        """Constructor for SGD optimizer.

        Args:
            vc: collection of variables to optimize.
        """
        self.train_vars = ModuleList(TrainRef(x) for x in vc.subset(TrainVar))
예제 #5
0
파일: lars.py 프로젝트: peterjliu/objax
    def __init__(self,
                 vc: VarCollection,
                 momentum: float = 0.9,
                 weight_decay: float = 1e-4,
                 tc: float = 1e-3,
                 eps: float = 1e-5):
        """Constructor for LARS optimizer.

        Args:
            vc: collection of variables to optimize.
            momentum: coefficient used for the moving average of the gradient.
            weight_decay: weight decay coefficient.
            tc: trust coefficient eta ( < 1) for trust ratio computation.
            eps: epsilon used for trust ratio computation.
        """
        self.momentum = momentum
        self.weight_decay = weight_decay
        self.tc = tc
        self.eps = eps
        self.train_vars = ModuleList(TrainRef(x) for x in vc.subset(TrainVar))
        self.m = ModuleList(
            StateVar(jn.zeros_like(x.value)) for x in self.train_vars)
예제 #6
0
class ExponentialMovingAverage(Module):
    """Maintains exponential moving averages for each variable from provided VarCollection."""
    def __init__(self,
                 vc: VarCollection,
                 momentum: float = 0.999,
                 debias: bool = False,
                 eps: float = 1e-6):
        """Creates ExponentialMovingAverage instance with given hyperparameters.
        
        Args:
            momentum: the decay factor for the moving average.
            debias: bool indicating whether to use initialization bias correction.
            eps: small adjustment to prevent division by zero.
        """
        self.momentum = momentum
        self.debias = debias
        self.eps = eps
        self.step = StateVar(jn.array(0, jn.uint32), reduce=lambda x: x[0])
        # Deduplicate variables and skip RandomState vars since they cannot be averaged.
        trainable, non_trainable = {}, {
        }  # Use dicts since they are ordered since python >= 3.6
        for v in vc:
            if isinstance(v, RandomState):
                continue
            if isinstance(v, TrainRef):
                v = v.ref
            if isinstance(v, TrainVar):
                trainable[v] = True
            else:
                non_trainable[v] = True
        self.refs = ModuleList(
            list(non_trainable.keys()) +
            [TrainRef(v) for v in trainable.keys()])
        self.m = ModuleList(
            StateVar(jn.zeros_like(x.value)) for x in self.refs)

    def __call__(self):
        """Updates the moving average."""
        self.step.value += 1
        for ref, m in zip(self.refs, self.m):
            m.value += (1 - self.momentum) * (ref.value - m.value)

    def refs_and_values(self) -> Tuple[VarCollection, List[JaxArray]]:
        """Returns the VarCollection of variables affected by Exponential Moving Average (EMA) and
        their corresponding EMA values."""
        if self.debias:
            step = self.step.value
            debias = 1 / (1 - (1 - self.eps) * self.momentum**step)
            tensors = [m.value * debias for ref, m in zip(self.refs, self.m)]
        else:
            tensors = self.m.vars().tensors()
        return self.refs.vars(), tensors

    def replace_vars(self, f: Callable):
        """Returns a function that acts as f called when variables are replaced by their averages.

        Args:
            f: function to be called on the stored averages.

        Returns:
            A function that returns the output of calling f with stored variables replaced by
            their moving averages.
        """
        def wrap(*args, **kwargs):
            refs, new_values = self.refs_and_values()
            original_values = refs.tensors()
            refs.assign(new_values)
            output = f(*args, **kwargs)
            refs.assign(original_values)
            return output

        return wrap
예제 #7
0
 def __init__(self, vc: VarCollection, base_optimizer: Callable, **kwargs):
     self.train_vars = ModuleList(TrainRef(x) for x in vc.subset(TrainVar))
     self.base_optimizer = base_optimizer(vc, **kwargs)
     self.state = defaultdict(dict)