def test_nested_list(self): class Net(nn.Module): def __init__(self): super().__init__() self.list = nn.ModuleList() self.list.append(nn.Conv2d(3, 32, kernel_size=3, padding=1)) self.list.append(nn.Conv2d(32, 64, kernel_size=3, padding=1)) self.list.append(nn.MaxPool2d(2, 2)) self.pool = nn.MaxPool2d(2, 2) self.network = Net() shape_dict = cnn_shape.get_layer_output_shapes(self.input.shape, self.network) expected_shape = cnn_shape.get_conv2d_output_shape( self.input.shape, self.network.list[0]) self.assertEqual(shape_dict["list"]["0"], expected_shape) expected_shape = cnn_shape.get_conv2d_output_shape( self.input.shape, self.network.list[1]) self.assertEqual(shape_dict["list"]["1"], expected_shape) expected_shape = cnn_shape.get_maxpool2d_output_shape( expected_shape, self.network.list[2]) self.assertEqual(shape_dict["list"]["2"], expected_shape) expected_shape = cnn_shape.get_maxpool2d_output_shape( expected_shape, self.network.pool) self.assertEqual(shape_dict["pool"], expected_shape)
def test_returns_correct_shapes_for_all_layers(self): shape_dict = cnn_shape.get_layer_output_shapes(self.input.shape, self.module_list) expected_shape = cnn_shape.get_conv2d_output_shape( self.input.shape, self.module_list[0]) self.assertEqual(shape_dict["0"], expected_shape) expected_shape = cnn_shape.get_conv2d_output_shape( self.input.shape, self.module_list[1]) self.assertEqual(shape_dict["1"], expected_shape) expected_shape = cnn_shape.get_maxpool2d_output_shape( expected_shape, self.module_list[2]) self.assertEqual(shape_dict["2"], expected_shape)
def test_returns_correct_shapes_for_all_layers(self): shape_dict = cnn_shape.get_layer_output_shapes(self.input.shape, self.network) expected_shape = cnn_shape.get_conv2d_output_shape( self.input.shape, self.network.encoder_conv1) self.assertEqual(shape_dict["encoder_conv1"], expected_shape) expected_shape = cnn_shape.get_maxpool2d_output_shape( expected_shape, self.network.encoder_pool1) self.assertEqual(shape_dict["encoder_pool1"], expected_shape) expected_shape = cnn_shape.get_upsample_output_shape( expected_shape, self.network.decoder_upsample1) self.assertEqual(shape_dict["decoder_upsample1"], expected_shape) expected_shape = cnn_shape.get_conv_transpose2d_output_shape( expected_shape, self.network.decoder_tconv1) self.assertEqual(shape_dict["decoder_tconv1"], expected_shape)
def test_shape_for_complex_pool(self): layer = nn.MaxPool2d(kernel_size=(3, 2), stride=(3, 2), padding=(1, 0)) output_shape = cnn_shape.get_maxpool2d_output_shape( self.input.shape, layer) expected_shape = layer.forward(self.input).shape self.assertEqual(output_shape, expected_shape)
def test_shape_for_dilated_pool(self): layer = nn.MaxPool2d(kernel_size=3, stride=2, dilation=2, padding=1) output_shape = cnn_shape.get_maxpool2d_output_shape( self.input.shape, layer) expected_shape = layer.forward(self.input).shape self.assertEqual(output_shape, expected_shape)
def test_shape_for_basic_pool(self): layer = nn.MaxPool2d(kernel_size=2, stride=2) output_shape = cnn_shape.get_maxpool2d_output_shape( self.input.shape, layer) expected_shape = layer.forward(self.input).shape self.assertEqual(output_shape, expected_shape)
def test_output_type(self): layer = nn.MaxPool2d(kernel_size=2, padding=1) output_shape = cnn_shape.get_maxpool2d_output_shape( self.input.shape, layer) self.assertIsInstance(output_shape, torch.Size, f"Output should be of type {torch.Size}")