Ejemplo n.º 1
0
from monai.networks.nets import (
    se_resnet50,
    se_resnet101,
    se_resnet152,
    se_resnext50_32x4d,
    se_resnext101_32x4d,
    senet154,
)
from tests.utils import skip_if_quick

TEST_CASE_1 = [senet154(3, 2, 2)]
TEST_CASE_2 = [se_resnet50(3, 2, 2)]
TEST_CASE_3 = [se_resnet101(3, 2, 2)]
TEST_CASE_4 = [se_resnet152(3, 2, 2)]
TEST_CASE_5 = [se_resnext50_32x4d(3, 2, 2)]
TEST_CASE_6 = [se_resnext101_32x4d(3, 2, 2)]

TEST_CASE_PRETRAINED = [se_resnet50(2, 3, 2, pretrained=True)]


class TestSENET(unittest.TestCase):
    @parameterized.expand([
        TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5,
        TEST_CASE_6
    ])
    def test_senet_shape(self, net):
        input_data = torch.randn(2, 2, 64, 64, 64)
        expected_shape = (2, 2)
        net.eval()
        with torch.no_grad():
            result = net.forward(input_data)
Ejemplo n.º 2
0
    se_resnet50,
    se_resnet101,
    se_resnet152,
    se_resnext50_32x4d,
    se_resnext101_32x4d,
    senet154,
)

device = "cuda" if torch.cuda.is_available() else "cpu"

TEST_CASE_1 = [senet154(3, 2, 2).to(device)]
TEST_CASE_2 = [se_resnet50(3, 2, 2).to(device)]
TEST_CASE_3 = [se_resnet101(3, 2, 2).to(device)]
TEST_CASE_4 = [se_resnet152(3, 2, 2).to(device)]
TEST_CASE_5 = [se_resnext50_32x4d(3, 2, 2).to(device)]
TEST_CASE_6 = [se_resnext101_32x4d(3, 2, 2).to(device)]

TEST_CASE_PRETRAINED = [se_resnet50(2, 3, 2, pretrained=True).to(device)]


class TestSENET(unittest.TestCase):
    @parameterized.expand([
        TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5,
        TEST_CASE_6
    ])
    def test_senet_shape(self, net):
        input_data = torch.randn(2, 2, 64, 64, 64).to(device)
        expected_shape = (2, 2)
        net.eval()
        with torch.no_grad():
            result = net.forward(input_data)
Ejemplo n.º 3
0
    se_resnet50,
    se_resnet101,
    se_resnet152,
    se_resnext50_32x4d,
    se_resnext101_32x4d,
    senet154,
)

input_param = {"spatial_dims": 3, "in_channels": 2, "num_classes": 10}

TEST_CASE_1 = [senet154(**input_param)]
TEST_CASE_2 = [se_resnet50(**input_param)]
TEST_CASE_3 = [se_resnet101(**input_param)]
TEST_CASE_4 = [se_resnet152(**input_param)]
TEST_CASE_5 = [se_resnext50_32x4d(**input_param)]
TEST_CASE_6 = [se_resnext101_32x4d(**input_param)]


class TestSENET(unittest.TestCase):
    @parameterized.expand([
        TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5,
        TEST_CASE_6
    ])
    def test_senet154_shape(self, net):
        input_data = torch.randn(5, 2, 64, 64, 64)
        expected_shape = (5, 10)
        net.eval()
        with torch.no_grad():
            result = net.forward(input_data)
            self.assertEqual(result.shape, expected_shape)