예제 #1
0
 def __init__(self,
              vc: VarCollection,
              momentum: float = 0.999,
              debias: bool = False,
              eps: float = 1e-6):
     """Creates ExponentialMovingAverage instance with given hyperparameters.
     
     Args:
         momentum: the decay factor for the moving average.
         debias: bool indicating whether to use initialization bias correction.
         eps: small adjustment to prevent division by zero.
     """
     self.momentum = momentum
     self.debias = debias
     self.eps = eps
     self.step = StateVar(jn.array(0, jn.uint32), reduce=lambda x: x[0])
     # Deduplicate variables and skip RandomState vars since they cannot be averaged.
     trainable, non_trainable = {}, {
     }  # Use dicts since they are ordered since python >= 3.6
     for v in vc:
         if isinstance(v, RandomState):
             continue
         if isinstance(v, TrainRef):
             v = v.ref
         if isinstance(v, TrainVar):
             trainable[v] = True
         else:
             non_trainable[v] = True
     self.refs = ModuleList(
         list(non_trainable.keys()) +
         [TrainRef(v) for v in trainable.keys()])
     self.m = ModuleList(
         StateVar(jn.zeros_like(x.value)) for x in self.refs)
예제 #2
0
파일: adam.py 프로젝트: spacexcorp/objax
    def __init__(self, vc: VarCollection, beta1: float = 0.9, beta2: float = 0.999, eps: float = 1e-8):
        """Constructor for Adam optimizer class.

        Args:
            vc: collection of variables to optimize.
            beta1: value of Adam's beta1 hyperparameter. Defaults to 0.9.
            beta2: value of Adam's beta2 hyperparameter. Defaults to 0.999.
            eps: value of Adam's epsilon hyperparameter. Defaults to 1e-8.
        """
        self.beta1 = beta1
        self.beta2 = beta2
        self.eps = eps
        self.step = StateVar(jn.array(0, jn.uint32), reduce=lambda x: x[0])
        self.train_vars = ModuleList(TrainRef(x) for x in vc.subset(TrainVar))
        self.m = ModuleList(StateVar(jn.zeros_like(x.value)) for x in self.train_vars)
        self.v = ModuleList(StateVar(jn.zeros_like(x.value)) for x in self.train_vars)
예제 #3
0
    def __init__(self, dims: Iterable[int], redux: Iterable[int], momentum: float = 0.999, eps: float = 1e-6):
        """Creates a BatchNorm module instance.

        Args:
            dims: shape of the batch normalization state variables.
            redux: list of indices of reduction axes. Batch norm statistics are computed by averaging over these axes.
            momentum: value used to compute exponential moving average of batch statistics.
            eps: small value which is used for numerical stability.
        """
        super().__init__()
        dims = tuple(dims)
        self.momentum = momentum
        self.eps = eps
        self.redux = tuple(redux)
        self.running_mean = StateVar(jn.zeros(dims))
        self.running_var = StateVar(jn.ones(dims))
        self.beta = TrainVar(jn.zeros(dims))
        self.gamma = TrainVar(jn.ones(dims))
예제 #4
0
    def __init__(self, shape: Tuple[int, ...], buffer_size: int, init_value: float = 0):
        """Creates a MovingAverage module instance.

        Args:
            shape: shape of the input tensor.
            buffer_size: buffer size for moving average.
            init_value: initial value for moving average buffer.
        """
        self.buffer = StateVar(jn.zeros((buffer_size,) + shape) + init_value)
예제 #5
0
    def __init__(self, shape: Tuple[int, ...], momentum: float = 0.999, init_value: float = 0):
        """Creates a ExponentialMovingAverage module instance.

        Args:
            shape: shape of the input tensor.
            momentum: momentum for exponential decrease of accumulated value.
            init_value: initial value for exponential moving average.
        """
        self.momentum = momentum
        self.avg = StateVar(jn.zeros(shape) + init_value)
예제 #6
0
    def __init__(self,
                 vc: VarCollection,
                 momentum: float = 0.9,
                 nesterov: bool = False):
        """Constructor for momentum optimizer class.

        Args:
            vc: collection of variables to optimize.
            momentum: the momentum hyperparameter.
            nesterov: bool indicating whether to use the Nesterov method.
        """
        self.momentum = momentum
        self.nesterov = nesterov
        self.train_vars = ModuleList(TrainRef(x) for x in vc.subset(TrainVar))
        self.m = ModuleList(
            StateVar(jn.zeros_like(x.value)) for x in self.train_vars)
예제 #7
0
파일: lars.py 프로젝트: peterjliu/objax
    def __init__(self,
                 vc: VarCollection,
                 momentum: float = 0.9,
                 weight_decay: float = 1e-4,
                 tc: float = 1e-3,
                 eps: float = 1e-5):
        """Constructor for LARS optimizer.

        Args:
            vc: collection of variables to optimize.
            momentum: coefficient used for the moving average of the gradient.
            weight_decay: weight decay coefficient.
            tc: trust coefficient eta ( < 1) for trust ratio computation.
            eps: epsilon used for trust ratio computation.
        """
        self.momentum = momentum
        self.weight_decay = weight_decay
        self.tc = tc
        self.eps = eps
        self.train_vars = ModuleList(TrainRef(x) for x in vc.subset(TrainVar))
        self.m = ModuleList(
            StateVar(jn.zeros_like(x.value)) for x in self.train_vars)
예제 #8
0
파일: gru.py 프로젝트: aterzis-google/objax
 def init_state(self, batch_size):
     """Initialize hidden state for input batch of size ``batch_size``."""
     self.state = StateVar(jn.zeros((batch_size, self.nstate)))