def test_prune_model_2_layers(self):
        """ Punning two layers with 0.5 comp-ratio in MNIST"""

        # create tf.compat.v1.Session and initialize the weights and biases with zeros
        config = tf.compat.v1.ConfigProto()
        config.gpu_options.allow_growth = True

        # create session with graph
        sess = tf.compat.v1.Session(graph=tf.Graph(), config=config)

        with sess.graph.as_default():
            # by default, model will be constructed in default graph
            _ = mnist_tf_model.create_model(data_format='channels_last')
            sess.run(tf.compat.v1.global_variables_initializer())

        # Create a layer database
        orig_layer_db = LayerDatabase(model=sess,
                                      input_shape=(1, 28, 28, 1),
                                      working_dir=None)
        conv1 = orig_layer_db.find_layer_by_name('conv2d/Conv2D')
        conv2 = orig_layer_db.find_layer_by_name('conv2d_1/Conv2D')

        layer_comp_ratio_list = [
            LayerCompRatioPair(conv1, Decimal(0.5)),
            LayerCompRatioPair(conv2, Decimal(0.5))
        ]

        spatial_svd_pruner = SpatialSvdPruner()
        comp_layer_db = spatial_svd_pruner.prune_model(orig_layer_db,
                                                       layer_comp_ratio_list,
                                                       CostMetric.mac,
                                                       trainer=None)

        conv1_a = comp_layer_db.find_layer_by_name('conv2d_a/Conv2D')
        conv1_b = comp_layer_db.find_layer_by_name('conv2d_b/Conv2D')

        # Weights shape [kh, kw, Nic, Noc]
        self.assertEqual([5, 1, 1, 2],
                         conv1_a.module.inputs[1].get_shape().as_list())
        self.assertEqual([1, 5, 2, 32],
                         conv1_b.module.inputs[1].get_shape().as_list())

        conv2_a = comp_layer_db.find_layer_by_name('conv2d_1_a/Conv2D')
        conv2_b = comp_layer_db.find_layer_by_name('conv2d_1_b/Conv2D')

        self.assertEqual([5, 1, 32, 53],
                         conv2_a.module.inputs[1].get_shape().as_list())
        self.assertEqual([1, 5, 53, 64],
                         conv2_b.module.inputs[1].get_shape().as_list())

        for layer in comp_layer_db:
            print("Layer: " + layer.name)
            print("   Module: " + str(layer.module.name))

        tf.compat.v1.reset_default_graph()
        sess.close()
        # delete temp directory
        shutil.rmtree(str('./temp_meta/'))
Example #2
0
    def test_layer_database_with_dynamic_shape(self):
        """ test layer database creation with different input shapes"""
        # create tf.compat.v1.Session and initialize the weights and biases with zeros
        config = tf.compat.v1.ConfigProto()
        config.gpu_options.allow_growth = True

        graph = tf.Graph()

        with graph.as_default():
            # by default, model will be constructed in default graph
            input_placeholder = tf.compat.v1.placeholder(tf.float32, [None, None, None, 3], 'input')
            x = tf.keras.layers.Conv2D(8, (2, 2), padding='SAME')(input_placeholder)
            x = tf.keras.layers.BatchNormalization(momentum=.3, epsilon=.65)(x)
            x = tf.keras.layers.Conv2D(8, (1, 1), padding='SAME', activation=tf.nn.tanh)(x)
            x = tf.keras.layers.BatchNormalization(momentum=.4, epsilon=.25)(x)
            init = tf.compat.v1.global_variables_initializer()

        # create session with graph
        sess = tf.compat.v1.Session(graph=graph, config=config)
        sess.run(init)

        layer_db = LayerDatabase(model=sess, input_shape=(1, 224, 224, 3), working_dir=None, starting_ops=['input'],
                                 ending_ops=['batch_normalization_1/cond/Merge'])

        conv1_layer = layer_db.find_layer_by_name('conv2d/Conv2D')
        conv2_layer = layer_db.find_layer_by_name('conv2d_1/Conv2D')

        self.assertEqual(conv1_layer.output_shape, [1, 8, 224, 224])
        self.assertEqual(conv2_layer.output_shape, [1, 8, 224, 224])

        layer_db.destroy()

        # 2) try with different input shape

        # create another session with graph
        sess = tf.compat.v1.Session(graph=graph, config=config)
        sess.run(init)

        batch_size = 32
        layer_db = LayerDatabase(model=sess, input_shape=(batch_size, 28, 28, 3), working_dir=None,
                                 starting_ops=['input'], ending_ops=['batch_normalization_1/cond/Merge'])

        conv1_layer = layer_db.find_layer_by_name('conv2d/Conv2D')
        conv2_layer = layer_db.find_layer_by_name('conv2d_1/Conv2D')

        self.assertEqual(conv1_layer.output_shape, [32, 8, 28, 28])
        self.assertEqual(conv2_layer.output_shape, [32, 8, 28, 28])

        layer_db.destroy()
    def test_prune_model_tf_slim(self):
        """ Punning a model with tf slim api """

        # create tf.compat.v1.Session and initialize the weights and biases with zeros
        config = tf.compat.v1.ConfigProto()
        config.gpu_options.allow_growth = True

        # create session with graph
        sess = tf.compat.v1.Session(graph=tf.Graph(), config=config)

        with sess.graph.as_default():
            # by default, model will be constructed in default graph
            x = tf.compat.v1.placeholder(tf.float32, [1, 32, 32, 3])
            _ = tf_slim_basic_model(x)
            sess.run(tf.compat.v1.global_variables_initializer())

        conn_graph_orig = ConnectedGraph(sess.graph, ['Placeholder'],
                                         ['tf_slim_model/Softmax'])
        num_ops_orig = len(conn_graph_orig.get_all_ops())

        # Create a layer database
        orig_layer_db = LayerDatabase(model=sess,
                                      input_shape=(1, 32, 32, 3),
                                      working_dir=None)
        conv1 = orig_layer_db.find_layer_by_name('Conv_1/Conv2D')
        conv1_bias = BiasUtils.get_bias_as_numpy_data(orig_layer_db.model,
                                                      conv1.module)

        layer_comp_ratio_list = [LayerCompRatioPair(conv1, Decimal(0.5))]

        spatial_svd_pruner = SpatialSvdPruner()
        comp_layer_db = spatial_svd_pruner.prune_model(orig_layer_db,
                                                       layer_comp_ratio_list,
                                                       CostMetric.mac,
                                                       trainer=None)
        # Check that svd added these ops
        _ = comp_layer_db.model.graph.get_operation_by_name('Conv_1_a/Conv2D')
        _ = comp_layer_db.model.graph.get_operation_by_name('Conv_1_b/Conv2D')

        conn_graph_new = ConnectedGraph(comp_layer_db.model.graph,
                                        ['Placeholder'],
                                        ['tf_slim_model/Softmax'])
        num_ops_new = len(conn_graph_new.get_all_ops())
        self.assertEqual(num_ops_orig + 1, num_ops_new)
        bias_add_op = comp_layer_db.model.graph.get_operation_by_name(
            'Conv_1_b/BiasAdd')
        conv_1_b_op = comp_layer_db.model.graph.get_operation_by_name(
            'Conv_1_b/Conv2D')
        self.assertEqual(
            conn_graph_new._module_identifier.get_op_info(bias_add_op),
            conn_graph_new._module_identifier.get_op_info(conv_1_b_op))
        self.assertTrue(
            np.array_equal(
                conv1_bias,
                BiasUtils.get_bias_as_numpy_data(comp_layer_db.model,
                                                 conv_1_b_op)))
    def test_prune_conv_no_bias(self):
        """ Test spatial svd on a conv layer with no bias """
        # create tf.compat.v1.Session and initialize the weights and biases with zeros
        config = tf.compat.v1.ConfigProto()
        config.gpu_options.allow_growth = True

        # create session with graph
        sess = tf.compat.v1.Session(graph=tf.Graph(), config=config)

        with sess.graph.as_default():
            # by default, model will be constructed in default graph
            inputs = tf.keras.Input(shape=(
                32,
                32,
                3,
            ))
            x = tf.keras.layers.Conv2D(32, (3, 3), use_bias=False)(inputs)
            _ = tf.keras.layers.Flatten()(x)
            sess.run(tf.compat.v1.global_variables_initializer())

        # Create a layer database
        orig_layer_db = LayerDatabase(model=sess,
                                      input_shape=(1, 32, 32, 3),
                                      working_dir=None)
        conv_op = orig_layer_db.find_layer_by_name('conv2d/Conv2D')

        layer_comp_ratio_list = [LayerCompRatioPair(conv_op, Decimal(0.5))]

        spatial_svd_pruner = SpatialSvdPruner()
        comp_layer_db = spatial_svd_pruner.prune_model(orig_layer_db,
                                                       layer_comp_ratio_list,
                                                       CostMetric.mac,
                                                       trainer=None)
        # Check that svd added these ops
        _ = comp_layer_db.model.graph.get_operation_by_name('conv2d_a/Conv2D')
        conv2d_b_op = comp_layer_db.model.graph.get_operation_by_name(
            'conv2d_b/Conv2D')
        reshape_op = comp_layer_db.model.graph.get_operation_by_name(
            'flatten/Reshape')
        self.assertEqual(conv2d_b_op, reshape_op.inputs[0].op)
Example #5
0
    def _reconstruct_layers(self, layers_to_reconstruct: List[Layer],
                            orig_layer_name_to_pruned_name_and_mask_dict: Dict[
                                str, Tuple[str, List[int]]],
                            layer_db: LayerDatabase,
                            comp_layer_db: LayerDatabase):
        """
        Reconstruct weights and biases of layers in the layers_to_reconstruct list.
        :param layers_to_reconstruct: List of layers to reconstruct weights and biases of
        :param orig_layer_name_to_pruned_name_and_mask_dict: Dictionary mapping original layer names to most recent
        pruned op name and most recent output masks.
        :param layer_db: Original layer database
        :param comp_layer_db: Compressed layer database
        """
        for layer in layers_to_reconstruct:
            # Get output mask of layer, that contains information about all channels winnowed since the start
            pruned_layer_name, output_mask = \
                orig_layer_name_to_pruned_name_and_mask_dict.get(layer.name, (None, None))
            assert pruned_layer_name is not None

            pruned_layer = comp_layer_db.find_layer_by_name(pruned_layer_name)
            self._data_subsample_and_reconstruction(layer, pruned_layer,
                                                    output_mask, layer_db,
                                                    comp_layer_db)
    def test_per_layer_eval_scores(self):

        pruner = unittest.mock.MagicMock()
        eval_func = unittest.mock.MagicMock()

        # create tf.compat.v1.Session and initialize the weights and biases with zeros
        config = tf.compat.v1.ConfigProto()
        config.gpu_options.allow_growth = True

        # create session with graph
        sess = tf.compat.v1.Session(graph=tf.Graph(), config=config)

        with sess.graph.as_default():
            # by default, model will be constructed in default graph
            _ = mnist_tf_model.create_model(data_format='channels_last')
            sess.run(tf.compat.v1.global_variables_initializer())

        # Create a layer database
        layer_db = LayerDatabase(model=sess,
                                 input_shape=(1, 28, 28, 1),
                                 working_dir=None)
        layer1 = layer_db.find_layer_by_name('conv2d/Conv2D')

        layer_db.mark_picked_layers([layer1])
        eval_func.side_effect = [90, 80, 70, 60, 50, 40, 30, 20, 10]

        url, process = start_bokeh_server_session(8006)
        bokeh_session = BokehServerSession(url=url, session_id="compression")

        # Instantiate child
        greedy_algo = comp_ratio_select.GreedyCompRatioSelectAlgo(
            layer_db=layer_db,
            pruner=pruner,
            cost_calculator=SpatialSvdCostCalculator(),
            eval_func=eval_func,
            eval_iterations=20,
            cost_metric=CostMetric.mac,
            target_comp_ratio=0.5,
            num_candidates=10,
            use_monotonic_fit=True,
            saved_eval_scores_dict=None,
            comp_ratio_rounding_algo=None,
            use_cuda=False,
            bokeh_session=bokeh_session)
        progress_bar = ProgressBar(1,
                                   "eval scores",
                                   "green",
                                   bokeh_session=bokeh_session)
        data_table = DataTable(num_columns=3,
                               num_rows=1,
                               column_names=[
                                   '0.1', '0.2', '0.3', '0.4', '0.5', '0.6',
                                   '0.7', '0.8', '0.9'
                               ],
                               row_index_names=[layer1.name],
                               bokeh_session=bokeh_session)

        pruner.prune_model.return_value = layer_db
        eval_dict = greedy_algo._compute_layerwise_eval_score_per_comp_ratio_candidate(
            data_table, progress_bar, layer1)

        self.assertEqual(90, eval_dict[Decimal('0.1')])

        tf.compat.v1.reset_default_graph()
        sess.close()

        bokeh_session.server_session.close("test complete")
        os.killpg(os.getpgid(process.pid), signal.SIGTERM)
    def test_select_per_layer_comp_ratios(self):

        pruner = unittest.mock.MagicMock()
        eval_func = unittest.mock.MagicMock()
        rounding_algo = unittest.mock.MagicMock()
        rounding_algo.round.side_effect = [
            0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.1, 0.2, 0.3, 0.4,
            0.5, 0.6, 0.7, 0.8, 0.9
        ]

        eval_func.side_effect = [
            10, 20, 30, 40, 50, 60, 70, 80, 90, 11, 21, 31, 35, 40, 45, 50, 55,
            60
        ]

        # create tf.compat.v1.Session and initialize the weights and biases with zeros
        config = tf.compat.v1.ConfigProto()
        config.gpu_options.allow_growth = True

        # create session with graph
        sess = tf.compat.v1.Session(graph=tf.Graph(), config=config)

        with sess.graph.as_default():
            # by default, model will be constructed in default graph
            _ = mnist_tf_model.create_model(data_format='channels_last')
            sess.run(tf.compat.v1.global_variables_initializer())

        # Create a layer database
        layer_db = LayerDatabase(model=sess,
                                 input_shape=(1, 28, 28, 1),
                                 working_dir=None)
        layer1 = layer_db.find_layer_by_name('conv2d/Conv2D')
        layer2 = layer_db.find_layer_by_name('conv2d_1/Conv2D')

        selected_layers = [layer1, layer2]
        layer_db.mark_picked_layers([layer1, layer2])

        try:
            os.remove('./data/greedy_selection_eval_scores_dict.pkl')
        except OSError:
            pass

        url, process = start_bokeh_server_session(8006)
        bokeh_session = BokehServerSession(url=url, session_id="compression")

        # Instantiate child
        greedy_algo = comp_ratio_select.GreedyCompRatioSelectAlgo(
            layer_db=layer_db,
            pruner=pruner,
            cost_calculator=SpatialSvdCostCalculator(),
            eval_func=eval_func,
            eval_iterations=20,
            cost_metric=CostMetric.mac,
            target_comp_ratio=Decimal(0.6),
            num_candidates=10,
            use_monotonic_fit=True,
            saved_eval_scores_dict=None,
            comp_ratio_rounding_algo=rounding_algo,
            use_cuda=False,
            bokeh_session=bokeh_session)

        layer_comp_ratio_list, stats = greedy_algo.select_per_layer_comp_ratios(
        )

        original_cost = SpatialSvdCostCalculator.compute_model_cost(layer_db)

        for layer in layer_db:
            if layer not in selected_layers:
                layer_comp_ratio_list.append(LayerCompRatioPair(layer, None))
        compressed_cost = SpatialSvdCostCalculator.calculate_compressed_cost(
            layer_db, layer_comp_ratio_list, CostMetric.mac)
        rounding_algo.round.side_effect = [
            0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.1, 0.2, 0.3, 0.4,
            0.5, 0.6, 0.7, 0.8, 0.9
        ]
        actual_compression_ratio = compressed_cost.mac / original_cost.mac
        self.assertTrue(
            math.isclose(Decimal(0.6), actual_compression_ratio, abs_tol=0.05))
        self.assertTrue(
            os.path.isfile('./data/greedy_selection_eval_scores_dict.pkl'))

        print('\n')
        for pair in layer_comp_ratio_list:
            print(pair)

        # lets repeat with a saved eval_dict
        greedy_algo = comp_ratio_select.GreedyCompRatioSelectAlgo(
            layer_db=layer_db,
            pruner=pruner,
            cost_calculator=SpatialSvdCostCalculator(),
            eval_func=eval_func,
            eval_iterations=20,
            cost_metric=CostMetric.mac,
            target_comp_ratio=Decimal(0.6),
            num_candidates=10,
            use_monotonic_fit=True,
            saved_eval_scores_dict=
            './data/greedy_selection_eval_scores_dict.pkl',
            comp_ratio_rounding_algo=rounding_algo,
            use_cuda=False,
            bokeh_session=bokeh_session)

        layer_comp_ratio_list, stats = greedy_algo.select_per_layer_comp_ratios(
        )

        original_cost = SpatialSvdCostCalculator.compute_model_cost(layer_db)

        for layer in layer_db:
            if layer not in selected_layers:
                layer_comp_ratio_list.append(LayerCompRatioPair(layer, None))
        compressed_cost = SpatialSvdCostCalculator.calculate_compressed_cost(
            layer_db, layer_comp_ratio_list, CostMetric.mac)

        actual_compression_ratio = compressed_cost.mac / original_cost.mac
        self.assertTrue(
            math.isclose(Decimal(0.6), actual_compression_ratio, abs_tol=0.05))

        print('\n')
        for pair in layer_comp_ratio_list:
            print(pair)

        tf.compat.v1.reset_default_graph()
        sess.close()

        bokeh_session.server_session.close("test complete")
        os.killpg(os.getpgid(process.pid), signal.SIGTERM)
    def test_eval_scores_with_spatial_svd_pruner(self):

        pruner = SpatialSvdPruner()
        eval_func = unittest.mock.MagicMock()
        eval_func.side_effect = [
            90, 80, 70, 60, 50, 40, 30, 20, 10, 91, 81, 71, 61, 51, 41, 31, 21,
            11
        ]

        # create tf.compat.v1.Session and initialize the weights and biases with zeros
        config = tf.compat.v1.ConfigProto()
        config.gpu_options.allow_growth = True

        # create session with graph
        sess = tf.compat.v1.Session(graph=tf.Graph(), config=config)

        with sess.graph.as_default():
            # by default, model will be constructed in default graph
            _ = mnist_tf_model.create_model(data_format='channels_last')
            sess.run(tf.compat.v1.global_variables_initializer())

        # Create a layer database
        layer_db = LayerDatabase(model=sess,
                                 input_shape=(1, 28, 28, 1),
                                 working_dir=None)
        layer1 = layer_db.find_layer_by_name('conv2d/Conv2D')
        layer2 = layer_db.find_layer_by_name('conv2d_1/Conv2D')

        layer_db.mark_picked_layers([layer1, layer2])

        url, process = start_bokeh_server_session(8006)
        bokeh_session = BokehServerSession(url=url, session_id="compression")

        # Instantiate child
        greedy_algo = comp_ratio_select.GreedyCompRatioSelectAlgo(
            layer_db=layer_db,
            pruner=pruner,
            cost_calculator=SpatialSvdCostCalculator(),
            eval_func=eval_func,
            eval_iterations=20,
            cost_metric=CostMetric.mac,
            target_comp_ratio=0.5,
            num_candidates=10,
            use_monotonic_fit=True,
            saved_eval_scores_dict=None,
            comp_ratio_rounding_algo=None,
            use_cuda=False,
            bokeh_session=bokeh_session)

        dict = greedy_algo._compute_eval_scores_for_all_comp_ratio_candidates()

        print()
        print(dict)
        self.assertEqual(90, dict['conv2d/Conv2D'][Decimal('0.1')])

        self.assertEqual(51, dict['conv2d_1/Conv2D'][Decimal('0.5')])
        self.assertEqual(21, dict['conv2d_1/Conv2D'][Decimal('0.8')])

        tf.compat.v1.reset_default_graph()
        sess.close()

        bokeh_session.server_session.close("test complete")
        os.killpg(os.getpgid(process.pid), signal.SIGTERM)
Example #9
0
    def prune_model(self, layer_db: LayerDatabase,
                    layer_comp_ratio_list: List[LayerCompRatioPair],
                    cost_metric: CostMetric, trainer):

        # sort all the layers in layer_comp_ratio_list based on occurrence
        layer_comp_ratio_list = self._sort_on_occurrence(
            layer_db.model, layer_comp_ratio_list)

        # Copy the db
        comp_layer_db = copy.deepcopy(layer_db)
        current_sess = comp_layer_db.model

        # Dictionary to map original layer name to list of most recent pruned layer name and output mask.
        # Masks remain at the original length and specify channels winnowed after each round of winnower.
        orig_layer_name_to_pruned_name_and_mask_dict = {}
        # Dictionary to map most recent pruned layer name to the original layer name
        pruned_name_to_orig_name_dict = {}
        # List to hold original layers to reconstruct
        layers_to_reconstruct = []
        detached_op_names = set()

        # Prune layers which have comp ratios less than 1
        for layer_comp_ratio in layer_comp_ratio_list:
            orig_layer = layer_db.find_layer_by_name(
                layer_comp_ratio.layer.name)
            if layer_comp_ratio.comp_ratio is not None and layer_comp_ratio.comp_ratio < 1.0:
                # 1) channel selection
                prune_indices = self._select_inp_channels(
                    orig_layer, layer_comp_ratio.comp_ratio)
                if not prune_indices:
                    continue

                # 2) Winnowing the model
                current_sess, ordered_modules_list = winnow.winnow_tf_model(
                    current_sess,
                    self._input_op_names,
                    self._output_op_names,
                    [(orig_layer.module, prune_indices)],
                    reshape=self._allow_custom_downsample_ops,
                    in_place=True,
                    verbose=False)
                if not ordered_modules_list:
                    continue

                layers_to_reconstruct.append(orig_layer)
                # Update dictionaries with new info about pruned ops and new masks
                self._update_pruned_ops_and_masks_info(
                    ordered_modules_list,
                    orig_layer_name_to_pruned_name_and_mask_dict,
                    pruned_name_to_orig_name_dict, detached_op_names)

        # Save and reload modified graph to allow changes to take effect
        # Need to initialize uninitialized variables first since only newly winnowed conv ops are initialized during
        # winnow_tf_model, and all other newly winnowed ops are not.
        with current_sess.graph.as_default():
            initialize_uninitialized_vars(current_sess)
        current_sess = save_and_load_graph('./saver', current_sess)
        comp_layer_db.update_database(current_sess,
                                      detached_op_names,
                                      update_model=True)

        # Perform reconstruction
        self._reconstruct_layers(layers_to_reconstruct,
                                 orig_layer_name_to_pruned_name_and_mask_dict,
                                 layer_db, comp_layer_db)

        return comp_layer_db
Example #10
0
    def calculate_compressed_cost(
            self, layer_db: LayerDatabase,
            layer_comp_ratio_list: List[LayerCompRatioPair]) -> Cost:
        """
        Calculate cost of a compressed model given a set of layers and corresponding comp-ratios
        :param layer_db: Layer database for original model
        :param layer_comp_ratio_list: List of (layer + comp-ratio) pairs
        :return: Estimated cost of the compressed model
        """

        # sort all the layers in layer_comp_ratio_list based on occurrence
        layer_comp_ratio_list = self._sort_on_occurrence(
            layer_db.model, layer_comp_ratio_list)

        detached_op_names = set()

        # Copy the db
        comp_layer_db = copy.deepcopy(layer_db)
        current_sess = comp_layer_db.model

        for layer_comp_ratio in layer_comp_ratio_list:

            orig_layer = layer_db.find_layer_by_name(
                layer_comp_ratio.layer.name)
            comp_ratio = layer_comp_ratio.comp_ratio

            if comp_ratio is not None and comp_ratio < 1.0:

                # select input channels of conv2d op to winnow
                prune_indices = self._select_inp_channels(
                    orig_layer, comp_ratio)
                if not prune_indices:
                    continue

                # Winnow the selected op and modify it's upstream affected ops
                current_sess, ordered_modules_list = winnow.winnow_tf_model(
                    current_sess,
                    self._input_op_names,
                    self._output_op_names,
                    [(orig_layer.module, prune_indices)],
                    reshape=self._allow_custom_downsample_ops,
                    in_place=True,
                    verbose=False)
                if not ordered_modules_list:
                    continue

                # Get all the detached op names from updated session graph
                for orig_op_name, _, _, _ in ordered_modules_list:
                    detached_op_names.add(orig_op_name)

        # update layer database by excluding the detached ops
        comp_layer_db.update_database(current_sess,
                                      detached_op_names,
                                      update_model=False)

        # calculate the cost of this model
        compressed_model_cost = CostCalculator.compute_model_cost(
            comp_layer_db)

        # close the session associated with compressed layer database
        comp_layer_db.model.close()

        return compressed_model_cost
    def test_datasampling_and_reconstruction(self):
        """
        Test data sampling and reconstruction logic
        """
        tf.compat.v1.reset_default_graph()
        batch_size = 1
        input_data = np.random.rand(100, 224, 224, 3)
        dataset = tf.data.Dataset.from_tensor_slices(input_data)
        dataset = dataset.batch(batch_size=batch_size)

        orig_g = tf.Graph()

        with orig_g.as_default():

            _ = VGG16(weights=None,
                      input_shape=(224, 224, 3),
                      include_top=False)
            orig_init = tf.compat.v1.global_variables_initializer()

        input_op_names = ['input_1']
        output_op_names = ['block5_pool/MaxPool']
        # create sess with graph
        orig_sess = tf.compat.v1.Session(graph=orig_g)
        # initialize all the variables in VGG16
        orig_sess.run(orig_init)

        # create layer database
        layer_db = LayerDatabase(model=orig_sess,
                                 input_shape=(1, 224, 224, 3),
                                 working_dir=None)
        conv_layer = layer_db.find_layer_by_name('block1_conv1/Conv2D')

        comp_layer_db = copy.deepcopy(layer_db)
        comp_conv_layer = comp_layer_db.find_layer_by_name(
            'block1_conv1/Conv2D')

        # get the weights before reconstruction in original model
        before_recon_weights_orig_model = layer_db.model.run(
            conv_layer.module.inputs[1])

        # get the weights before reconstruction in pruned  model
        before_recon_weights_pruned_model = comp_layer_db.model.run(
            comp_conv_layer.module.inputs[1])

        # weight should be exactly same before reconstruction in original and pruned layer database
        self.assertTrue(
            np.array_equal(before_recon_weights_orig_model,
                           before_recon_weights_pruned_model))

        cp = InputChannelPruner(input_op_names=input_op_names,
                                output_op_names=output_op_names,
                                data_set=dataset,
                                batch_size=batch_size,
                                num_reconstruction_samples=50,
                                allow_custom_downsample_ops=True)

        num_in_channels = comp_conv_layer.weight_shape[0]
        cp._data_subsample_and_reconstruction(orig_layer=conv_layer,
                                              pruned_layer=comp_conv_layer,
                                              output_mask=[1] *
                                              num_in_channels,
                                              orig_layer_db=layer_db,
                                              comp_layer_db=comp_layer_db)

        # get the weights after reconstruction
        after_recon_weights_pruned_model = comp_layer_db.model.run(
            comp_conv_layer.module.inputs[1])

        # weight should not be same before and after reconstruction
        self.assertFalse(
            np.array_equal(before_recon_weights_orig_model,
                           after_recon_weights_pruned_model))

        layer_db.model.close()
        comp_layer_db.model.close()
        # delete temp directory
        shutil.rmtree(str('./temp_meta/'))