Ejemplo n.º 1
0
    def _get_layer_pairs(layer_db: LayerDatabase, module_comp_ratio_pairs: List[ModuleCompRatioPair]):
        layer_comp_ratio_pairs = []

        for pair in module_comp_ratio_pairs:
            layer_comp_ratio_pair = LayerCompRatioPair(layer_db.find_layer_by_module(pair.module),
                                                       pair.comp_ratio)
            layer_comp_ratio_pairs.append(layer_comp_ratio_pair)

        return layer_comp_ratio_pairs
Ejemplo n.º 2
0
    def test_split_conv_layer_with_mo(self):

        logger.debug(self.id())
        model = mnist_model.Net().to("cpu")

        layer_database = LayerDatabase(model=model, input_shape=(1, 1, 28, 28))

        with unittest.mock.patch('aimet_torch.svd.layer_selector_deprecated.LayerSelectorDeprecated'):
            svd = s.SvdImpl(model=model, run_model=mnist_model.evaluate, run_model_iterations=1,
                            input_shape=(1, 1, 28, 28),
                            compression_type=aimet_torch.svd.svd_intf_defs_deprecated.CompressionTechnique.svd, cost_metric=aimet_torch.svd.svd_intf_defs_deprecated.CostMetric.memory,
                            layer_selection_scheme=aimet_torch.svd.svd_intf_defs_deprecated.LayerSelectionScheme.top_n_layers, num_layers=2)

        conv2 = layer_database.find_layer_by_module(model.conv2)
        pymo_utils.PymoSvdUtils.configure_layers_in_pymo_svd([conv2], aimet_common.defs.CostMetric.mac, svd._svd_lib_ref)

        split_layer = svd_pruner_deprecated.DeprecatedSvdPruner
        seq, conv_a, conv_b = split_layer.prune_layer(conv2, 28, svd._svd_lib_ref)

        print('\n')
        weight_arr = conv_a.module.weight.detach().numpy().flatten()
        weight_arr = weight_arr[0:10]
        print(weight_arr)

        self.assertEqual((28, model.conv2.in_channels, 1, 1), conv_a.module.weight.shape)
        self.assertEqual([28], list(conv_a.module.bias.shape))
        self.assertEqual((model.conv2.out_channels, 28, 5, 5), conv_b.module.weight.shape)
        self.assertEqual([model.conv2.out_channels], list(conv_b.module.bias.shape))

        self.assertEqual(model.conv2.stride, conv_a.module.stride)
        self.assertEqual(model.conv2.stride, conv_b.module.stride)

        self.assertEqual((0, 0), conv_a.module.padding)
        self.assertEqual(model.conv2.padding, conv_b.module.padding)

        self.assertEqual((1, 1), conv_a.module.kernel_size)
        self.assertEqual(model.conv2.kernel_size, conv_b.module.kernel_size)