def test_slice_axis(self): slice_layer = group_norm.GroupNorm(groups=5, create_scale=False, create_offset=False) axis_layer = group_norm.GroupNorm(groups=5, create_scale=False, create_offset=False) inputs = np.random.uniform(0, 10, [3, 4, 4, 5]).astype(np.float32) scale = np.random.normal(size=(5, ), loc=1.0) offset = np.random.normal(size=(5, )) slice_outputs = slice_layer(inputs, scale, offset) axis_outputs = axis_layer(inputs, scale, offset) self.assertAllClose(slice_outputs, axis_outputs)
def test_valid_data_format_channels_last(self, data_format): test = group_norm.GroupNorm(groups=5, data_format=data_format, create_scale=False, create_offset=False) self.assertEqual(test.channel_index, -1)
def test_data_format_agnostic_var(self): c_last_layer = group_norm.GroupNorm(groups=5, create_scale=True, create_offset=True) c_first_layer = group_norm.GroupNorm(groups=5, create_scale=True, create_offset=True, data_format="NCHW") inputs = np.random.uniform(0, 10, [3, 4, 4, 10]).astype(np.float32) c_last_output = c_last_layer(inputs) inputs = jnp.transpose(inputs, [0, 3, 1, 2]) c_first_output = c_first_layer(inputs) c_first_output = jnp.transpose(c_first_output, [0, 2, 3, 1]) self.assertAllClose(c_last_output, c_first_output)
def test_no_offset_beta_init_provided(self): with self.assertRaisesRegex( ValueError, "Cannot set `offset_init` if `create_offset=False`."): group_norm.GroupNorm(groups=5, create_scale=True, create_offset=False, offset_init=jnp.zeros)
def test_no_scale_and_init_provided(self): with self.assertRaisesRegex( ValueError, "Cannot set `scale_init` if `create_scale=False`."): group_norm.GroupNorm(groups=5, create_scale=False, create_offset=True, scale_init=jnp.ones)
def test_invalid_axis(self, axis): with self.assertRaisesRegex( ValueError, "`axis` should be an int, slice or iterable of ints."): group_norm.GroupNorm(groups=5, axis=axis, create_scale=False, create_offset=False)
def test_simple_case(self): layer = group_norm.GroupNorm(groups=5, create_scale=False, create_offset=False) inputs = jnp.ones([2, 3, 3, 10]) outputs = layer(inputs) for x in np.nditer(outputs): self.assertEqual(x, 0.0)
def test_create_offset_and_offset_provided(self): layer = group_norm.GroupNorm(groups=5, create_offset=True, create_scale=False) with self.assertRaisesRegex( ValueError, "Cannot pass `offset` at call time if `create_offset=True`."): layer(jnp.ones([2, 3, 5]), offset=jnp.ones([4]))
def test_invalid_data_format(self, data_format): with self.assertRaisesRegex( ValueError, "Unable to extract channel information from '{}'.".format( data_format)): group_norm.GroupNorm(groups=5, data_format=data_format, create_scale=False, create_offset=False)
def test_simple_case_tensor(self): layer = group_norm.GroupNorm(groups=5, create_scale=False, create_offset=False) inputs = jnp.ones([2, 3, 3, 10]) scale = constant(0.5, shape=(10, )) offset = constant(2.0, shape=(10, )) outputs = layer(inputs, scale, offset) for x in np.nditer(outputs): self.assertEqual(x, 2.0)
def test_simple_case_var(self): layer = group_norm.GroupNorm(groups=5, create_scale=True, create_offset=True, scale_init=initializers.Constant(0.5), offset_init=initializers.Constant(2.0)) inputs = jnp.ones([2, 3, 3, 10]) outputs = layer(inputs) for x in np.nditer(outputs): self.assertEqual(x, 2.0)
def test_incompatible_groups_and_tensor(self, shape): layer = group_norm.GroupNorm(groups=5, create_scale=False, create_offset=False) inputs = jnp.ones(shape) with self.assertRaisesRegex( ValueError, "The number of channels must be divisible by the number of groups" ): layer(inputs)
def test_data_format_agnostic_tensor(self): c_last = group_norm.GroupNorm(groups=5, create_scale=False, create_offset=False) c_first = group_norm.GroupNorm(groups=5, data_format="NCHW", create_scale=False, create_offset=False) inputs = np.random.uniform(0, 10, [3, 4, 4, 10]).astype(np.float32) scale = np.random.normal(size=(10, ), loc=1.0) offset = np.random.normal(size=(10, )) c_last_output = c_last(inputs, scale, offset) inputs = jnp.transpose(inputs, [0, 3, 1, 2]) scale = jnp.reshape(scale, (10, 1, 1)) offset = jnp.reshape(offset, (10, 1, 1)) c_first_output = c_first(inputs, scale, offset) c_first_output = jnp.transpose(c_first_output, [0, 2, 3, 1]) self.assertAllClose(c_last_output, c_first_output, rtol=1e-5)
def test3ddata_format_agnostic(self): c_last_layer = group_norm.GroupNorm(groups=5, create_scale=False, create_offset=False) c_first_layer = group_norm.GroupNorm(groups=5, create_scale=False, create_offset=False, data_format="NCW") inputs = np.random.uniform(0, 10, [3, 4, 10]).astype(np.float32) scale = np.random.normal(size=(10, ), loc=1.0) offset = np.random.normal(size=(10, )) c_last_output = c_last_layer(inputs, scale, offset) inputs = jnp.transpose(inputs, [0, 2, 1]) scale = jnp.reshape(scale, [-1, 1]) offset = jnp.reshape(offset, [-1, 1]) c_first_output = c_first_layer(inputs, scale, offset) c_first_output = jnp.transpose(c_first_output, [0, 2, 1]) self.assertAllClose(c_last_output, c_first_output, atol=1e-5, rtol=1e-5)
def test_rank_changes(self): layer = group_norm.GroupNorm(groups=5, create_scale=False, create_offset=False) inputs = jnp.ones([2, 3, 3, 5]) scale = constant(0.5, shape=(5, )) offset = constant(2.0, shape=(5, )) layer(inputs, scale, offset) with self.assertRaisesRegex( ValueError, "The rank of the inputs cannot change between calls, the original" ): layer(jnp.ones([2, 3, 3, 4, 5]), scale, offset)