Пример #1
0
    def test_nested_list(self):
        class Net(nn.Module):
            def __init__(self):
                super().__init__()
                self.list = nn.ModuleList()
                self.list.append(nn.Conv2d(3, 32, kernel_size=3, padding=1))
                self.list.append(nn.Conv2d(32, 64, kernel_size=3, padding=1))
                self.list.append(nn.MaxPool2d(2, 2))
                self.pool = nn.MaxPool2d(2, 2)

        self.network = Net()
        shape_dict = cnn_shape.get_layer_output_shapes(self.input.shape,
                                                       self.network)
        expected_shape = cnn_shape.get_conv2d_output_shape(
            self.input.shape, self.network.list[0])
        self.assertEqual(shape_dict["list"]["0"], expected_shape)
        expected_shape = cnn_shape.get_conv2d_output_shape(
            self.input.shape, self.network.list[1])
        self.assertEqual(shape_dict["list"]["1"], expected_shape)
        expected_shape = cnn_shape.get_maxpool2d_output_shape(
            expected_shape, self.network.list[2])
        self.assertEqual(shape_dict["list"]["2"], expected_shape)
        expected_shape = cnn_shape.get_maxpool2d_output_shape(
            expected_shape, self.network.pool)
        self.assertEqual(shape_dict["pool"], expected_shape)
Пример #2
0
 def test_returns_correct_shapes_for_all_layers(self):
     shape_dict = cnn_shape.get_layer_output_shapes(self.input.shape,
                                                    self.module_list)
     expected_shape = cnn_shape.get_conv2d_output_shape(
         self.input.shape, self.module_list[0])
     self.assertEqual(shape_dict["0"], expected_shape)
     expected_shape = cnn_shape.get_conv2d_output_shape(
         self.input.shape, self.module_list[1])
     self.assertEqual(shape_dict["1"], expected_shape)
     expected_shape = cnn_shape.get_maxpool2d_output_shape(
         expected_shape, self.module_list[2])
     self.assertEqual(shape_dict["2"], expected_shape)
Пример #3
0
 def test_returns_correct_shapes_for_all_layers(self):
     shape_dict = cnn_shape.get_layer_output_shapes(self.input.shape,
                                                    self.network)
     expected_shape = cnn_shape.get_conv2d_output_shape(
         self.input.shape, self.network.encoder_conv1)
     self.assertEqual(shape_dict["encoder_conv1"], expected_shape)
     expected_shape = cnn_shape.get_maxpool2d_output_shape(
         expected_shape, self.network.encoder_pool1)
     self.assertEqual(shape_dict["encoder_pool1"], expected_shape)
     expected_shape = cnn_shape.get_upsample_output_shape(
         expected_shape, self.network.decoder_upsample1)
     self.assertEqual(shape_dict["decoder_upsample1"], expected_shape)
     expected_shape = cnn_shape.get_conv_transpose2d_output_shape(
         expected_shape, self.network.decoder_tconv1)
     self.assertEqual(shape_dict["decoder_tconv1"], expected_shape)
Пример #4
0
 def test_shape_for_complex_pool(self):
     layer = nn.MaxPool2d(kernel_size=(3, 2), stride=(3, 2), padding=(1, 0))
     output_shape = cnn_shape.get_maxpool2d_output_shape(
         self.input.shape, layer)
     expected_shape = layer.forward(self.input).shape
     self.assertEqual(output_shape, expected_shape)
Пример #5
0
 def test_shape_for_dilated_pool(self):
     layer = nn.MaxPool2d(kernel_size=3, stride=2, dilation=2, padding=1)
     output_shape = cnn_shape.get_maxpool2d_output_shape(
         self.input.shape, layer)
     expected_shape = layer.forward(self.input).shape
     self.assertEqual(output_shape, expected_shape)
Пример #6
0
 def test_shape_for_basic_pool(self):
     layer = nn.MaxPool2d(kernel_size=2, stride=2)
     output_shape = cnn_shape.get_maxpool2d_output_shape(
         self.input.shape, layer)
     expected_shape = layer.forward(self.input).shape
     self.assertEqual(output_shape, expected_shape)
Пример #7
0
 def test_output_type(self):
     layer = nn.MaxPool2d(kernel_size=2, padding=1)
     output_shape = cnn_shape.get_maxpool2d_output_shape(
         self.input.shape, layer)
     self.assertIsInstance(output_shape, torch.Size,
                           f"Output should be of type {torch.Size}")