예제 #1
0
def orthogonal(shape: Tuple[int, ...],
               gain: float = 1,
               axis: int = -1) -> JaxArray:
    """Returns a uniformly distributed orthogonal tensor from
    `Exact solutions to the nonlinear dynamics of learning in deep linear neural networks
    <https://openreview.net/forum?id=_wzZwKpTDF_9C>`_.

    Args:
        shape: shape of the output tensor.
        gain: optional scaling factor.
        axis: the orthogonalizarion axis

    Returns:
        An orthogonally initialized tensor.
        These tensors will be row-orthonormal along the access specified by
        ``axis``. If the rank of the weight is greater than 2, the shape will be
        flattened in all other dimensions and then will be row-orthonormal along the
        final dimension. Note that this only works if the ``axis`` dimension is
        larger, otherwise the tensor will be transposed (equivalently, it will be
        column orthonormal instead of row orthonormal).
        If the shape is not square, the matrices will have orthonormal rows or
        columns depending on which side is smaller.
    """
    n_rows = shape[axis]
    n_cols = np.prod(shape) // n_rows
    matrix_shape = (n_rows, n_cols) if n_rows > n_cols else (n_cols, n_rows)
    norm_dst = random.normal(matrix_shape)
    q_mat, r_mat = np.linalg.qr(norm_dst)
    # Enforce Q is uniformly distributed
    q_mat *= np.sign(np.diag(r_mat))
    if n_rows < n_cols:
        q_mat = q_mat.T
    q_mat = np.reshape(q_mat, (n_rows, ) + tuple(np.delete(shape, axis)))
    q_mat = np.moveaxis(q_mat, 0, axis)
    return gain * jn.array(q_mat)
예제 #2
0
    def __call__(self, *args):
        """Returns the computed DP-SGD gradients.

        Returns:
            A tuple (gradients, value of f)."""
        batch = args[0].shape[0]
        assert batch % self.microbatch == 0
        num_microbatches = batch // self.microbatch
        stddev = self.l2_norm_clip * self.noise_multiplier / num_microbatches
        g, v = self.clipped_grad(*[self.reshape_microbatch(x) for x in args])
        g = [gx + random.normal(gx.shape, stddev=stddev, generator=self.keygen) for gx in g]
        return g, v
예제 #3
0
파일: init.py 프로젝트: spacexcorp/objax
def kaiming_normal(shape: Tuple[int, ...], gain: float = 1) -> JaxArray:
    """Returns a tensor with values assigned using Kaiming He normal initializer from
    `Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification
    <https://arxiv.org/abs/1502.01852>`_.

    Args:
        shape: shape of the output tensor.
        gain: optional scaling factor.

    Returns:
        Tensor initialized with normal random variables with standard deviation (gain * kaiming_normal_gain).
    """
    return random.normal(shape, stddev=gain * kaiming_normal_gain(shape))
예제 #4
0
파일: init.py 프로젝트: spacexcorp/objax
def xavier_normal(shape: Tuple[int, ...], gain: float = 1) -> JaxArray:
    """Returns a tensor with values assigned using Xavier Glorot normal initializer from
    `Understanding the difficulty of training deep feedforward neural networks
    <http://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf>`_.

    Args:
        shape: shape of the output tensor.
        gain: optional scaling factor.

    Returns:
        Tensor initialized with normal random variables with standard deviation (gain * xavier_normal_gain).
    """
    return random.normal(shape, stddev=gain * xavier_normal_gain(shape))
예제 #5
0
파일: gradient.py 프로젝트: srxzr/objax
    def __call__(self, *args):
        """Returns the computed DP-SGD gradients.

        Returns:
            A tuple (gradients, value of f)."""
        batch = args[0].shape[0]
        assert batch % self.microbatch == 0
        num_microbatches = batch // self.microbatch
        stddev = self.l2_norm_clip * self.noise_multiplier / num_microbatches
        g, v = self.private_grad(*[self.reshape_microbatch(x) for x in args])
        g, v = jax.tree_map(functools.partial(jn.mean, axis=0), (g, v))
        g = [
            gx + random.normal(gx.shape, stddev=stddev, generator=self.keygen)
            for gx in g
        ]
        return g, v