def __init__(self, vc: VarCollection, momentum: float = 0.999, debias: bool = False, eps: float = 1e-6): """Creates ExponentialMovingAverage instance with given hyperparameters. Args: momentum: the decay factor for the moving average. debias: bool indicating whether to use initialization bias correction. eps: small adjustment to prevent division by zero. """ self.momentum = momentum self.debias = debias self.eps = eps self.step = StateVar(jn.array(0, jn.uint32), reduce=lambda x: x[0]) # Deduplicate variables and skip RandomState vars since they cannot be averaged. trainable, non_trainable = {}, { } # Use dicts since they are ordered since python >= 3.6 for v in vc: if isinstance(v, RandomState): continue if isinstance(v, TrainRef): v = v.ref if isinstance(v, TrainVar): trainable[v] = True else: non_trainable[v] = True self.refs = ModuleList( list(non_trainable.keys()) + [TrainRef(v) for v in trainable.keys()]) self.m = ModuleList( StateVar(jn.zeros_like(x.value)) for x in self.refs)
def __init__(self, vc: VarCollection, beta1: float = 0.9, beta2: float = 0.999, eps: float = 1e-8): """Constructor for Adam optimizer class. Args: vc: collection of variables to optimize. beta1: value of Adam's beta1 hyperparameter. Defaults to 0.9. beta2: value of Adam's beta2 hyperparameter. Defaults to 0.999. eps: value of Adam's epsilon hyperparameter. Defaults to 1e-8. """ self.beta1 = beta1 self.beta2 = beta2 self.eps = eps self.step = StateVar(jn.array(0, jn.uint32), reduce=lambda x: x[0]) self.train_vars = ModuleList(TrainRef(x) for x in vc.subset(TrainVar)) self.m = ModuleList(StateVar(jn.zeros_like(x.value)) for x in self.train_vars) self.v = ModuleList(StateVar(jn.zeros_like(x.value)) for x in self.train_vars)
def __init__(self, dims: Iterable[int], redux: Iterable[int], momentum: float = 0.999, eps: float = 1e-6): """Creates a BatchNorm module instance. Args: dims: shape of the batch normalization state variables. redux: list of indices of reduction axes. Batch norm statistics are computed by averaging over these axes. momentum: value used to compute exponential moving average of batch statistics. eps: small value which is used for numerical stability. """ super().__init__() dims = tuple(dims) self.momentum = momentum self.eps = eps self.redux = tuple(redux) self.running_mean = StateVar(jn.zeros(dims)) self.running_var = StateVar(jn.ones(dims)) self.beta = TrainVar(jn.zeros(dims)) self.gamma = TrainVar(jn.ones(dims))
def __init__(self, shape: Tuple[int, ...], buffer_size: int, init_value: float = 0): """Creates a MovingAverage module instance. Args: shape: shape of the input tensor. buffer_size: buffer size for moving average. init_value: initial value for moving average buffer. """ self.buffer = StateVar(jn.zeros((buffer_size,) + shape) + init_value)
def __init__(self, shape: Tuple[int, ...], momentum: float = 0.999, init_value: float = 0): """Creates a ExponentialMovingAverage module instance. Args: shape: shape of the input tensor. momentum: momentum for exponential decrease of accumulated value. init_value: initial value for exponential moving average. """ self.momentum = momentum self.avg = StateVar(jn.zeros(shape) + init_value)
def __init__(self, vc: VarCollection, momentum: float = 0.9, nesterov: bool = False): """Constructor for momentum optimizer class. Args: vc: collection of variables to optimize. momentum: the momentum hyperparameter. nesterov: bool indicating whether to use the Nesterov method. """ self.momentum = momentum self.nesterov = nesterov self.train_vars = ModuleList(TrainRef(x) for x in vc.subset(TrainVar)) self.m = ModuleList( StateVar(jn.zeros_like(x.value)) for x in self.train_vars)
def __init__(self, vc: VarCollection, momentum: float = 0.9, weight_decay: float = 1e-4, tc: float = 1e-3, eps: float = 1e-5): """Constructor for LARS optimizer. Args: vc: collection of variables to optimize. momentum: coefficient used for the moving average of the gradient. weight_decay: weight decay coefficient. tc: trust coefficient eta ( < 1) for trust ratio computation. eps: epsilon used for trust ratio computation. """ self.momentum = momentum self.weight_decay = weight_decay self.tc = tc self.eps = eps self.train_vars = ModuleList(TrainRef(x) for x in vc.subset(TrainVar)) self.m = ModuleList( StateVar(jn.zeros_like(x.value)) for x in self.train_vars)
def init_state(self, batch_size): """Initialize hidden state for input batch of size ``batch_size``.""" self.state = StateVar(jn.zeros((batch_size, self.nstate)))