def __init__(self, params): """Initializes GroupNorm layer and checks parameters.""" super().__init__(params) p = self.params asserts.not_none(p.name) asserts.gt(p.num_groups, 0) asserts.gt(p.min_group_size, 0) asserts.le(p.min_group_size, p.dim) asserts.eq(p.dim % p.min_group_size, 0) if p.dim >= p.num_groups: asserts.eq( p.dim % p.num_groups, 0, msg='p.dim({0}) is not dividable by p.num_groups({1})'.format( p.dim, p.num_groups)) asserts.in_set(p.input_rank, (3, 4))
def test_gt_raises(self, value1, value2): with self.assertRaisesRegex( ValueError, f'`value1={value1}` must be strictly greater than `value2={value2}`.$' ): asserts.gt(value1, value2)
def test_gt(self, value1, value2): asserts.gt(value1, value2)
def _get_default_paddings(self, inputs: JTensor) -> JTensor: """Gets the default paddings for an input.""" in_shape = list(inputs.shape) asserts.gt(len(in_shape), 1) in_shape[-1] = 1 return jnp.zeros(in_shape, dtype=inputs.dtype)