Beispiel #1
0
    def __init__(self, params):
        """Initializes GroupNorm layer and checks parameters."""
        super().__init__(params)
        p = self.params
        asserts.not_none(p.name)
        asserts.gt(p.num_groups, 0)
        asserts.gt(p.min_group_size, 0)
        asserts.le(p.min_group_size, p.dim)
        asserts.eq(p.dim % p.min_group_size, 0)

        if p.dim >= p.num_groups:
            asserts.eq(
                p.dim % p.num_groups,
                0,
                msg='p.dim({0}) is not dividable by p.num_groups({1})'.format(
                    p.dim, p.num_groups))

        asserts.in_set(p.input_rank, (3, 4))
Beispiel #2
0
 def test_gt_raises(self, value1, value2):
     with self.assertRaisesRegex(
             ValueError,
             f'`value1={value1}` must be strictly greater than `value2={value2}`.$'
     ):
         asserts.gt(value1, value2)
Beispiel #3
0
 def test_gt(self, value1, value2):
     asserts.gt(value1, value2)
Beispiel #4
0
 def _get_default_paddings(self, inputs: JTensor) -> JTensor:
     """Gets the default paddings for an input."""
     in_shape = list(inputs.shape)
     asserts.gt(len(in_shape), 1)
     in_shape[-1] = 1
     return jnp.zeros(in_shape, dtype=inputs.dtype)