Пример #1
0
    def test_nested_list(self):
        class Net(nn.Module):
            def __init__(self):
                super().__init__()
                self.list = nn.ModuleList()
                self.list.append(nn.Conv2d(3, 32, kernel_size=3, padding=1))
                self.list.append(nn.Conv2d(32, 64, kernel_size=3, padding=1))
                self.list.append(nn.MaxPool2d(2, 2))
                self.pool = nn.MaxPool2d(2, 2)

        self.network = Net()
        shape_dict = cnn_shape.get_layer_output_shapes(self.input.shape,
                                                       self.network)
        expected_shape = cnn_shape.get_conv2d_output_shape(
            self.input.shape, self.network.list[0])
        self.assertEqual(shape_dict["list"]["0"], expected_shape)
        expected_shape = cnn_shape.get_conv2d_output_shape(
            self.input.shape, self.network.list[1])
        self.assertEqual(shape_dict["list"]["1"], expected_shape)
        expected_shape = cnn_shape.get_maxpool2d_output_shape(
            expected_shape, self.network.list[2])
        self.assertEqual(shape_dict["list"]["2"], expected_shape)
        expected_shape = cnn_shape.get_maxpool2d_output_shape(
            expected_shape, self.network.pool)
        self.assertEqual(shape_dict["pool"], expected_shape)
Пример #2
0
 def test_returns_correct_shapes_for_all_layers(self):
     shape_dict = cnn_shape.get_layer_output_shapes(self.input.shape,
                                                    self.module_list)
     expected_shape = cnn_shape.get_conv2d_output_shape(
         self.input.shape, self.module_list[0])
     self.assertEqual(shape_dict["0"], expected_shape)
     expected_shape = cnn_shape.get_conv2d_output_shape(
         self.input.shape, self.module_list[1])
     self.assertEqual(shape_dict["1"], expected_shape)
     expected_shape = cnn_shape.get_maxpool2d_output_shape(
         expected_shape, self.module_list[2])
     self.assertEqual(shape_dict["2"], expected_shape)
Пример #3
0
 def test_shape_for_downsampling_conv(self):
     layer = nn.Conv2d(self.input_channels,
                       out_channels=8,
                       kernel_size=5,
                       padding=0)
     output_shape = cnn_shape.get_conv2d_output_shape(
         self.input.shape, layer)
     expected_shape = layer.forward(self.input).shape
     self.assertEqual(output_shape, expected_shape)
Пример #4
0
 def test_shape_for_size_preserving_conv(self):
     layer = nn.Conv2d(self.input_channels,
                       out_channels=4,
                       kernel_size=3,
                       padding=1)
     output_shape = cnn_shape.get_conv2d_output_shape(
         self.input.shape, layer)
     expected_shape = layer.forward(self.input).shape
     self.assertEqual(output_shape, expected_shape)
Пример #5
0
 def test_output_type(self):
     layer = nn.Conv2d(self.input_channels,
                       out_channels=2,
                       kernel_size=3,
                       padding=1)
     output_shape = cnn_shape.get_conv2d_output_shape(
         self.input.shape, layer)
     self.assertIsInstance(output_shape, torch.Size,
                           f"Output should be of type {torch.Size}")
Пример #6
0
 def test_shape_for_complex_conv(self):
     layer = nn.Conv2d(self.input_channels,
                       out_channels=4,
                       kernel_size=3,
                       padding=(1, 0),
                       stride=(1, 2),
                       dilation=(1, 2))
     output_shape = cnn_shape.get_conv2d_output_shape(
         self.input.shape, layer)
     expected_shape = layer.forward(self.input).shape
     self.assertEqual(output_shape, expected_shape)
Пример #7
0
 def test_returns_correct_shapes_for_all_layers(self):
     shape_dict = cnn_shape.get_layer_output_shapes(self.input.shape,
                                                    self.network)
     expected_shape = cnn_shape.get_conv2d_output_shape(
         self.input.shape, self.network.encoder_conv1)
     self.assertEqual(shape_dict["encoder_conv1"], expected_shape)
     expected_shape = cnn_shape.get_maxpool2d_output_shape(
         expected_shape, self.network.encoder_pool1)
     self.assertEqual(shape_dict["encoder_pool1"], expected_shape)
     expected_shape = cnn_shape.get_upsample_output_shape(
         expected_shape, self.network.decoder_upsample1)
     self.assertEqual(shape_dict["decoder_upsample1"], expected_shape)
     expected_shape = cnn_shape.get_conv_transpose2d_output_shape(
         expected_shape, self.network.decoder_tconv1)
     self.assertEqual(shape_dict["decoder_tconv1"], expected_shape)