def test_simple(self, resnet_v2): image = jnp.ones([2, 64, 64, 3]) model = resnet.ResNet([1, 1, 1, 1], 10, resnet_v2=resnet_v2) logits = model(image, is_training=True) self.assertIsNotNone(logits) self.assertEqual(logits.shape, (2, 10))
def test_error_incorrect_args_block_list(self, list_length): block_list = [i for i in range(list_length)] with self.assertRaisesRegex( ValueError, "blocks_per_group` must be of length 4 not {}".format( list_length)): resnet.ResNet(block_list, 10, {"decay_rate": 0.9, "eps": 1e-5})
def test_error_incorrect_args_channel_list(self, list_length): channel_list = [i for i in range(list_length)] with self.assertRaisesRegex( ValueError, "channels_per_group_list` must be of length 4 not {}".format( list_length)): resnet.ResNet([1, 1, 1, 1], 10, {"decay_rate": 0.9, "eps": 1e-5}, channels_per_group_list=channel_list)
def test_simple(self, resnet_v2, bottleneck): image = jnp.ones([2, 64, 64, 3]) model = resnet.ResNet([1, 1, 1, 1], 10, resnet_v2=resnet_v2, bottleneck=bottleneck) for is_training in (True, False): logits = model(image, is_training=is_training) self.assertEqual(logits.shape, (2, 10))
def forward_fn(image): model = resnet.ResNet([1, 1, 1, 1], 10, resnet_v2=resnet_v2, bottleneck=bottleneck) return model(image, is_training=False, test_local_stats=True)