def test_nested_list(self): class Net(nn.Module): def __init__(self): super().__init__() self.list = nn.ModuleList() self.list.append(nn.Conv2d(3, 32, kernel_size=3, padding=1)) self.list.append(nn.Conv2d(32, 64, kernel_size=3, padding=1)) self.list.append(nn.MaxPool2d(2, 2)) self.pool = nn.MaxPool2d(2, 2) self.network = Net() shape_dict = cnn_shape.get_layer_output_shapes(self.input.shape, self.network) expected_shape = cnn_shape.get_conv2d_output_shape( self.input.shape, self.network.list[0]) self.assertEqual(shape_dict["list"]["0"], expected_shape) expected_shape = cnn_shape.get_conv2d_output_shape( self.input.shape, self.network.list[1]) self.assertEqual(shape_dict["list"]["1"], expected_shape) expected_shape = cnn_shape.get_maxpool2d_output_shape( expected_shape, self.network.list[2]) self.assertEqual(shape_dict["list"]["2"], expected_shape) expected_shape = cnn_shape.get_maxpool2d_output_shape( expected_shape, self.network.pool) self.assertEqual(shape_dict["pool"], expected_shape)
def test_returns_correct_shapes_for_all_layers(self): shape_dict = cnn_shape.get_layer_output_shapes(self.input.shape, self.module_list) expected_shape = cnn_shape.get_conv2d_output_shape( self.input.shape, self.module_list[0]) self.assertEqual(shape_dict["0"], expected_shape) expected_shape = cnn_shape.get_conv2d_output_shape( self.input.shape, self.module_list[1]) self.assertEqual(shape_dict["1"], expected_shape) expected_shape = cnn_shape.get_maxpool2d_output_shape( expected_shape, self.module_list[2]) self.assertEqual(shape_dict["2"], expected_shape)
def test_returns_correct_shapes_for_all_layers(self): shape_dict = cnn_shape.get_layer_output_shapes(self.input.shape, self.network) expected_shape = cnn_shape.get_conv2d_output_shape( self.input.shape, self.network.encoder_conv1) self.assertEqual(shape_dict["encoder_conv1"], expected_shape) expected_shape = cnn_shape.get_maxpool2d_output_shape( expected_shape, self.network.encoder_pool1) self.assertEqual(shape_dict["encoder_pool1"], expected_shape) expected_shape = cnn_shape.get_upsample_output_shape( expected_shape, self.network.decoder_upsample1) self.assertEqual(shape_dict["decoder_upsample1"], expected_shape) expected_shape = cnn_shape.get_conv_transpose2d_output_shape( expected_shape, self.network.decoder_tconv1) self.assertEqual(shape_dict["decoder_tconv1"], expected_shape)
def test_ignores_modules_not_being_layers(self): self.module_list.append(nn.Dropout()) cnn_shape.get_layer_output_shapes(self.input.shape, self.module_list)
def test_throws_when_unsupported_layer_type_encountered(self): self.module_list.append(nn.AdaptiveAvgPool2d(16)) with self.assertRaises(KeyError): cnn_shape.get_layer_output_shapes(self.input.shape, self.module_list)
def test_result_dict_has_len_equal_to_layer_count(self): layer_count = len(self.module_list) shape_dict = cnn_shape.get_layer_output_shapes(self.input.shape, self.module_list) self.assertEqual(len(shape_dict), layer_count)
def test_return_type(self): shape_dict = cnn_shape.get_layer_output_shapes(self.input.shape, self.module_list) self.assertIsInstance(shape_dict, dict)
def test_ignores_modules_not_being_layers(self): self.network.add_module("dropout", nn.Dropout()) cnn_shape.get_layer_output_shapes(self.input.shape, self.network)
def test_throws_when_unsupported_layer_type_encountered(self): self.network.add_module("unsupported", nn.AdaptiveAvgPool2d(16)) with self.assertRaises(KeyError): cnn_shape.get_layer_output_shapes(self.input.shape, self.network)
def test_result_dict_has_len_equal_to_layer_count(self): layer_count = len(list(self.network.named_modules())[1:]) shape_dict = cnn_shape.get_layer_output_shapes(self.input.shape, self.network) self.assertEqual(len(shape_dict), layer_count)