예제 #1
0
    def __call__(self, x: JaxArray, training: bool) -> JaxArray:
        """Performs batch normalization of input tensor.

        Args:
            x: input tensor.
            training: if True compute batch normalization in training mode (accumulating batch statistics),
                otherwise compute in evaluation mode (using already accumulated batch statistics).

        Returns:
            Batch normalized tensor.
        """
        if training:
            m = x.mean(self.redux, keepdims=True)
            v = ((x - m)**2).mean(
                self.redux,
                keepdims=True)  # Note: x^2 - m^2 is not numerically stable.
            self.running_mean.value += (1 - self.momentum) * (
                m - self.running_mean.value)
            self.running_var.value += (1 - self.momentum) * (
                v - self.running_var.value)
        else:
            m, v = self.running_mean.value, self.running_var.value
        y = self.gamma.value * (
            x - m) * functional.rsqrt(v + self.eps) + self.beta.value
        return y
예제 #2
0
파일: adam.py 프로젝트: qingliaowu/objax
    def __call__(self,
                 lr: float,
                 grads: List[JaxArray],
                 beta1: Optional[float] = None,
                 beta2: Optional[float] = None):
        """Updates variables and other state based on Adam algorithm.

        Args:
            lr: the learning rate.
            grads: the gradients to apply.
            beta1: optional, override the default beta1.
            beta2: optional, override the default beta2.
        """
        assert len(grads) == len(
            self.train_vars
        ), 'Expecting as many gradients as trainable variables'
        if beta1 is None:
            beta1 = self.beta1
        if beta2 is None:
            beta2 = self.beta2
        self.step.value += 1
        lr *= jn.sqrt(1 - beta2**self.step.value) / (1 -
                                                     beta1**self.step.value)
        for g, p, m, v in zip(grads, self.train_vars, self.m, self.v):
            m.value += (1 - beta1) * (g - m.value)
            v.value += (1 - beta2) * (g**2 - v.value)
            p.value -= lr * m.value * functional.rsqrt(v.value + self.eps)
예제 #3
0
 def __call__(self, x, mask, training):
     sum_dims = list(range(len(vals.shape[:-1])))
     x_or_zero = jnp.where(mask[..., None], vals, 0 * vals)
     smask = jax.device_put(scalar_mask(self.rep))
     if training:
         num_valid = mask.sum(sum_dims)
         m = x_or_zero.sum(sum_dims) / num_valid
         x2 = (x_or_zero**2).sum(sum_dims) / num_valid
         v = x2 - m**2
         v = jnp.where(smask, v, ragged_gather_scatter(x2, self.rep))
         self.running_mean.value += (1 - self.momentum) * (
             m - self.running_mean.value)
         self.running_var.value += (1 - self.momentum) * (
             v - self.running_var.value)
     else:
         m, v = self.running_mean.value, self.running_var.value
     y = jnp.where(smask,self.gamma.value * (x_or_zero - m) * F.rsqrt(v + self.eps) + \
                         self.beta.value,x_or_zero*F.rsqrt(v+self.eps))
     return y, mask  # switch to or (x-m)
예제 #4
0
 def __call__(self, x: JaxArray, training: bool, batch_norm_update: bool = True) -> JaxArray:
     if training:
         m = functional.parallel.pmean(x.mean(self.redux, keepdims=True))
         v = functional.parallel.pmean((x ** 2).mean(self.redux, keepdims=True) - m ** 2)
         if batch_norm_update:
             self.running_mean.value += (1 - self.momentum) * (m - self.running_mean.value)
             self.running_var.value += (1 - self.momentum) * (v - self.running_var.value)
     else:
         m, v = self.running_mean.value, self.running_var.value
     y = self.gamma.value * (x - m) * functional.rsqrt(v + self.eps) + self.beta.value
     return y
예제 #5
0
파일: layers.py 프로젝트: google/objax
 def __call__(self, x: JaxArray, training: bool = True) -> JaxArray:
     """Returns the results of applying group normalization to input x."""
     # This method has unused training argument, so group norm can be used as a drop-in replacement of batch norm.
     del training
     group_shape = ((-1, self.groups, self.nin // self.groups) + x.shape[2:])
     x = x.reshape(group_shape)
     mean = x.mean(axis=self.redux, keepdims=True)
     var = x.var(axis=self.redux, keepdims=True)
     x = (x - mean) * functional.rsqrt(var + self.eps)
     x = x.reshape((-1, self.nin,) + group_shape[3:])
     x = x * self.gamma.value + self.beta.value
     return x
예제 #6
0
파일: adam.py 프로젝트: spacexcorp/objax
    def __call__(self, lr: float, grads: List[JaxArray]):
        """Updates variables and other state based on Adam algorithm.

        Args:
            lr: the learning rate.
            grads: the gradients to apply.
        """
        assert len(grads) == len(self.train_vars), 'Expecting as many gradients as trainable variables'
        self.step.value += 1
        lr *= jn.sqrt(1 - self.beta2 ** self.step.value) / (1 - self.beta1 ** self.step.value)
        for g, p, m, v in zip(grads, self.train_vars, self.m, self.v):
            m.value += (1 - self.beta1) * (g - m.value)
            v.value += (1 - self.beta2) * (g ** 2 - v.value)
            p.value -= lr * m.value * functional.rsqrt(v.value + self.eps)
예제 #7
0
 def __call__(self, x,
              training):  #TODO: support elementwise for regular reps
     #return x #DISABLE BN, harms performance!! !!
     smask = jax.device_put(scalar_mask(self.rep))
     if training:
         m = x.mean(self.redux, keepdims=True)
         v = (x**2).mean(self.redux, keepdims=True) - m**2
         v = jnp.where(
             smask, v,
             ragged_gather_scatter(
                 (x**2).mean(self.redux),
                 self.rep))  #in non scalar indices, divide by sum squared
         self.running_mean.value += (1 - self.momentum) * (
             m - self.running_mean.value)
         self.running_var.value += (1 - self.momentum) * (
             v - self.running_var.value)
     else:
         m, v = self.running_mean.value, self.running_var.value
     y = jnp.where(smask,
                   self.gamma.value *
                   (x - m) * F.rsqrt(v + self.eps) + self.beta.value,
                   x * F.rsqrt(v + self.eps))  #(x-m)*F.rsqrt(v + self.eps))
     return y  # switch to or (x-m)
예제 #8
0
 def __call__(self, vals, mask, training=True):
     sum_dims = list(range(len(vals.shape[:-1])))
     x_or_zero = jnp.where(mask[..., None], vals, 0 * vals)
     if training:
         num_valid = mask.sum(sum_dims)
         m = x_or_zero.sum(sum_dims) / num_valid
         v = (x_or_zero**2).sum(sum_dims) / num_valid - m**2
         self.running_mean.value += (1 - self.momentum) * (
             m - self.running_mean.value)
         self.running_var.value += (1 - self.momentum) * (
             v - self.running_var.value)
     else:
         m, v = self.running_mean.value, self.running_var.value
     return ((x_or_zero - m) * self.gamma.value * F.rsqrt(v + self.eps) +
             self.beta.value, mask)
예제 #9
0
    def __call__(self, x: JaxArray, training: bool) -> JaxArray:
        """Performs batch normalization of input tensor.

        Args:
            x: input tensor.
            training: if True compute batch normalization in training mode (accumulating batch statistics),
                otherwise compute in evaluation mode (using already accumulated batch statistics).

        Returns:
            Batch normalized tensor.
        """
        shape = (1, -1, 1, 1)
        weight = self.weight.value.reshape(shape)
        bias = self.bias.value.reshape(shape)
        if training:
            mean = x.mean(self.redux, keepdims=True)
            var = (x ** 2).mean(self.redux, keepdims=True) - mean ** 2
            self.running_mean.value += (1 - self.momentum) * (mean.squeeze(axis=self.redux) - self.running_mean.value)
            self.running_var.value += (1 - self.momentum) * (var.squeeze(axis=self.redux) - self.running_var.value)
        else:
            mean, var = self.running_mean.value.reshape(shape), self.running_var.value.reshape(shape)

        y = weight * (x - mean) * functional.rsqrt(var + self.eps) + bias
        return y