Beispiel #1
0
import numpy as np
import tensorflow as tf
from sleap.nn.system import use_cpu_only; use_cpu_only()  # hide GPUs for test

from sleap.nn.architectures import leap
from sleap.nn.config import LEAPConfig

class LeapTests(tf.test.TestCase):
    def test_leap_cnn_reference(self):
        # Reference implementation from the original paper.
        arch = leap.LeapCNN(
            filters=64,
            filters_rate=2,
            down_blocks=3,
            down_convs_per_block=3,
            up_blocks=3,
            up_interpolate=False,
            up_convs_per_block=2,
        )
        x_in = tf.keras.layers.Input((192, 192, 1))
        x, x_mid = arch.make_backbone(x_in)
        model = tf.keras.Model(x_in, x)
        param_counts = [
            np.prod(train_var.shape) for train_var in model.trainable_weights
        ]

        with self.subTest("number of layers"):
            self.assertEqual(len(model.layers), 40)
        with self.subTest("number of trainable weights"):
            self.assertEqual(len(model.trainable_weights), 36)
        with self.subTest("trainable parameter count"):
Beispiel #2
0
import numpy as np
import tensorflow as tf
from sleap.nn.system import use_cpu_only

use_cpu_only()  # hide GPUs for test

from sleap.nn.architectures import encoder_decoder


class EncoderDecoderTests(tf.test.TestCase):
    def test_simple_conv_block(self):
        block = encoder_decoder.SimpleConvBlock(
            pooling_stride=2,
            num_convs=3,
            filters=16,
            kernel_size=3,
            use_bias=True,
            batch_norm=False,
            batch_norm_before_activation=True,
            activation="relu",
        )
        x_in = tf.keras.Input((8, 8, 1))
        x = block.make_block(x_in)
        model = tf.keras.Model(x_in, x)

        self.assertEqual(len(model.layers), 1 + 2 * 3 + 1)
        self.assertEqual(len(model.trainable_weights), 6)
        self.assertEqual(model.count_params(), 4800)
        self.assertAllEqual(model.output.shape, (None, 4, 4, 16))

    def test_simple_conv_block_bn(self):