Пример #1
0
  def test_simple(self, resnet_v2):
    image = jnp.ones([2, 64, 64, 3])
    model = resnet.ResNet([1, 1, 1, 1], 10, resnet_v2=resnet_v2)

    logits = model(image, is_training=True)
    self.assertIsNotNone(logits)
    self.assertEqual(logits.shape, (2, 10))
Пример #2
0
 def test_error_incorrect_args_block_list(self, list_length):
     block_list = [i for i in range(list_length)]
     with self.assertRaisesRegex(
             ValueError,
             "blocks_per_group` must be of length 4 not {}".format(
                 list_length)):
         resnet.ResNet(block_list, 10, {"decay_rate": 0.9, "eps": 1e-5})
Пример #3
0
 def test_error_incorrect_args_channel_list(self, list_length):
   channel_list = [i for i in range(list_length)]
   with self.assertRaisesRegex(
       ValueError,
       "channels_per_group_list` must be of length 4 not {}".format(
           list_length)):
     resnet.ResNet([1, 1, 1, 1], 10, {"decay_rate": 0.9, "eps": 1e-5},
                   channels_per_group_list=channel_list)
Пример #4
0
  def test_simple(self, resnet_v2, bottleneck):
    image = jnp.ones([2, 64, 64, 3])
    model = resnet.ResNet([1, 1, 1, 1], 10,
                          resnet_v2=resnet_v2,
                          bottleneck=bottleneck)

    for is_training in (True, False):
      logits = model(image, is_training=is_training)
      self.assertEqual(logits.shape, (2, 10))
Пример #5
0
 def forward_fn(image):
     model = resnet.ResNet([1, 1, 1, 1],
                           10,
                           resnet_v2=resnet_v2,
                           bottleneck=bottleneck)
     return model(image, is_training=False, test_local_stats=True)