Ejemplo n.º 1
0
    def test_prune_layer(self):

        orig_model = mnist_torch_model.Net()
        orig_model.eval()
        # Create a layer database
        orig_layer_db = LayerDatabase(orig_model, input_shape=(1, 1, 28, 28))
        # Copy the db
        comp_layer_db = copy.deepcopy(orig_layer_db)

        dataset_size = 100
        batch_size = 10
        # max out number of batches
        number_of_batches = 10
        samples_per_image = 10
        num_reconstruction_samples = number_of_batches * batch_size * samples_per_image
        # create fake data loader with image size (1, 28, 28)
        data_loader = create_fake_data_loader(dataset_size=dataset_size,
                                              batch_size=batch_size)

        input_channel_pruner = InputChannelPruner(
            data_loader=data_loader,
            input_shape=(1, 1, 28, 28),
            num_reconstruction_samples=num_reconstruction_samples,
            allow_custom_downsample_ops=True)

        conv2 = comp_layer_db.find_layer_by_name('conv2')
        input_channel_pruner._prune_layer(orig_layer_db, comp_layer_db, conv2,
                                          0.5, CostMetric.mac)

        self.assertTrue(comp_layer_db.model.conv2.in_channels, 16)
        self.assertTrue(comp_layer_db.model.conv2.out_channels, 64)
Ejemplo n.º 2
0
    def test_get_quantized_weight(self):
        model = mnist_model.Net()

        params = qsim.QuantParams(weight_bw=4,
                                  act_bw=4,
                                  round_mode="nearest",
                                  quant_scheme=QuantScheme.post_training_tf)
        use_cuda = False
        dataset_size = 2
        batch_size = 1
        data_loader = create_fake_data_loader(dataset_size=dataset_size,
                                              batch_size=batch_size)

        def pass_data_through_model(model,
                                    early_stopping_iterations=None,
                                    use_cuda=False):
            # forward pass for given number of batches for model
            for _, (images_in_one_batch, _) in enumerate(data_loader):
                model(images_in_one_batch)

        quantsim = qsim.QuantizationSimModel(model=model,
                                             quant_scheme=params.quant_scheme,
                                             rounding_mode=params.round_mode,
                                             default_output_bw=params.act_bw,
                                             default_param_bw=params.weight_bw,
                                             in_place=False,
                                             dummy_input=torch.rand(
                                                 1, 1, 28, 28))
        quantsim.compute_encodings(pass_data_through_model, None)
        layer = quantsim.model.conv2
        quant_dequant_weights = bias_correction.get_quantized_dequantized_weight(
            layer, use_cuda)
        self.assertEqual(quant_dequant_weights.shape,
                         torch.Size([64, 32, 5, 5]))
def bias_correction_analytical_and_empirical():

    dataset_size = 2000
    batch_size = 64

    data_loader = create_fake_data_loader(dataset_size=dataset_size,
                                          batch_size=batch_size,
                                          image_size=(3, 224, 224))

    model = MobileNetV2()
    model.eval()

    # Find all BN + Conv pairs for analytical BC and remaining Conv for Empirical BC
    module_prop_dict = bias_correction.find_all_conv_bn_with_activation(
        model, input_shape=(1, 3, 224, 224))

    params = QuantParams(weight_bw=4,
                         act_bw=4,
                         round_mode="nearest",
                         quant_scheme='tf_enhanced')

    # Perform Bias Correction
    bias_correction.correct_bias(model.to(device="cuda"),
                                 params,
                                 num_quant_samples=1000,
                                 data_loader=data_loader,
                                 num_bias_correct_samples=512,
                                 conv_bn_dict=module_prop_dict,
                                 perform_only_empirical_bias_corr=False)
Ejemplo n.º 4
0
    def test_get_output_of_layer(self):
        model = TestNet()
        dataset_size = 2
        batch_size = 2
        data_loader = create_fake_data_loader(dataset_size=dataset_size,
                                              batch_size=batch_size)
        for images_in_one_batch, _ in data_loader:
            conv2_output_data = bias_correction.get_output_data(
                model.conv2, model, images_in_one_batch)

        # max out number of batches
        number_of_batches = 1
        iterator = data_loader.__iter__()

        for batch in range(number_of_batches):

            images_in_one_batch, _ = iterator.__next__()
            conv1_output = model.conv1(images_in_one_batch)
            conv2_input = conv1_output
            conv2_output = model.conv2(
                functional.relu(functional.max_pool2d(conv2_input, 2)))
            # compare the output from conv2 layer
            self.assertTrue(
                np.allclose(
                    to_numpy(conv2_output),
                    np.asarray(conv2_output_data)[batch *
                                                  batch_size:(batch + 1) *
                                                  batch_size, :, :, :]))
Ejemplo n.º 5
0
    def test_bias_correction_empirical_with_config_file(self):
        # Using a dummy extension of MNIST
        torch.manual_seed(10)
        model = mnist_model.Net()

        model = model.eval()

        model_copy = copy.deepcopy(model)
        dataset_size = 2
        batch_size = 1

        data_loader = create_fake_data_loader(dataset_size=dataset_size,
                                              batch_size=batch_size,
                                              image_size=(1, 28, 28))

        # Takes default config file
        params = qsim.QuantParams(weight_bw=4,
                                  act_bw=4,
                                  round_mode="nearest",
                                  quant_scheme=QuantScheme.post_training_tf,
                                  config_file=None)
        with unittest.mock.patch(
                'aimet_torch.bias_correction.call_empirical_mo_correct_bias'
        ) as empirical_mock:
            bias_correction.correct_bias(model, params, 2, data_loader, 2)

        self.assertEqual(empirical_mock.call_count, 4)
        self.assertTrue(
            np.allclose(model.conv1.bias.detach().cpu().numpy(),
                        model_copy.conv1.bias.detach().cpu().numpy()))

        self.assertTrue(model.conv2.bias.detach().cpu().numpy() is not None)
        self.assertTrue(model.fc1.bias.detach().cpu().numpy() is not None)
Ejemplo n.º 6
0
    def test_data_sub_sampling_and_reconstruction_without_bias(self):
        """Test end to end data sub sampling and reconstruction for MNIST conv2 layer (without bias)"""

        orig_model = mnist_model()
        # set bias to None
        orig_model.conv2.bias = None

        comp_model = copy.deepcopy(orig_model)

        dataset_size = 100
        batch_size = 10
        # max out number of batches
        number_of_batches = 10
        samples_per_image = 10
        num_reconstruction_samples = number_of_batches * batch_size * samples_per_image
        # create fake data loader with image size (1, 28, 28)
        data_loader = create_fake_data_loader(dataset_size=dataset_size,
                                              batch_size=batch_size)

        conv2_pr_layer_name = get_layer_name(comp_model, comp_model.conv2)

        sampled_inp_data, sampled_out_data = DataSubSampler.get_sub_sampled_data(
            orig_layer=orig_model.conv2,
            pruned_layer=comp_model.conv2,
            orig_model=orig_model,
            comp_model=comp_model,
            data_loader=data_loader,
            num_reconstruction_samples=num_reconstruction_samples)

        conv_layer = get_layer_by_name(model=comp_model,
                                       layer_name=conv2_pr_layer_name)

        assert conv_layer == comp_model.conv2
        # original weight before reconstruction
        orig_weight = conv_layer.weight.data
        WeightReconstructor.reconstruct_params_for_conv2d(
            layer=conv_layer,
            input_data=sampled_inp_data,
            output_data=sampled_out_data)
        # new weight after reconstruction
        new_weight = conv_layer.weight.data
        new_bias = conv_layer.bias

        self.assertEqual(new_weight.shape, orig_weight.shape)
        self.assertEqual(new_bias, None)
        # if you increase the data (data set size, number of batches or samples per image),
        # reduce the absolute tolerance
        self.assertTrue(
            np.allclose(to_numpy(new_weight), to_numpy(orig_weight),
                        atol=1e-0))
Ejemplo n.º 7
0
    def test_subsampled_input_data(self, np_choice_function):
        """ Test to collect activations (input from model_copy and output from model for conv2 layer) and compare
            with sub sampled input data
        """
        # hardcoded mocked 10 sample locations
        # (0, 0), (1, 1), (2, 2), (3, 3), (4, 4), (5, 5), (6, 6), (7, 7), (6, 6), (5, 5)
        np_choice_function.return_value = heights = widths = [
            0, 1, 2, 3, 4, 5, 6, 7, 6, 5
        ]

        orig_model = TestNet()
        comp_model = copy.deepcopy(orig_model)
        # only one image and from that 10 samples
        dataset_size = 1
        batch_size = 1
        num_reconstruction_samples = 10

        # create fake data loader with image size (1, 28, 28)
        data_loader = create_fake_data_loader(dataset_size=dataset_size,
                                              batch_size=batch_size,
                                              image_size=(1, 28, 28))

        conv2_input_data, _ = DataSubSampler.get_sub_sampled_data(
            orig_layer=orig_model.conv2,
            pruned_layer=comp_model.conv2,
            orig_model=orig_model,
            comp_model=comp_model,
            data_loader=data_loader,
            num_reconstruction_samples=num_reconstruction_samples)

        # collect the input data of conv2 from compressed model using same data loader

        iterator = data_loader.__iter__()
        images_in_one_batch, _ = iterator.__next__()
        conv1_output = comp_model.conv1(images_in_one_batch)
        conv2_input = functional.relu(functional.max_pool2d(conv1_output, 2))

        kernel_size_h, kernel_size_w = comp_model.conv2.kernel_size

        for sample in range(num_reconstruction_samples):

            self.assertTrue(
                np.array_equal(
                    conv2_input_data[sample, :, :, :],
                    conv2_input[0, :, heights[sample]:heights[sample] +
                                kernel_size_h, widths[sample]:widths[sample] +
                                kernel_size_w].detach().cpu().numpy()))
Ejemplo n.º 8
0
    def test_get_activation_data(self):
        """ Test to collect activations (input from model_copy and output from model for conv2 layer) and compare
        """
        orig_model = TestNet().cuda()
        comp_model = copy.deepcopy(orig_model)
        dataset_size = 1000
        batch_size = 10
        # max out number of batches
        number_of_batches = 100

        # create fake data loader with image size (1, 28, 28)
        data_loader = create_fake_data_loader(dataset_size=dataset_size,
                                              batch_size=batch_size)

        conv2_input_data, conv2_output_data = DataSubSampler.get_sub_sampled_data(
            orig_layer=orig_model.conv2,
            pruned_layer=comp_model.conv2,
            orig_model=orig_model,
            comp_model=comp_model,
            data_loader=data_loader,
            num_reconstruction_samples=number_of_batches)
        iterator = data_loader.__iter__()

        for batch in range(number_of_batches):

            images_in_one_batch, _ = iterator.__next__()
            conv1_output = orig_model.conv1(images_in_one_batch.cuda())
            conv2_input = conv1_output
            conv2_output = orig_model.conv2(
                functional.relu(functional.max_pool2d(conv2_input, 2)))
            # compare the output from conv2 layer
            self.assertTrue(
                np.array_equal(
                    to_numpy(conv2_output),
                    conv2_output_data[batch * batch_size:(batch + 1) *
                                      batch_size, :, :, :]))

            conv1_output_copy = comp_model.conv1(images_in_one_batch.cuda())
            conv2_input_copy = functional.relu(
                functional.max_pool2d(conv1_output_copy, 2))
            # compare the inputs of conv2 layer
            self.assertTrue(
                np.array_equal(
                    to_numpy(conv2_input_copy),
                    conv2_input_data[batch * batch_size:(batch + 1) *
                                     batch_size, :, :, :]))
Ejemplo n.º 9
0
    def test_data_sub_sampling_and_reconstruction(self):
        """Test end to end data sub sampling and reconstruction for MNIST conv2 layer"""
        orig_model = mnist_model()
        comp_model = copy.deepcopy(orig_model)

        dataset_size = 100
        batch_size = 10
        # max out number of batches
        number_of_batches = 10
        samples_per_image = 10
        num_reconstruction_samples = number_of_batches * batch_size * samples_per_image

        # create fake data loader with image size (1, 28, 28)
        data_loader = create_fake_data_loader(dataset_size=dataset_size,
                                              batch_size=batch_size)

        cp = InputChannelPruner(
            data_loader=data_loader,
            input_shape=None,
            num_reconstruction_samples=num_reconstruction_samples,
            allow_custom_downsample_ops=True)

        cp._data_subsample_and_reconstruction(orig_layer=orig_model.conv2,
                                              pruned_layer=comp_model.conv2,
                                              orig_model=orig_model,
                                              comp_model=comp_model)

        self.assertEqual(comp_model.conv2.weight.data.shape,
                         orig_model.conv2.weight.data.shape)
        self.assertEqual(comp_model.conv2.bias.data.shape,
                         orig_model.conv2.bias.data.shape)

        # if you increase the data (data set size, number of batches or samples per image),
        # reduce the absolute tolerance

        self.assertTrue(
            np.allclose(to_numpy(comp_model.conv2.weight.data),
                        to_numpy(orig_model.conv2.weight.data),
                        atol=1e-0))

        self.assertTrue(
            np.allclose(to_numpy(comp_model.conv2.bias.data),
                        to_numpy(orig_model.conv2.bias.data),
                        atol=1e-0))
Ejemplo n.º 10
0
    def test_subsampled_input_data_fc(self):
        """ Test to collect activations (input from model_copy and output from model for fc1 layer) and compare
            with sub sampled output data
        """
        orig_model = TestNet()
        comp_model = copy.deepcopy(orig_model)
        # only one image and from that 10 samples
        dataset_size = 100
        batch_size = 10
        num_reconstruction_samples = 5000

        # create fake data loader with image size (1, 28, 28)
        data_loader = create_fake_data_loader(dataset_size=dataset_size,
                                              batch_size=batch_size,
                                              image_size=(1, 28, 28))

        fc1_input_data, _ = DataSubSampler.get_sub_sampled_data(
            orig_layer=orig_model.fc1,
            pruned_layer=comp_model.fc1,
            orig_model=orig_model,
            comp_model=comp_model,
            data_loader=data_loader,
            num_reconstruction_samples=num_reconstruction_samples)

        self.assertTrue(fc1_input_data.shape[0] *
                        fc1_input_data.shape[1] > num_reconstruction_samples)

        # collect the output data of fc1 from original model using same data loader
        iterator = data_loader.__iter__()
        images_in_one_batch, _ = iterator.__next__()

        conv1_output = orig_model.conv1(images_in_one_batch)
        conv2_input = conv1_output
        conv2_output = orig_model.conv2(
            functional.relu(functional.max_pool2d(conv2_input, 2)))
        fc1_input = conv2_output
        fc1_input = functional.relu(functional.max_pool2d(fc1_input, 2))
        fc1_input = fc1_input.view(fc1_input.size(0),
                                   -1).detach().cpu().numpy()

        # compare data of first batch only
        self.assertTrue(np.array_equal(fc1_input_data[0:10], fc1_input))
Ejemplo n.º 11
0
    def test_bias_correction_analytical_and_empirical_ignore_layer(self):

        torch.manual_seed(10)
        model = MockMobileNetV1()
        model = model.eval()
        dataset_size = 2
        batch_size = 1

        data_loader = create_fake_data_loader(dataset_size=dataset_size,
                                              batch_size=batch_size,
                                              image_size=(3, 224, 224))
        params = qsim.QuantParams(weight_bw=8,
                                  act_bw=8,
                                  round_mode="nearest",
                                  quant_scheme=QuantScheme.post_training_tf)
        conv_bn_dict = find_all_conv_bn_with_activation(model,
                                                        input_shape=(1, 3, 224,
                                                                     224))

        layer = model.model[0][0]
        layers_to_ignore = [layer]

        with unittest.mock.patch(
                'aimet_torch.bias_correction.call_analytical_mo_correct_bias'
        ) as analytical_mock:
            with unittest.mock.patch(
                    'aimet_torch.bias_correction.call_empirical_mo_correct_bias'
            ) as empirical_mock:
                bias_correction.correct_bias(
                    model,
                    params,
                    2,
                    data_loader,
                    2,
                    conv_bn_dict,
                    perform_only_empirical_bias_corr=False,
                    layers_to_ignore=layers_to_ignore)

        self.assertEqual(analytical_mock.call_count, 8)  # one layer ignored
        self.assertEqual(empirical_mock.call_count, 9)
        self.assertTrue(
            model.model[1][0].bias.detach().cpu().numpy() is not None)
Ejemplo n.º 12
0
    def test_cached_dataset(self):
        """ Test cache data loader splitting into train and validation """
        dataset_size = 256
        batch_size = 16

        # create fake data loader with image size (1, 2, 2)
        data_loader = utils.create_fake_data_loader(dataset_size=dataset_size,
                                                    batch_size=batch_size,
                                                    image_size=(1, 2, 2))
        num_batches = 6
        path = '/tmp/test_cached_dataset/'
        cached_dataset = utils.CachedDataset(data_loader, num_batches, path)
        self.assertEqual(len(cached_dataset), 6)

        # Try creating cached data loader by more than possible batches from data loader and expect ValueError
        possible_batches = int(dataset_size / batch_size)
        with pytest.raises(ValueError):
            utils.CachedDataset(data_loader, possible_batches + 1, path)

        shutil.rmtree('/tmp/test_cached_dataset/')
Ejemplo n.º 13
0
    def test_data_sub_sample_and_reconstruction_with_zero_channels(self):
        """Test end to end data sub sampling and reconstruction for MNIST conv2 layer"""
        orig_model = mnist_model()
        comp_model = copy.deepcopy(orig_model)

        dataset_size = 100
        batch_size = 10
        # max out number of batches
        number_of_batches = 10
        samples_per_image = 10
        num_reconstruction_samples = number_of_batches * batch_size * samples_per_image
        # create fake data loader with image size (1, 28, 28)
        data_loader = create_fake_data_loader(dataset_size=dataset_size,
                                              batch_size=batch_size)

        cp = InputChannelPruner(
            data_loader=data_loader,
            input_shape=None,
            num_reconstruction_samples=num_reconstruction_samples,
            allow_custom_downsample_ops=True)

        input_channels_to_prune = [0, 1, 2, 3, 4, 5, 15, 29, 24, 28]
        zero_out_input_channels(comp_model.conv2, input_channels_to_prune)

        before_reconstruction = comp_model.conv2.weight.data

        cp._data_subsample_and_reconstruction(orig_layer=orig_model.conv2,
                                              pruned_layer=comp_model.conv2,
                                              orig_model=orig_model,
                                              comp_model=comp_model)

        after_reconstruction = comp_model.conv2.weight.data

        self.assertEqual(comp_model.conv2.weight.data.shape,
                         orig_model.conv2.weight.data.shape)
        self.assertEqual(comp_model.conv2.bias.data.shape,
                         orig_model.conv2.bias.data.shape)
        # make sure they are not same
        self.assertFalse(
            np.allclose(to_numpy(before_reconstruction),
                        to_numpy(after_reconstruction)))
def bias_correction_empirical():
    dataset_size = 2000
    batch_size = 64

    data_loader = create_fake_data_loader(dataset_size=dataset_size,
                                          batch_size=batch_size,
                                          image_size=(3, 224, 224))

    model = MobileNetV2()
    model.eval()

    params = QuantParams(weight_bw=4,
                         act_bw=4,
                         round_mode="nearest",
                         quant_scheme='tf_enhanced')

    # Perform Bias Correction
    bias_correction.correct_bias(model.to(device="cuda"),
                                 params,
                                 num_quant_samples=1000,
                                 data_loader=data_loader.train_loader,
                                 num_bias_correct_samples=512)
Ejemplo n.º 15
0
    def test_prune_layer_with_seq(self):
        """ Test end to end prune layer with resnet18"""

        batch_size = 2
        dataset_size = 1000
        number_of_batches = 1
        samples_per_image = 10
        num_reconstruction_samples = number_of_batches * batch_size * samples_per_image

        resnet18_model = models.resnet18(pretrained=True)
        # Create a layer database
        orig_layer_db = LayerDatabase(resnet18_model,
                                      input_shape=(1, 3, 224, 224))
        # Copy the db
        comp_layer_db = copy.deepcopy(orig_layer_db)

        data_loader = create_fake_data_loader(dataset_size=dataset_size,
                                              batch_size=batch_size,
                                              image_size=(3, 224, 224))

        input_channel_pruner = InputChannelPruner(
            data_loader=data_loader,
            input_shape=(1, 3, 224, 224),
            num_reconstruction_samples=num_reconstruction_samples,
            allow_custom_downsample_ops=True)

        conv_below_split = comp_layer_db.find_layer_by_name('layer1.1.conv1')
        input_channel_pruner._prune_layer(orig_layer_db, comp_layer_db,
                                          conv_below_split, 0.25,
                                          CostMetric.mac)

        # 64 * 0.25 = 16
        self.assertEqual(comp_layer_db.model.layer1[1].conv1[1].in_channels,
                         16)
        self.assertEqual(comp_layer_db.model.layer1[1].conv1[1].out_channels,
                         64)
        self.assertEqual(
            list(comp_layer_db.model.layer1[1].conv1[1].weight.shape),
            [64, 16, 3, 3])
Ejemplo n.º 16
0
    def test_prune_model_with_seq(self):
        """Test end to end prune model with resnet18"""

        batch_size = 2
        dataset_size = 1000
        number_of_batches = 1
        samples_per_image = 10
        num_reconstruction_samples = number_of_batches * batch_size * samples_per_image

        resnet18_model = models.resnet18(pretrained=True)
        resnet18_model.eval()

        # Create a layer database
        orig_layer_db = LayerDatabase(resnet18_model,
                                      input_shape=(1, 3, 224, 224))

        data_loader = create_fake_data_loader(dataset_size=dataset_size,
                                              batch_size=batch_size,
                                              image_size=(3, 224, 224))

        input_channel_pruner = InputChannelPruner(
            data_loader=data_loader,
            input_shape=(1, 3, 224, 224),
            num_reconstruction_samples=num_reconstruction_samples,
            allow_custom_downsample_ops=True)

        # keeping compression ratio = 0.5 for all layers
        layer_comp_ratio_list = [
            LayerCompRatioPair(
                Layer(resnet18_model.layer4[1].conv1, 'layer4.1.conv1', None),
                0.5),
            LayerCompRatioPair(
                Layer(resnet18_model.layer3[1].conv1, 'layer3.1.conv1', None),
                0.5),
            LayerCompRatioPair(
                Layer(resnet18_model.layer2[1].conv1, 'layer2.1.conv1', None),
                0.5),
            LayerCompRatioPair(
                Layer(resnet18_model.layer1[1].conv1, 'layer1.1.conv1', None),
                0.5),
            LayerCompRatioPair(
                Layer(resnet18_model.layer1[0].conv2, 'layer1.0.conv2', None),
                0.5)
        ]

        comp_layer_db = input_channel_pruner.prune_model(orig_layer_db,
                                                         layer_comp_ratio_list,
                                                         CostMetric.mac,
                                                         trainer=None)

        # 1) not below split
        self.assertEqual(comp_layer_db.model.layer1[0].conv2.in_channels, 32)
        self.assertEqual(comp_layer_db.model.layer1[0].conv2.out_channels, 64)
        self.assertEqual(
            list(comp_layer_db.model.layer1[0].conv2.weight.shape),
            [64, 32, 3, 3])
        # impacted
        self.assertEqual(comp_layer_db.model.layer1[0].conv1.in_channels, 64)
        self.assertEqual(comp_layer_db.model.layer1[0].conv1.out_channels, 32)
        self.assertEqual(
            list(comp_layer_db.model.layer1[0].conv1.weight.shape),
            [32, 64, 3, 3])

        # 2) below split

        # 64 * .5
        self.assertEqual(comp_layer_db.model.layer1[1].conv1[1].in_channels,
                         32)
        self.assertEqual(comp_layer_db.model.layer1[1].conv1[1].out_channels,
                         64)
        self.assertEqual(
            list(comp_layer_db.model.layer1[1].conv1[1].weight.shape),
            [64, 32, 3, 3])

        # 128 * .5
        self.assertEqual(comp_layer_db.model.layer2[1].conv1[1].in_channels,
                         64)
        self.assertEqual(comp_layer_db.model.layer2[1].conv1[1].out_channels,
                         128)
        self.assertEqual(
            list(comp_layer_db.model.layer2[1].conv1[1].weight.shape),
            [128, 64, 3, 3])

        # 256 * .5
        self.assertEqual(comp_layer_db.model.layer3[1].conv1[1].in_channels,
                         128)
        self.assertEqual(comp_layer_db.model.layer3[1].conv1[1].out_channels,
                         256)
        self.assertEqual(
            list(comp_layer_db.model.layer3[1].conv1[1].weight.shape),
            [256, 128, 3, 3])

        # 512 * .5
        self.assertEqual(comp_layer_db.model.layer4[1].conv1[1].in_channels,
                         256)
        self.assertEqual(comp_layer_db.model.layer4[1].conv1[1].out_channels,
                         512)
        self.assertEqual(
            list(comp_layer_db.model.layer4[1].conv1[1].weight.shape),
            [512, 256, 3, 3])
Ejemplo n.º 17
0
    def test_prune_model(self):
        """Test end to end prune model with Mnist"""
        class Net(nn.Module):
            def __init__(self):
                super(Net, self).__init__()
                self.conv1 = nn.Conv2d(1, 10, kernel_size=3)
                self.max_pool2d = nn.MaxPool2d(2)
                self.relu1 = nn.ReLU()
                self.conv2 = nn.Conv2d(10, 20, kernel_size=3)
                self.relu2 = nn.ReLU()
                self.conv3 = nn.Conv2d(20, 30, kernel_size=3)
                self.relu3 = nn.ReLU()
                self.conv4 = nn.Conv2d(30, 40, kernel_size=3)
                self.relu4 = nn.ReLU()
                self.fc1 = nn.Linear(7 * 7 * 40, 300)
                self.relu5 = nn.ReLU()
                self.fc2 = nn.Linear(300, 10)
                self.log_softmax = nn.LogSoftmax(dim=1)

            def forward(self, x):
                x = self.relu1(self.max_pool2d(self.conv1(x)))
                x = self.relu2(self.conv2(x))
                x = self.relu3(self.conv3(x))
                x = self.relu4(self.conv4(x))
                x = x.view(x.size(0), -1)
                x = self.relu5(self.fc1(x))
                x = self.fc2(x)
                return self.log_softmax(x)

        orig_model = Net()
        orig_model.eval()
        # Create a layer database
        orig_layer_db = LayerDatabase(orig_model, input_shape=(1, 1, 28, 28))
        dataset_size = 1000
        batch_size = 10
        # max out number of batches
        number_of_batches = 100
        samples_per_image = 10

        # create fake data loader with image size (1, 28, 28)
        data_loader = create_fake_data_loader(dataset_size=dataset_size,
                                              batch_size=batch_size)

        input_channel_pruner = InputChannelPruner(
            data_loader=data_loader,
            input_shape=(1, 1, 28, 28),
            num_reconstruction_samples=number_of_batches,
            allow_custom_downsample_ops=True)

        # keeping compression ratio = 0.5 for all layers
        layer_comp_ratio_list = [
            LayerCompRatioPair(Layer(orig_model.conv4, 'conv4', None), 0.5),
            LayerCompRatioPair(Layer(orig_model.conv3, 'conv3', None), 0.5),
            LayerCompRatioPair(Layer(orig_model.conv2, 'conv2', None), 0.5)
        ]

        comp_layer_db = input_channel_pruner.prune_model(orig_layer_db,
                                                         layer_comp_ratio_list,
                                                         CostMetric.mac,
                                                         trainer=None)

        self.assertEqual(comp_layer_db.model.conv2.in_channels, 5)
        self.assertEqual(comp_layer_db.model.conv2.out_channels, 10)

        self.assertEqual(comp_layer_db.model.conv3.in_channels, 10)
        self.assertEqual(comp_layer_db.model.conv3.out_channels, 15)

        self.assertEqual(comp_layer_db.model.conv4.in_channels, 15)
        self.assertEqual(comp_layer_db.model.conv4.out_channels, 40)