def test_mask_network(self, mask_net, mask_net_channels, instance_embedding_dim, input_channels, use_instance_embedding): net = deepmac_meta_arch.MaskHeadNetwork( mask_net, num_init_channels=mask_net_channels, use_instance_embedding=use_instance_embedding) call_func = tf.function(net.__call__) out = call_func(tf.zeros((2, instance_embedding_dim)), tf.zeros((2, 32, 32, input_channels)), training=True) self.assertEqual(out.shape, (2, 32, 32)) self.assertAllGreater(out.numpy(), -np.inf) self.assertAllLess(out.numpy(), np.inf) out = call_func(tf.zeros((2, instance_embedding_dim)), tf.zeros((2, 32, 32, input_channels)), training=True) self.assertEqual(out.shape, (2, 32, 32)) out = call_func(tf.zeros((0, instance_embedding_dim)), tf.zeros((0, 32, 32, input_channels)), training=True) self.assertEqual(out.shape, (0, 32, 32))
def test_fc_tf_function(self): net = deepmac_meta_arch.MaskHeadNetwork('fully_connected', 8, mask_size=32) call_func = tf.function(net.__call__) out = call_func(tf.zeros((2, 4)), tf.zeros((2, 32, 32, 8)), training=True) self.assertEqual(out.shape, (2, 32, 32))
def test_mask_network_resnet_tf_function(self): net = deepmac_meta_arch.MaskHeadNetwork('resnet8') call_func = tf.function(net.__call__) out = call_func(tf.zeros((2, 4)), tf.zeros((2, 32, 32, 16)), training=True) self.assertEqual(out.shape, (2, 32, 32))
def test_mask_network_params_resnet4(self): net = deepmac_meta_arch.MaskHeadNetwork('resnet4', num_init_channels=8) _ = net(tf.zeros((2, 16)), tf.zeros((2, 32, 32, 16)), training=True) trainable_params = tf.reduce_sum( [tf.reduce_prod(tf.shape(w)) for w in net.trainable_weights]) self.assertEqual(trainable_params.numpy(), 8665)
def test_mask_network_embedding_projection_small(self): net = deepmac_meta_arch.MaskHeadNetwork('embedding_projection', num_init_channels=-1, use_instance_embedding=False) call_func = tf.function(net.__call__) out = call_func(1e6 + tf.zeros((2, 7)), tf.zeros((2, 32, 32, 7)), training=True) self.assertEqual(out.shape, (2, 32, 32)) self.assertAllGreater(out.numpy(), -np.inf) self.assertAllLess(out.numpy(), np.inf)
def test_mask_network_resnet(self): net = deepmac_meta_arch.MaskHeadNetwork('resnet4') out = net(tf.zeros((2, 4)), tf.zeros((2, 32, 32, 16)), training=True) self.assertEqual(out.shape, (2, 32, 32))