Exemplo n.º 1
0
    def test_nested_list(self):
        class Net(nn.Module):
            def __init__(self):
                super().__init__()
                self.list = nn.ModuleList()
                self.list.append(nn.Conv2d(3, 32, kernel_size=3, padding=1))
                self.list.append(nn.Conv2d(32, 64, kernel_size=3, padding=1))
                self.list.append(nn.MaxPool2d(2, 2))
                self.pool = nn.MaxPool2d(2, 2)

        self.network = Net()
        shape_dict = cnn_shape.get_layer_output_shapes(self.input.shape,
                                                       self.network)
        expected_shape = cnn_shape.get_conv2d_output_shape(
            self.input.shape, self.network.list[0])
        self.assertEqual(shape_dict["list"]["0"], expected_shape)
        expected_shape = cnn_shape.get_conv2d_output_shape(
            self.input.shape, self.network.list[1])
        self.assertEqual(shape_dict["list"]["1"], expected_shape)
        expected_shape = cnn_shape.get_maxpool2d_output_shape(
            expected_shape, self.network.list[2])
        self.assertEqual(shape_dict["list"]["2"], expected_shape)
        expected_shape = cnn_shape.get_maxpool2d_output_shape(
            expected_shape, self.network.pool)
        self.assertEqual(shape_dict["pool"], expected_shape)
Exemplo n.º 2
0
 def test_returns_correct_shapes_for_all_layers(self):
     shape_dict = cnn_shape.get_layer_output_shapes(self.input.shape,
                                                    self.module_list)
     expected_shape = cnn_shape.get_conv2d_output_shape(
         self.input.shape, self.module_list[0])
     self.assertEqual(shape_dict["0"], expected_shape)
     expected_shape = cnn_shape.get_conv2d_output_shape(
         self.input.shape, self.module_list[1])
     self.assertEqual(shape_dict["1"], expected_shape)
     expected_shape = cnn_shape.get_maxpool2d_output_shape(
         expected_shape, self.module_list[2])
     self.assertEqual(shape_dict["2"], expected_shape)
Exemplo n.º 3
0
 def test_returns_correct_shapes_for_all_layers(self):
     shape_dict = cnn_shape.get_layer_output_shapes(self.input.shape,
                                                    self.network)
     expected_shape = cnn_shape.get_conv2d_output_shape(
         self.input.shape, self.network.encoder_conv1)
     self.assertEqual(shape_dict["encoder_conv1"], expected_shape)
     expected_shape = cnn_shape.get_maxpool2d_output_shape(
         expected_shape, self.network.encoder_pool1)
     self.assertEqual(shape_dict["encoder_pool1"], expected_shape)
     expected_shape = cnn_shape.get_upsample_output_shape(
         expected_shape, self.network.decoder_upsample1)
     self.assertEqual(shape_dict["decoder_upsample1"], expected_shape)
     expected_shape = cnn_shape.get_conv_transpose2d_output_shape(
         expected_shape, self.network.decoder_tconv1)
     self.assertEqual(shape_dict["decoder_tconv1"], expected_shape)
Exemplo n.º 4
0
 def test_ignores_modules_not_being_layers(self):
     self.module_list.append(nn.Dropout())
     cnn_shape.get_layer_output_shapes(self.input.shape, self.module_list)
Exemplo n.º 5
0
 def test_throws_when_unsupported_layer_type_encountered(self):
     self.module_list.append(nn.AdaptiveAvgPool2d(16))
     with self.assertRaises(KeyError):
         cnn_shape.get_layer_output_shapes(self.input.shape,
                                           self.module_list)
Exemplo n.º 6
0
 def test_result_dict_has_len_equal_to_layer_count(self):
     layer_count = len(self.module_list)
     shape_dict = cnn_shape.get_layer_output_shapes(self.input.shape,
                                                    self.module_list)
     self.assertEqual(len(shape_dict), layer_count)
Exemplo n.º 7
0
 def test_return_type(self):
     shape_dict = cnn_shape.get_layer_output_shapes(self.input.shape,
                                                    self.module_list)
     self.assertIsInstance(shape_dict, dict)
Exemplo n.º 8
0
 def test_ignores_modules_not_being_layers(self):
     self.network.add_module("dropout", nn.Dropout())
     cnn_shape.get_layer_output_shapes(self.input.shape, self.network)
Exemplo n.º 9
0
 def test_throws_when_unsupported_layer_type_encountered(self):
     self.network.add_module("unsupported", nn.AdaptiveAvgPool2d(16))
     with self.assertRaises(KeyError):
         cnn_shape.get_layer_output_shapes(self.input.shape, self.network)
Exemplo n.º 10
0
 def test_result_dict_has_len_equal_to_layer_count(self):
     layer_count = len(list(self.network.named_modules())[1:])
     shape_dict = cnn_shape.get_layer_output_shapes(self.input.shape,
                                                    self.network)
     self.assertEqual(len(shape_dict), layer_count)