def test_concat_get_op_product_graph(self):
        """ Test connected graph construction on a graph with concat op """

        tf.compat.v1.reset_default_graph()

        _ = concat_model()
        conn_graph = ConnectedGraph(tf.compat.v1.get_default_graph(), ['input_1'], ['concat_model/Softmax'])
        self.assertTrue(validate_branch_ops(conn_graph))
        self.assertTrue(validate_product_tensor_lists(conn_graph))
        self.assertEqual(2, conn_graph.branch_count)
        self.assertEqual(13, len(conn_graph.get_all_ops()))
        self.assertEqual(12 + len(tf.compat.v1.get_default_graph().get_collection('variables')),
                         len(conn_graph.get_all_products()))

        # Check that the order of input products to the concat op matches the order of input tensors in the tf graph
        concat_tf_op = tf.compat.v1.get_default_graph().get_operation_by_name("concatenate/concat")
        concat_op = conn_graph.get_all_ops()['concatenate/concat']
        for index, product in enumerate(concat_op.get_input_products()):
            self.assertTrue(len(product.consumers) == 1)
            self.assertEqual(product.tensor_dict[product.consumers[0]], concat_tf_op.inputs[index])
    def test_model_with_PReLU(self):
        """
        PreLU
        """

        tf.compat.v1.reset_default_graph()
        inputs = tf.keras.Input(shape=(1, 10, 10), name="inputs")

        x = tf.keras.layers.PReLU()(inputs)
        x = tf.keras.layers.ReLU()(x)

        init = tf.compat.v1.global_variables_initializer()
        sess = tf.compat.v1.Session()
        sess.run(init)

        conn_graph = ConnectedGraph(tf.compat.v1.get_default_graph(), starting_op_names=['inputs'],
                                    output_op_names=['re_lu/Relu'])

        self.assertEqual(3, len(conn_graph.get_all_ops()))
        self.assertEqual(5, len(conn_graph.get_all_ops()['p_re_lu/Relu'].internal_ops))
    def __init__(self,
                 sess: tf.compat.v1.Session,
                 input_op_names: List[str],
                 output_op_names: List[str],
                 list_of_modules_to_winnow: List[Tuple[tf.Operation,
                                                       List]] = None,
                 reshape=True,
                 in_place=False,
                 verbose=False):
        """
        MaskPropagationWinnower object initialization.

        :param sess: The model to be winnowed.
        :param input_op_names: Input operations to the model.
        :param output_op_names: List of output op names of the model, used to help ConnectedGraph determine valid ops
        (to ignore training ops for example).
        :param list_of_modules_to_winnow: A list of Tuples with each Tuple containing a module and a list of
                    channels to be winnowed for that module.
        :param reshape: f set to True a Down Sample Layer is added between modules to match the number of channels.
                    If set to False, the modules that need a Down Sample Layer will not be winnowed.
        :param in_place: If set to True, the model will be winnowed in place.
                     If set to False, a copy of the model will be winnowed.
        :param verbose: If set to True, logs detailed winnowing log messages.
        """

        super().__init__(list_of_modules_to_winnow, reshape, in_place, verbose)

        debug_level = logger.getEffectiveLevel()
        logger.debug("Current log level: %s", debug_level)

        self._conn_graph = ConnectedGraph(sess.graph, input_op_names,
                                          output_op_names)
        self._modules_by_name = None
        self._mask_propagator = MaskPropagator(self._conn_graph,
                                               model_api=ModelApi.tensorflow)
        self._module_reducer = ModuleReducer(
            self._conn_graph,
            sess,
            using_cuda=False,
            reshape=False,
            op_to_mask_dict=self._mask_propagator.op_to_mask_dict)
    def test_model_with_global_max_pool2d(self):
        """ Test connected graph construction on model with leaky relu op """
        tf.compat.v1.reset_default_graph()
        _ = model_with_global_max_pool2d()
        conn_graph = ConnectedGraph(tf.compat.v1.get_default_graph(), starting_op_names=['input_1'],
                                    output_op_names=['model_with_global_max_pool2d/Softmax'])
        self.assertTrue(validate_branch_ops(conn_graph))
        self.assertTrue(validate_product_tensor_lists(conn_graph))
        self.assertEqual(0, conn_graph.branch_count)
        self.assertEqual(6, len(conn_graph.get_all_ops()))

        # 5 products from inter module connections
        # 4 products from parameters
        self.assertEqual(9, len(conn_graph.get_all_products()))
        found_global_max_pool2d = False
        for op in conn_graph.get_all_ops().values():
            if op.type == 'GlobalMaxPool2D':
                found_global_max_pool2d = True
        self.assertTrue(found_global_max_pool2d)

        tf.compat.v1.reset_default_graph()
    def test_model_with_leaky_relu(self):
        """ Test connected graph construction on model with leaky relu op """
        tf.compat.v1.reset_default_graph()
        _ = model_with_leaky_relu()
        conn_graph = ConnectedGraph(tf.compat.v1.get_default_graph(), starting_op_names=['input_1'],
                                    output_op_names=['model_with_leaky_relu/Softmax'])
        self.assertTrue(validate_branch_ops(conn_graph))
        self.assertTrue(validate_product_tensor_lists(conn_graph))
        self.assertEqual(0, conn_graph.branch_count)
        self.assertEqual(7, len(conn_graph.get_all_ops()))

        # 6 products from inter module connections
        # 6 products from parameters
        self.assertEqual(12, len(conn_graph.get_all_products()))
        found_leaky_relu = False
        for op in conn_graph.get_all_ops().values():
            if op.type == 'LeakyRelu':
                found_leaky_relu = True
        self.assertTrue(found_leaky_relu)

        tf.compat.v1.reset_default_graph()
    def test_model_with_lstm_layer_deepspeech_time_major_true_sigmoid(self):
        """ Test connected graph construction on a model with stacked LSTM op in DeepSpeech model"""

        tf.compat.v1.reset_default_graph()
        sess = tf.compat.v1.Session()
        with sess.graph.as_default():
            inputs = tf.keras.Input(shape=(3, 100))

            # Defaults
            # return_state=False, unit_forget_bias=True
            # return_sequences=False, time_major=False
            x = tf.keras.layers.LSTM(12,
                                     recurrent_activation='sigmoid',
                                     unit_forget_bias=False,
                                     time_major=True,
                                     return_sequences=True,
                                     name='lstm_stacked')(inputs)

            x2 = tf.keras.layers.LSTM(12, name='last_lstm')(x)

            _ = tf.keras.layers.Dense(12, activation=tf.nn.softmax,
                                      name="matmul0")(x2)

            init = tf.compat.v1.global_variables_initializer()
            sess.run(init)
            # _ = tf.compat.v1.summary.FileWriter('./lstm_deepspeech', sess.graph)

        # construct a connected graph
        conn_graph = ConnectedGraph(sess.graph, ['input_1'], ['matmul0/Softmax'])

        self.assertEqual(5, len(conn_graph.get_all_ops()))
        lstm_detected = False
        for op in conn_graph.get_all_ops().values():
            if op.type == 'LSTM' and op.name == 'lstm_stacked':
                self.assertEqual(op.pattern_type, 'LSTM_Stacked_TimeMajor_True_Sigmoid')
                lstm_detected = True
                inner_list = op.internal_ops
                self.assertEqual(75, len(inner_list))
                self.assertEqual(op.get_module(), sess.graph.get_operation_by_name('lstm_stacked/while/MatMul'))
        self.assertTrue(lstm_detected)
Пример #7
0
    def test_prune_model_tf_slim(self):
        """ Punning a model with tf slim api """

        # create tf.compat.v1.Session and initialize the weights and biases with zeros
        config = tf.compat.v1.ConfigProto()
        config.gpu_options.allow_growth = True

        # create session with graph
        sess = tf.compat.v1.Session(graph=tf.Graph(), config=config)

        with sess.graph.as_default():
            # by default, model will be constructed in default graph
            x = tf.compat.v1.placeholder(tf.float32, [1, 32, 32, 3])
            _ = tf_slim_basic_model(x)
            sess.run(tf.compat.v1.global_variables_initializer())

        conn_graph_orig = ConnectedGraph(sess.graph, ['Placeholder'],
                                         ['tf_slim_model/Softmax'])
        num_ops_orig = len(conn_graph_orig.get_all_ops())

        # Create a layer database
        orig_layer_db = LayerDatabase(model=sess,
                                      input_shape=(1, 32, 32, 3),
                                      working_dir=None)
        conv1 = orig_layer_db.find_layer_by_name('Conv_1/Conv2D')
        conv1_bias = BiasUtils.get_bias_as_numpy_data(orig_layer_db.model,
                                                      conv1.module)

        layer_comp_ratio_list = [LayerCompRatioPair(conv1, Decimal(0.5))]

        spatial_svd_pruner = SpatialSvdPruner()
        comp_layer_db = spatial_svd_pruner.prune_model(orig_layer_db,
                                                       layer_comp_ratio_list,
                                                       CostMetric.mac,
                                                       trainer=None)
        # Check that svd added these ops
        _ = comp_layer_db.model.graph.get_operation_by_name('Conv_1_a/Conv2D')
        _ = comp_layer_db.model.graph.get_operation_by_name('Conv_1_b/Conv2D')

        conn_graph_new = ConnectedGraph(comp_layer_db.model.graph,
                                        ['Placeholder'],
                                        ['tf_slim_model/Softmax'])
        num_ops_new = len(conn_graph_new.get_all_ops())
        self.assertEqual(num_ops_orig + 1, num_ops_new)
        bias_add_op = comp_layer_db.model.graph.get_operation_by_name(
            'Conv_1_b/BiasAdd')
        conv_1_b_op = comp_layer_db.model.graph.get_operation_by_name(
            'Conv_1_b/Conv2D')
        self.assertEqual(
            conn_graph_new._module_identifier.get_op_info(bias_add_op),
            conn_graph_new._module_identifier.get_op_info(conv_1_b_op))
        self.assertTrue(
            np.array_equal(
                conv1_bias,
                BiasUtils.get_bias_as_numpy_data(comp_layer_db.model,
                                                 conv_1_b_op)))
    def test_model_zoo_videnn_pose_estimation_model_with_input_split(self):
        """
        create a smaller network with connections as in pose estimation model and ViDeNN model
        Testwhen input is split and fed into two different ops
        :return:
        """

        tf.compat.v1.reset_default_graph()
        inputs = tf.keras.Input(shape=(None, None, 2), name="inputs")

        x = tf.keras.layers.Conv2D(2, kernel_size=3, padding='same')(inputs)
        x = tf.keras.layers.BatchNormalization()(x)
        x = tf.nn.relu(x)
        x = tf.keras.layers.Conv2D(2, kernel_size=3, padding='same')(x)
        x = tf.keras.layers.BatchNormalization()(x)
        z = tf.keras.layers.Add()([inputs, x])
        x = tf.nn.relu(z)

        init = tf.compat.v1.global_variables_initializer()
        sess = tf.compat.v1.Session()
        sess.run(init)

        conn_graph = ConnectedGraph(tf.compat.v1.get_default_graph(), starting_op_names=['inputs'],
                                    output_op_names=['Relu_1'])

        # get input ops
        # Find the input node(s) in the graph
        input_nodes = []
        for op in conn_graph.get_all_ops().values():
            if op.inputs and op.inputs[0].is_model_input:
                input_nodes.append(op)

        # there should be two ops marked as inputs in this model
        self.assertEqual(len(input_nodes), 2)
        self.assertEqual(input_nodes[0].name, 'add/add')
        self.assertEqual(input_nodes[1].name, "conv2d/Conv2D")
Пример #9
0
    def test_bn_fold_with_no_bias(self):
        tf.compat.v1.reset_default_graph()
        inputs = tf.keras.Input(shape=(
            32,
            32,
            3,
        ))
        conv_op = tf.keras.layers.Conv2D(32, (3, 3), use_bias=False)(inputs)
        bn_op = tf.keras.layers.BatchNormalization(fused=True)(conv_op,
                                                               training=False)
        _ = tf.nn.relu(bn_op)

        init = tf.compat.v1.global_variables_initializer()
        sess = tf.compat.v1.Session()
        sess.run(init)

        conv_op = sess.graph.get_operation_by_name('conv2d/Conv2D')
        np.random.seed(0)
        w_shape = conv_op.inputs[0].shape
        numpy_data = np.random.rand(1, w_shape[1], w_shape[2], w_shape[3])

        relu_op = sess.graph.get_operation_by_name('Relu')
        baseline_output = sess.run(relu_op.outputs[0],
                                   feed_dict={conv_op.inputs[0]: numpy_data})
        old_conn_graph = ConnectedGraph(sess.graph,
                                        starting_op_names=['input_1'],
                                        output_op_names=['Relu'])

        new_sess, pairs = fold_all_batch_norms(sess, "input_1", 'Relu')

        new_conv_op = new_sess.graph.get_operation_by_name('conv2d/Conv2D')
        w2 = new_conv_op.inputs[0]
        feed_dict = {w2: numpy_data}

        new_relu_op = new_sess.graph.get_operation_by_name('Relu')
        output_after_fold = new_sess.run(new_relu_op.outputs[0],
                                         feed_dict=feed_dict)
        new_conn_graph = ConnectedGraph(new_sess.graph,
                                        starting_op_names=['input_1'],
                                        output_op_names=['Relu'])

        self.assertTrue(
            np.allclose(baseline_output, output_after_fold, atol=1.e-4))
        # New connected graph should have one less op since bn was removed
        self.assertTrue(len(old_conn_graph.get_all_ops()),
                        len(new_conn_graph.get_all_ops()) - 1)
Пример #10
0
    def find_all_convs_bn_with_activation(model,
                                          start_op_names: Union[List[str],
                                                                str],
                                          output_op_names: Union[List[str],
                                                                 str]):
        """
        uses searcher to choose convs/ linears with bn and activation info.
        :param model: tf.compat.v1.Session type
        :param start_op_names: list of strings with names of starting ops in the model
        :param output_op_names: List of output op names of the model, used to help ConnectedGraph determine valid ops
        (to ignore training ops for example).
        :return: dictionary of conv/linear layers with associated bn op / activation info
        """

        if isinstance(start_op_names, str):
            start_op_names = [start_op_names]

        if isinstance(output_op_names, str):
            output_op_names = [output_op_names]

        conn_graph = ConnectedGraph(model.graph, start_op_names,
                                    output_op_names)

        # create a list of patterns and corresponding handlers or actions to be applied for selecting
        # layers for bias correction.
        # layer_select_handler is an instance of custom handler created for bias correction.
        patterns_with_callback, layer_select_handler = BiasCorrection._conv_bn_select_custom_pattern_init(
        )

        # graph searcher looks for patterns and applies actions when matching patterns are found
        graph_searcher = GraphSearcher(conn_graph, patterns_with_callback)
        graph_searcher.find_all_patterns_in_graph_apply_actions()

        # use custom handler instance and fetch the selected layer info for bias correction
        convs_bn_activation_info_dict = layer_select_handler.get_conv_linear_bn_info_dict(
        )

        return convs_bn_activation_info_dict
    def test_keras_model_functional_with_non_fused_batchnorms_get_op_product_graph(self):
        """ Test connected graph construction on keras model functional with non fused batchnorms """
        tf.compat.v1.reset_default_graph()

        _ = keras_model_functional_with_non_fused_batchnorms()
        conn_graph = ConnectedGraph(tf.compat.v1.get_default_graph(), ['input_1'],
                                    ['keras_model_functional_with_non_fused_batchnorms/Softmax'])
        self.assertTrue(validate_branch_ops(conn_graph))
        self.assertTrue(validate_product_tensor_lists(conn_graph))
        _ = conn_graph.get_all_ops()['batch_normalization']
        _ = conn_graph.get_all_ops()['scope_1/batch_normalization_1']
        _ = conn_graph.get_all_ops()['scope_1/batch_normalization_2']
        self.assertEqual(0, conn_graph.branch_count)
        self.assertEqual(14, len(conn_graph.get_all_ops()))

        # 13 products from inter module connections
        # 22 products from parameters
        self.assertEqual(35, len(conn_graph.get_all_products()))
class GraphSearchUtils:

    """ Implements graph search utils required by CLE feature"""

    def __init__(self, model: tf.Graph, start_op_names: Union[str, List[str]], output_op_names: Union[str, List[str]]):
        if isinstance(start_op_names, str):
            start_op_names = [start_op_names]

        if isinstance(output_op_names, str):
            output_op_names = [output_op_names]

        self._connected_graph = ConnectedGraph(model, start_op_names, output_op_names)

    def find_and_replace_relu6_with_relu(self, sess: tf.compat.v1.Session) -> tf.compat.v1.Session:
        """
        finds and replaces Relu6 ops with Relu
        :return: updated session
        """
        for op in self._connected_graph.get_all_ops().values():
            if op.type in ['Relu6']:
                # send the session here, so we make the update on sess.graph (active graph)
                ReluUtils.replace_relu6_with_relu(sess, op.get_module())

        # in the end update the session
        after_relu_replace_sess = save_and_load_graph('./replace_relu6_with_relu', sess)

        return after_relu_replace_sess

    @staticmethod
    def find_downstream_layer_groups_to_scale(op, layer_groups, visited_nodes, current_group=None):
        """
        Populates all the layer groups eligible for cross layer scaling
        :param op: starting  op
        :param layer_groups: layer_groups as empty list
        :param visited_nodes: all the ops that have been visited
        :param current_group: op groups
        :return: None. Updates layer_groups[] if groups are found.
        """

        if not current_group:
            current_group = []

        if op in visited_nodes:
            return

        visited_nodes.append(op)
        logger.debug("Visiting node: {%s}", op.dotted_name)

        # If current node is Conv2D, add to the current group
        if op.type in ['Conv2D', 'DepthwiseConv2dNative']:
            current_group.append(op)

        # Terminating condition for current group
        if not (op.type in ['Conv2D', 'DepthwiseConv2dNative', 'Relu', 'PReLU', 'Pad', 'Identity']):
            if (len(current_group) > 1) and (current_group not in layer_groups):
                layer_groups.append(current_group)
                node_set = [op.dotted_name for op in current_group]
                logger.debug("Added new set of nodes: {%s}", node_set)
            current_group = []

        if op.output:
            for consumer in op.output.consumers:
                GraphSearchUtils.find_downstream_layer_groups_to_scale(consumer, layer_groups, visited_nodes,
                                                                       current_group)

        # Reached a leaf.. See if the current group has something to grab
        if (len(current_group) > 1) and (current_group not in layer_groups):
            layer_groups.append(current_group)
            node_set = [op.dotted_name for op in current_group]
            logger.debug("Added new set of nodes: {%s}", node_set)

    def find_layer_groups_to_scale_as_conn_ops(self) -> List[List[Op]]:
        """
        :return: List of groups of layers. Each group can be independently equalized
        """

        # Find the input node(s) in the graph
        input_nodes = []
        for op in self._connected_graph.get_all_ops().values():
            if op.inputs and op.inputs[0].is_model_input:
                input_nodes.append(op)

        layer_groups = []
        visited_nodes = []

        for op in input_nodes:
            self.find_downstream_layer_groups_to_scale(op=op, layer_groups=layer_groups,
                                                       visited_nodes=visited_nodes)

        return layer_groups

    def find_layer_groups_to_scale(self):
        """
        Find layer groups for scaling as tf ops
        :return: groups for scaling as tf ops
        """

        layer_groups_as_conn_graph_ops = self.find_layer_groups_to_scale_as_conn_ops()
        layer_groups_as_tf_ops, tf_op_to_conn_graph_op_map = self.convert_conn_graph_ops_to_tf_op(layer_groups_as_conn_graph_ops)

        return tf_op_to_conn_graph_op_map, layer_groups_as_tf_ops

    @staticmethod
    def convert_conn_graph_ops_to_tf_op(op_groups: List[List[Op]]) -> \
            List[List[tf.Operation]]:
        """
         Helper function to get op list as tf.Operation type to be usable for updating/scaling weights and biases
         using generic apis for tensor updates.
        :param op_groups: list of op groups as TfOperation type of used by Connected Graph
        :return: lis of op groups as tf.Operation  (standard TF op type)
        """
        tf_op_to_conn_graph_op_map = {}
        layer_groups_as_tf_ops = []
        for ops in op_groups:
            curr_group = []
            for op in ops:
                tf_op_to_conn_graph_op_map[op.get_module()] = op
                curr_group.append(op.get_module())
            layer_groups_as_tf_ops.append(curr_group)

        return layer_groups_as_tf_ops, tf_op_to_conn_graph_op_map

    @staticmethod
    def convert_layer_group_to_cls_sets(layer_group):
        """
        Helper function to convert a layer group to a list of cls sets
        :param layer_group: Given layer group to convert
        :return: List of cls sets
        """
        cls_sets = []

        prev_layer_to_scale = layer_group.pop(0)
        while layer_group:
            next_layer_to_scale = layer_group.pop(0)

            if next_layer_to_scale.type in ['DepthwiseConv2dNative']:
                next_non_depthwise_conv_layer = layer_group.pop(0)
                cls_sets.append((prev_layer_to_scale, next_layer_to_scale, next_non_depthwise_conv_layer))
                prev_layer_to_scale = next_non_depthwise_conv_layer

            else:
                cls_sets.append((prev_layer_to_scale, next_layer_to_scale))
                prev_layer_to_scale = next_layer_to_scale

        return cls_sets

    @staticmethod
    def is_relu_activation_present_in_cls_sets(cls_sets: List[ClsSet],
                                               tf_op_to_conn_graph_op_map: Dict) -> List[bool]:
        """
        check if there is Relu activations between cls sets
        :param cls_sets: cls conv op pairs
        :param tf_op_to_conn_graph_op_map: Map of tf-op => connected graph op
        :return: list of relu activation preset flags(True or False)
        corresponding to input cls_sets list
        """
        is_relu_activation_in_cls_sets = []
        for cls_set in cls_sets:
            # We need to check activation functions for all layers but the last one in the set
            # Because we are only interested in checking activation functions between the layers we will scale
            cls_set = cls_set[:-1]

            is_relu_activation_in_cls_set = ()
            for conv_op in cls_set:
                conn_graph_conv_op = tf_op_to_conn_graph_op_map[conv_op]
                is_relu_activation_in_cls_set += (ReluUtils.does_conv_have_relu_activation(conn_graph_conv_op), )

            if len(is_relu_activation_in_cls_set) == 1:
                is_relu_activation_in_cls_set = is_relu_activation_in_cls_set[0]

            is_relu_activation_in_cls_sets.append(is_relu_activation_in_cls_set)

        return is_relu_activation_in_cls_sets

    @staticmethod
    def map_op_names_to_ops(sess: tf.compat.v1.Session) -> Dict[str, tf.Operation]:
        """
        After the fold and cls , the graph is updated, so are the ops
        So, we need a way to map ops we stored on graph we began with, to perform
        high bias fold operation on latest ops in the updated graph.
        :param sess: active tf.compat.v1.Session (tf.compat.v1.Session type)
        :return: a dictionary of op names mapped to ops in the given new session.
        Note : only stores infor pertaining to bn and conv ops required by high bias fold.
        """

        tf_names_op_dict = {}
        with sess.graph.as_default():
            op_list = sess.graph.get_operations()
            for op in op_list:
                if op.type in ['Conv2D', 'DepthwiseConv2dNative', 'FusedBatchNormV3']:
                    tf_names_op_dict[op.name] = op

        return tf_names_op_dict
class MaskPropagationWinnower(aimetCommonMaskPropagationWinnower):
    """
    The MaskPropagationWinnower class implements winnowing based on propagating masks corresponding to
    each module's input channels identified to be winnowed.
    """
    def __init__(self,
                 sess: tf.compat.v1.Session,
                 input_op_names: List[str],
                 output_op_names: List[str],
                 list_of_modules_to_winnow: List[Tuple[tf.Operation,
                                                       List]] = None,
                 reshape=True,
                 in_place=False,
                 verbose=False):
        """
        MaskPropagationWinnower object initialization.

        :param sess: The model to be winnowed.
        :param input_op_names: Input operations to the model.
        :param output_op_names: List of output op names of the model, used to help ConnectedGraph determine valid ops
        (to ignore training ops for example).
        :param list_of_modules_to_winnow: A list of Tuples with each Tuple containing a module and a list of
                    channels to be winnowed for that module.
        :param reshape: f set to True a Down Sample Layer is added between modules to match the number of channels.
                    If set to False, the modules that need a Down Sample Layer will not be winnowed.
        :param in_place: If set to True, the model will be winnowed in place.
                     If set to False, a copy of the model will be winnowed.
        :param verbose: If set to True, logs detailed winnowing log messages.
        """

        super().__init__(list_of_modules_to_winnow, reshape, in_place, verbose)

        debug_level = logger.getEffectiveLevel()
        logger.debug("Current log level: %s", debug_level)

        self._conn_graph = ConnectedGraph(sess.graph, input_op_names,
                                          output_op_names)
        self._modules_by_name = None
        self._mask_propagator = MaskPropagator(self._conn_graph,
                                               model_api=ModelApi.tensorflow)
        self._module_reducer = ModuleReducer(
            self._conn_graph,
            sess,
            using_cuda=False,
            reshape=False,
            op_to_mask_dict=self._mask_propagator.op_to_mask_dict)

    def propagate_masks_and_winnow(self):
        """  For the modules to be winnowed, create and propagate the masks.
        Once mask propagation is completed, winnow the model. """

        # Propagate the masks
        self._propagate_masks()

        modified_op_list = self._mask_propagator.get_ops_with_non_default_ip_op_masks(
        )
        for name in modified_op_list:
            logger.info("Modified Op: %s", name)

        new_sess, modified_modules_dict = self._module_reducer.reduce_modules()

        if modified_modules_dict:
            ordered_module_list = self._create_modified_modules_list(
                modified_modules_dict)
        else:
            ordered_module_list = None
            logger.info(
                "No modules were winnowed. Original model is returned.")

        return new_sess, ordered_module_list

    def _propagate_masks(self):
        """  For the modules to be winnowed, set the channels to winnow and propagate the masks."""

        for tf_op, list_of_channels_to_winnow in self._list_of_modules_to_winnow:
            self.validate_winnow_api_parameters(tf_op,
                                                list_of_channels_to_winnow)

            input_channels_to_winnow = list_of_channels_to_winnow
            output_channels_to_winnow = None
            if tf_op.type in ['Conv2D', 'Dense']:
                self._mask_propagator.update_channels_to_winnow(
                    tf_op.name, self._reshape, input_channels_to_winnow,
                    output_channels_to_winnow)

        # The channels to winnow have been updated
        # Propagate the masks.
        self._mask_propagator.propagate_masks()

    @staticmethod
    def _create_modified_modules_list(
            modified_modules: Dict[str, Tuple[tf.Operation, Mask]]):
        """
        Creates and returns a list of tuples with each tuple containing the original module and its replacement module
        :param modified_modules: Dictionary mapping names of ops before winnow to a tuple of
        (name after winnow, op mask)

        """

        modified_module_list = []

        for orig_module_name, (new_module,
                               op_mask) in modified_modules.items():
            # Remove prefix of the model name
            # E.g. the module_name maybe Net.layer1.conv1, we only want layer1.conv1
            first_dot_position = orig_module_name.find('.')
            if first_dot_position != -1:
                orig_module_name = orig_module_name[first_dot_position + 1:]
            modified_module_list.append(
                (orig_module_name, new_module, op_mask.input_channel_masks,
                 op_mask.output_channel_masks))

        return modified_module_list

    def validate_winnow_api_parameters(self, tf_op: tf.Operation,
                                       list_of_channels_to_winnow: List[int]):
        """

        For a given module, validate Winnow API parameters.

        :param tf_op: tf operation whose channel numbers are being validated.
        :param list_of_channels_to_winnow: list of channels to be winnowed.
        :return:
        """

        if not tf_op.type == "Conv2D":
            logger.critical(
                "Winnowing is currently only supported for Conv2d modules. Attempting to winnow "
                "module of type %s", tf_op.type)
            raise NotImplementedError(tf_op.type)

        # Validate the list of channels.
        num_channels_to_winnow = len(list_of_channels_to_winnow)
        if num_channels_to_winnow == 0:
            raise ValueError(
                "The list of channels to winnow is empty for the module: %s" %
                tf_op.name)

        module_op = self._conn_graph.get_op_from_module_name(tf_op.name)

        max_channel_num = max(list_of_channels_to_winnow)
        max_in_channel_index = module_op.num_in_channels - 1
        if max_channel_num > max_in_channel_index:
            raise ValueError(
                "Channel number: %s exceeds module's max channel number index: %s for module: %s"
                % (max_channel_num, max_in_channel_index, tf_op.name))

        if num_channels_to_winnow == module_op.num_in_channels:
            raise ValueError(
                "Winnowing all the input channels is not allowed, module: %s" %
                tf_op.name)

        input_index = 0  # Using op index 0 to examine input to op
        if module_op.inputs[input_index].is_model_input:
            logger.critical(
                "Winnowing the first module of a model is NOT supported. Please ignore the first "
                "module and try again. First module: %s, shape %s, channels to winnow: %s",
                module_op.dotted_name, module_op.inputs[input_index].shape,
                list_of_channels_to_winnow)
            raise NotImplementedError(module_op.dotted_name)