def test_concat_get_op_product_graph(self): """ Test connected graph construction on a graph with concat op """ tf.compat.v1.reset_default_graph() _ = concat_model() conn_graph = ConnectedGraph(tf.compat.v1.get_default_graph(), ['input_1'], ['concat_model/Softmax']) self.assertTrue(validate_branch_ops(conn_graph)) self.assertTrue(validate_product_tensor_lists(conn_graph)) self.assertEqual(2, conn_graph.branch_count) self.assertEqual(13, len(conn_graph.get_all_ops())) self.assertEqual(12 + len(tf.compat.v1.get_default_graph().get_collection('variables')), len(conn_graph.get_all_products())) # Check that the order of input products to the concat op matches the order of input tensors in the tf graph concat_tf_op = tf.compat.v1.get_default_graph().get_operation_by_name("concatenate/concat") concat_op = conn_graph.get_all_ops()['concatenate/concat'] for index, product in enumerate(concat_op.get_input_products()): self.assertTrue(len(product.consumers) == 1) self.assertEqual(product.tensor_dict[product.consumers[0]], concat_tf_op.inputs[index])
def test_model_with_PReLU(self): """ PreLU """ tf.compat.v1.reset_default_graph() inputs = tf.keras.Input(shape=(1, 10, 10), name="inputs") x = tf.keras.layers.PReLU()(inputs) x = tf.keras.layers.ReLU()(x) init = tf.compat.v1.global_variables_initializer() sess = tf.compat.v1.Session() sess.run(init) conn_graph = ConnectedGraph(tf.compat.v1.get_default_graph(), starting_op_names=['inputs'], output_op_names=['re_lu/Relu']) self.assertEqual(3, len(conn_graph.get_all_ops())) self.assertEqual(5, len(conn_graph.get_all_ops()['p_re_lu/Relu'].internal_ops))
def __init__(self, sess: tf.compat.v1.Session, input_op_names: List[str], output_op_names: List[str], list_of_modules_to_winnow: List[Tuple[tf.Operation, List]] = None, reshape=True, in_place=False, verbose=False): """ MaskPropagationWinnower object initialization. :param sess: The model to be winnowed. :param input_op_names: Input operations to the model. :param output_op_names: List of output op names of the model, used to help ConnectedGraph determine valid ops (to ignore training ops for example). :param list_of_modules_to_winnow: A list of Tuples with each Tuple containing a module and a list of channels to be winnowed for that module. :param reshape: f set to True a Down Sample Layer is added between modules to match the number of channels. If set to False, the modules that need a Down Sample Layer will not be winnowed. :param in_place: If set to True, the model will be winnowed in place. If set to False, a copy of the model will be winnowed. :param verbose: If set to True, logs detailed winnowing log messages. """ super().__init__(list_of_modules_to_winnow, reshape, in_place, verbose) debug_level = logger.getEffectiveLevel() logger.debug("Current log level: %s", debug_level) self._conn_graph = ConnectedGraph(sess.graph, input_op_names, output_op_names) self._modules_by_name = None self._mask_propagator = MaskPropagator(self._conn_graph, model_api=ModelApi.tensorflow) self._module_reducer = ModuleReducer( self._conn_graph, sess, using_cuda=False, reshape=False, op_to_mask_dict=self._mask_propagator.op_to_mask_dict)
def test_model_with_global_max_pool2d(self): """ Test connected graph construction on model with leaky relu op """ tf.compat.v1.reset_default_graph() _ = model_with_global_max_pool2d() conn_graph = ConnectedGraph(tf.compat.v1.get_default_graph(), starting_op_names=['input_1'], output_op_names=['model_with_global_max_pool2d/Softmax']) self.assertTrue(validate_branch_ops(conn_graph)) self.assertTrue(validate_product_tensor_lists(conn_graph)) self.assertEqual(0, conn_graph.branch_count) self.assertEqual(6, len(conn_graph.get_all_ops())) # 5 products from inter module connections # 4 products from parameters self.assertEqual(9, len(conn_graph.get_all_products())) found_global_max_pool2d = False for op in conn_graph.get_all_ops().values(): if op.type == 'GlobalMaxPool2D': found_global_max_pool2d = True self.assertTrue(found_global_max_pool2d) tf.compat.v1.reset_default_graph()
def test_model_with_leaky_relu(self): """ Test connected graph construction on model with leaky relu op """ tf.compat.v1.reset_default_graph() _ = model_with_leaky_relu() conn_graph = ConnectedGraph(tf.compat.v1.get_default_graph(), starting_op_names=['input_1'], output_op_names=['model_with_leaky_relu/Softmax']) self.assertTrue(validate_branch_ops(conn_graph)) self.assertTrue(validate_product_tensor_lists(conn_graph)) self.assertEqual(0, conn_graph.branch_count) self.assertEqual(7, len(conn_graph.get_all_ops())) # 6 products from inter module connections # 6 products from parameters self.assertEqual(12, len(conn_graph.get_all_products())) found_leaky_relu = False for op in conn_graph.get_all_ops().values(): if op.type == 'LeakyRelu': found_leaky_relu = True self.assertTrue(found_leaky_relu) tf.compat.v1.reset_default_graph()
def test_model_with_lstm_layer_deepspeech_time_major_true_sigmoid(self): """ Test connected graph construction on a model with stacked LSTM op in DeepSpeech model""" tf.compat.v1.reset_default_graph() sess = tf.compat.v1.Session() with sess.graph.as_default(): inputs = tf.keras.Input(shape=(3, 100)) # Defaults # return_state=False, unit_forget_bias=True # return_sequences=False, time_major=False x = tf.keras.layers.LSTM(12, recurrent_activation='sigmoid', unit_forget_bias=False, time_major=True, return_sequences=True, name='lstm_stacked')(inputs) x2 = tf.keras.layers.LSTM(12, name='last_lstm')(x) _ = tf.keras.layers.Dense(12, activation=tf.nn.softmax, name="matmul0")(x2) init = tf.compat.v1.global_variables_initializer() sess.run(init) # _ = tf.compat.v1.summary.FileWriter('./lstm_deepspeech', sess.graph) # construct a connected graph conn_graph = ConnectedGraph(sess.graph, ['input_1'], ['matmul0/Softmax']) self.assertEqual(5, len(conn_graph.get_all_ops())) lstm_detected = False for op in conn_graph.get_all_ops().values(): if op.type == 'LSTM' and op.name == 'lstm_stacked': self.assertEqual(op.pattern_type, 'LSTM_Stacked_TimeMajor_True_Sigmoid') lstm_detected = True inner_list = op.internal_ops self.assertEqual(75, len(inner_list)) self.assertEqual(op.get_module(), sess.graph.get_operation_by_name('lstm_stacked/while/MatMul')) self.assertTrue(lstm_detected)
def test_prune_model_tf_slim(self): """ Punning a model with tf slim api """ # create tf.compat.v1.Session and initialize the weights and biases with zeros config = tf.compat.v1.ConfigProto() config.gpu_options.allow_growth = True # create session with graph sess = tf.compat.v1.Session(graph=tf.Graph(), config=config) with sess.graph.as_default(): # by default, model will be constructed in default graph x = tf.compat.v1.placeholder(tf.float32, [1, 32, 32, 3]) _ = tf_slim_basic_model(x) sess.run(tf.compat.v1.global_variables_initializer()) conn_graph_orig = ConnectedGraph(sess.graph, ['Placeholder'], ['tf_slim_model/Softmax']) num_ops_orig = len(conn_graph_orig.get_all_ops()) # Create a layer database orig_layer_db = LayerDatabase(model=sess, input_shape=(1, 32, 32, 3), working_dir=None) conv1 = orig_layer_db.find_layer_by_name('Conv_1/Conv2D') conv1_bias = BiasUtils.get_bias_as_numpy_data(orig_layer_db.model, conv1.module) layer_comp_ratio_list = [LayerCompRatioPair(conv1, Decimal(0.5))] spatial_svd_pruner = SpatialSvdPruner() comp_layer_db = spatial_svd_pruner.prune_model(orig_layer_db, layer_comp_ratio_list, CostMetric.mac, trainer=None) # Check that svd added these ops _ = comp_layer_db.model.graph.get_operation_by_name('Conv_1_a/Conv2D') _ = comp_layer_db.model.graph.get_operation_by_name('Conv_1_b/Conv2D') conn_graph_new = ConnectedGraph(comp_layer_db.model.graph, ['Placeholder'], ['tf_slim_model/Softmax']) num_ops_new = len(conn_graph_new.get_all_ops()) self.assertEqual(num_ops_orig + 1, num_ops_new) bias_add_op = comp_layer_db.model.graph.get_operation_by_name( 'Conv_1_b/BiasAdd') conv_1_b_op = comp_layer_db.model.graph.get_operation_by_name( 'Conv_1_b/Conv2D') self.assertEqual( conn_graph_new._module_identifier.get_op_info(bias_add_op), conn_graph_new._module_identifier.get_op_info(conv_1_b_op)) self.assertTrue( np.array_equal( conv1_bias, BiasUtils.get_bias_as_numpy_data(comp_layer_db.model, conv_1_b_op)))
def test_model_zoo_videnn_pose_estimation_model_with_input_split(self): """ create a smaller network with connections as in pose estimation model and ViDeNN model Testwhen input is split and fed into two different ops :return: """ tf.compat.v1.reset_default_graph() inputs = tf.keras.Input(shape=(None, None, 2), name="inputs") x = tf.keras.layers.Conv2D(2, kernel_size=3, padding='same')(inputs) x = tf.keras.layers.BatchNormalization()(x) x = tf.nn.relu(x) x = tf.keras.layers.Conv2D(2, kernel_size=3, padding='same')(x) x = tf.keras.layers.BatchNormalization()(x) z = tf.keras.layers.Add()([inputs, x]) x = tf.nn.relu(z) init = tf.compat.v1.global_variables_initializer() sess = tf.compat.v1.Session() sess.run(init) conn_graph = ConnectedGraph(tf.compat.v1.get_default_graph(), starting_op_names=['inputs'], output_op_names=['Relu_1']) # get input ops # Find the input node(s) in the graph input_nodes = [] for op in conn_graph.get_all_ops().values(): if op.inputs and op.inputs[0].is_model_input: input_nodes.append(op) # there should be two ops marked as inputs in this model self.assertEqual(len(input_nodes), 2) self.assertEqual(input_nodes[0].name, 'add/add') self.assertEqual(input_nodes[1].name, "conv2d/Conv2D")
def test_bn_fold_with_no_bias(self): tf.compat.v1.reset_default_graph() inputs = tf.keras.Input(shape=( 32, 32, 3, )) conv_op = tf.keras.layers.Conv2D(32, (3, 3), use_bias=False)(inputs) bn_op = tf.keras.layers.BatchNormalization(fused=True)(conv_op, training=False) _ = tf.nn.relu(bn_op) init = tf.compat.v1.global_variables_initializer() sess = tf.compat.v1.Session() sess.run(init) conv_op = sess.graph.get_operation_by_name('conv2d/Conv2D') np.random.seed(0) w_shape = conv_op.inputs[0].shape numpy_data = np.random.rand(1, w_shape[1], w_shape[2], w_shape[3]) relu_op = sess.graph.get_operation_by_name('Relu') baseline_output = sess.run(relu_op.outputs[0], feed_dict={conv_op.inputs[0]: numpy_data}) old_conn_graph = ConnectedGraph(sess.graph, starting_op_names=['input_1'], output_op_names=['Relu']) new_sess, pairs = fold_all_batch_norms(sess, "input_1", 'Relu') new_conv_op = new_sess.graph.get_operation_by_name('conv2d/Conv2D') w2 = new_conv_op.inputs[0] feed_dict = {w2: numpy_data} new_relu_op = new_sess.graph.get_operation_by_name('Relu') output_after_fold = new_sess.run(new_relu_op.outputs[0], feed_dict=feed_dict) new_conn_graph = ConnectedGraph(new_sess.graph, starting_op_names=['input_1'], output_op_names=['Relu']) self.assertTrue( np.allclose(baseline_output, output_after_fold, atol=1.e-4)) # New connected graph should have one less op since bn was removed self.assertTrue(len(old_conn_graph.get_all_ops()), len(new_conn_graph.get_all_ops()) - 1)
def find_all_convs_bn_with_activation(model, start_op_names: Union[List[str], str], output_op_names: Union[List[str], str]): """ uses searcher to choose convs/ linears with bn and activation info. :param model: tf.compat.v1.Session type :param start_op_names: list of strings with names of starting ops in the model :param output_op_names: List of output op names of the model, used to help ConnectedGraph determine valid ops (to ignore training ops for example). :return: dictionary of conv/linear layers with associated bn op / activation info """ if isinstance(start_op_names, str): start_op_names = [start_op_names] if isinstance(output_op_names, str): output_op_names = [output_op_names] conn_graph = ConnectedGraph(model.graph, start_op_names, output_op_names) # create a list of patterns and corresponding handlers or actions to be applied for selecting # layers for bias correction. # layer_select_handler is an instance of custom handler created for bias correction. patterns_with_callback, layer_select_handler = BiasCorrection._conv_bn_select_custom_pattern_init( ) # graph searcher looks for patterns and applies actions when matching patterns are found graph_searcher = GraphSearcher(conn_graph, patterns_with_callback) graph_searcher.find_all_patterns_in_graph_apply_actions() # use custom handler instance and fetch the selected layer info for bias correction convs_bn_activation_info_dict = layer_select_handler.get_conv_linear_bn_info_dict( ) return convs_bn_activation_info_dict
def test_keras_model_functional_with_non_fused_batchnorms_get_op_product_graph(self): """ Test connected graph construction on keras model functional with non fused batchnorms """ tf.compat.v1.reset_default_graph() _ = keras_model_functional_with_non_fused_batchnorms() conn_graph = ConnectedGraph(tf.compat.v1.get_default_graph(), ['input_1'], ['keras_model_functional_with_non_fused_batchnorms/Softmax']) self.assertTrue(validate_branch_ops(conn_graph)) self.assertTrue(validate_product_tensor_lists(conn_graph)) _ = conn_graph.get_all_ops()['batch_normalization'] _ = conn_graph.get_all_ops()['scope_1/batch_normalization_1'] _ = conn_graph.get_all_ops()['scope_1/batch_normalization_2'] self.assertEqual(0, conn_graph.branch_count) self.assertEqual(14, len(conn_graph.get_all_ops())) # 13 products from inter module connections # 22 products from parameters self.assertEqual(35, len(conn_graph.get_all_products()))
class GraphSearchUtils: """ Implements graph search utils required by CLE feature""" def __init__(self, model: tf.Graph, start_op_names: Union[str, List[str]], output_op_names: Union[str, List[str]]): if isinstance(start_op_names, str): start_op_names = [start_op_names] if isinstance(output_op_names, str): output_op_names = [output_op_names] self._connected_graph = ConnectedGraph(model, start_op_names, output_op_names) def find_and_replace_relu6_with_relu(self, sess: tf.compat.v1.Session) -> tf.compat.v1.Session: """ finds and replaces Relu6 ops with Relu :return: updated session """ for op in self._connected_graph.get_all_ops().values(): if op.type in ['Relu6']: # send the session here, so we make the update on sess.graph (active graph) ReluUtils.replace_relu6_with_relu(sess, op.get_module()) # in the end update the session after_relu_replace_sess = save_and_load_graph('./replace_relu6_with_relu', sess) return after_relu_replace_sess @staticmethod def find_downstream_layer_groups_to_scale(op, layer_groups, visited_nodes, current_group=None): """ Populates all the layer groups eligible for cross layer scaling :param op: starting op :param layer_groups: layer_groups as empty list :param visited_nodes: all the ops that have been visited :param current_group: op groups :return: None. Updates layer_groups[] if groups are found. """ if not current_group: current_group = [] if op in visited_nodes: return visited_nodes.append(op) logger.debug("Visiting node: {%s}", op.dotted_name) # If current node is Conv2D, add to the current group if op.type in ['Conv2D', 'DepthwiseConv2dNative']: current_group.append(op) # Terminating condition for current group if not (op.type in ['Conv2D', 'DepthwiseConv2dNative', 'Relu', 'PReLU', 'Pad', 'Identity']): if (len(current_group) > 1) and (current_group not in layer_groups): layer_groups.append(current_group) node_set = [op.dotted_name for op in current_group] logger.debug("Added new set of nodes: {%s}", node_set) current_group = [] if op.output: for consumer in op.output.consumers: GraphSearchUtils.find_downstream_layer_groups_to_scale(consumer, layer_groups, visited_nodes, current_group) # Reached a leaf.. See if the current group has something to grab if (len(current_group) > 1) and (current_group not in layer_groups): layer_groups.append(current_group) node_set = [op.dotted_name for op in current_group] logger.debug("Added new set of nodes: {%s}", node_set) def find_layer_groups_to_scale_as_conn_ops(self) -> List[List[Op]]: """ :return: List of groups of layers. Each group can be independently equalized """ # Find the input node(s) in the graph input_nodes = [] for op in self._connected_graph.get_all_ops().values(): if op.inputs and op.inputs[0].is_model_input: input_nodes.append(op) layer_groups = [] visited_nodes = [] for op in input_nodes: self.find_downstream_layer_groups_to_scale(op=op, layer_groups=layer_groups, visited_nodes=visited_nodes) return layer_groups def find_layer_groups_to_scale(self): """ Find layer groups for scaling as tf ops :return: groups for scaling as tf ops """ layer_groups_as_conn_graph_ops = self.find_layer_groups_to_scale_as_conn_ops() layer_groups_as_tf_ops, tf_op_to_conn_graph_op_map = self.convert_conn_graph_ops_to_tf_op(layer_groups_as_conn_graph_ops) return tf_op_to_conn_graph_op_map, layer_groups_as_tf_ops @staticmethod def convert_conn_graph_ops_to_tf_op(op_groups: List[List[Op]]) -> \ List[List[tf.Operation]]: """ Helper function to get op list as tf.Operation type to be usable for updating/scaling weights and biases using generic apis for tensor updates. :param op_groups: list of op groups as TfOperation type of used by Connected Graph :return: lis of op groups as tf.Operation (standard TF op type) """ tf_op_to_conn_graph_op_map = {} layer_groups_as_tf_ops = [] for ops in op_groups: curr_group = [] for op in ops: tf_op_to_conn_graph_op_map[op.get_module()] = op curr_group.append(op.get_module()) layer_groups_as_tf_ops.append(curr_group) return layer_groups_as_tf_ops, tf_op_to_conn_graph_op_map @staticmethod def convert_layer_group_to_cls_sets(layer_group): """ Helper function to convert a layer group to a list of cls sets :param layer_group: Given layer group to convert :return: List of cls sets """ cls_sets = [] prev_layer_to_scale = layer_group.pop(0) while layer_group: next_layer_to_scale = layer_group.pop(0) if next_layer_to_scale.type in ['DepthwiseConv2dNative']: next_non_depthwise_conv_layer = layer_group.pop(0) cls_sets.append((prev_layer_to_scale, next_layer_to_scale, next_non_depthwise_conv_layer)) prev_layer_to_scale = next_non_depthwise_conv_layer else: cls_sets.append((prev_layer_to_scale, next_layer_to_scale)) prev_layer_to_scale = next_layer_to_scale return cls_sets @staticmethod def is_relu_activation_present_in_cls_sets(cls_sets: List[ClsSet], tf_op_to_conn_graph_op_map: Dict) -> List[bool]: """ check if there is Relu activations between cls sets :param cls_sets: cls conv op pairs :param tf_op_to_conn_graph_op_map: Map of tf-op => connected graph op :return: list of relu activation preset flags(True or False) corresponding to input cls_sets list """ is_relu_activation_in_cls_sets = [] for cls_set in cls_sets: # We need to check activation functions for all layers but the last one in the set # Because we are only interested in checking activation functions between the layers we will scale cls_set = cls_set[:-1] is_relu_activation_in_cls_set = () for conv_op in cls_set: conn_graph_conv_op = tf_op_to_conn_graph_op_map[conv_op] is_relu_activation_in_cls_set += (ReluUtils.does_conv_have_relu_activation(conn_graph_conv_op), ) if len(is_relu_activation_in_cls_set) == 1: is_relu_activation_in_cls_set = is_relu_activation_in_cls_set[0] is_relu_activation_in_cls_sets.append(is_relu_activation_in_cls_set) return is_relu_activation_in_cls_sets @staticmethod def map_op_names_to_ops(sess: tf.compat.v1.Session) -> Dict[str, tf.Operation]: """ After the fold and cls , the graph is updated, so are the ops So, we need a way to map ops we stored on graph we began with, to perform high bias fold operation on latest ops in the updated graph. :param sess: active tf.compat.v1.Session (tf.compat.v1.Session type) :return: a dictionary of op names mapped to ops in the given new session. Note : only stores infor pertaining to bn and conv ops required by high bias fold. """ tf_names_op_dict = {} with sess.graph.as_default(): op_list = sess.graph.get_operations() for op in op_list: if op.type in ['Conv2D', 'DepthwiseConv2dNative', 'FusedBatchNormV3']: tf_names_op_dict[op.name] = op return tf_names_op_dict
class MaskPropagationWinnower(aimetCommonMaskPropagationWinnower): """ The MaskPropagationWinnower class implements winnowing based on propagating masks corresponding to each module's input channels identified to be winnowed. """ def __init__(self, sess: tf.compat.v1.Session, input_op_names: List[str], output_op_names: List[str], list_of_modules_to_winnow: List[Tuple[tf.Operation, List]] = None, reshape=True, in_place=False, verbose=False): """ MaskPropagationWinnower object initialization. :param sess: The model to be winnowed. :param input_op_names: Input operations to the model. :param output_op_names: List of output op names of the model, used to help ConnectedGraph determine valid ops (to ignore training ops for example). :param list_of_modules_to_winnow: A list of Tuples with each Tuple containing a module and a list of channels to be winnowed for that module. :param reshape: f set to True a Down Sample Layer is added between modules to match the number of channels. If set to False, the modules that need a Down Sample Layer will not be winnowed. :param in_place: If set to True, the model will be winnowed in place. If set to False, a copy of the model will be winnowed. :param verbose: If set to True, logs detailed winnowing log messages. """ super().__init__(list_of_modules_to_winnow, reshape, in_place, verbose) debug_level = logger.getEffectiveLevel() logger.debug("Current log level: %s", debug_level) self._conn_graph = ConnectedGraph(sess.graph, input_op_names, output_op_names) self._modules_by_name = None self._mask_propagator = MaskPropagator(self._conn_graph, model_api=ModelApi.tensorflow) self._module_reducer = ModuleReducer( self._conn_graph, sess, using_cuda=False, reshape=False, op_to_mask_dict=self._mask_propagator.op_to_mask_dict) def propagate_masks_and_winnow(self): """ For the modules to be winnowed, create and propagate the masks. Once mask propagation is completed, winnow the model. """ # Propagate the masks self._propagate_masks() modified_op_list = self._mask_propagator.get_ops_with_non_default_ip_op_masks( ) for name in modified_op_list: logger.info("Modified Op: %s", name) new_sess, modified_modules_dict = self._module_reducer.reduce_modules() if modified_modules_dict: ordered_module_list = self._create_modified_modules_list( modified_modules_dict) else: ordered_module_list = None logger.info( "No modules were winnowed. Original model is returned.") return new_sess, ordered_module_list def _propagate_masks(self): """ For the modules to be winnowed, set the channels to winnow and propagate the masks.""" for tf_op, list_of_channels_to_winnow in self._list_of_modules_to_winnow: self.validate_winnow_api_parameters(tf_op, list_of_channels_to_winnow) input_channels_to_winnow = list_of_channels_to_winnow output_channels_to_winnow = None if tf_op.type in ['Conv2D', 'Dense']: self._mask_propagator.update_channels_to_winnow( tf_op.name, self._reshape, input_channels_to_winnow, output_channels_to_winnow) # The channels to winnow have been updated # Propagate the masks. self._mask_propagator.propagate_masks() @staticmethod def _create_modified_modules_list( modified_modules: Dict[str, Tuple[tf.Operation, Mask]]): """ Creates and returns a list of tuples with each tuple containing the original module and its replacement module :param modified_modules: Dictionary mapping names of ops before winnow to a tuple of (name after winnow, op mask) """ modified_module_list = [] for orig_module_name, (new_module, op_mask) in modified_modules.items(): # Remove prefix of the model name # E.g. the module_name maybe Net.layer1.conv1, we only want layer1.conv1 first_dot_position = orig_module_name.find('.') if first_dot_position != -1: orig_module_name = orig_module_name[first_dot_position + 1:] modified_module_list.append( (orig_module_name, new_module, op_mask.input_channel_masks, op_mask.output_channel_masks)) return modified_module_list def validate_winnow_api_parameters(self, tf_op: tf.Operation, list_of_channels_to_winnow: List[int]): """ For a given module, validate Winnow API parameters. :param tf_op: tf operation whose channel numbers are being validated. :param list_of_channels_to_winnow: list of channels to be winnowed. :return: """ if not tf_op.type == "Conv2D": logger.critical( "Winnowing is currently only supported for Conv2d modules. Attempting to winnow " "module of type %s", tf_op.type) raise NotImplementedError(tf_op.type) # Validate the list of channels. num_channels_to_winnow = len(list_of_channels_to_winnow) if num_channels_to_winnow == 0: raise ValueError( "The list of channels to winnow is empty for the module: %s" % tf_op.name) module_op = self._conn_graph.get_op_from_module_name(tf_op.name) max_channel_num = max(list_of_channels_to_winnow) max_in_channel_index = module_op.num_in_channels - 1 if max_channel_num > max_in_channel_index: raise ValueError( "Channel number: %s exceeds module's max channel number index: %s for module: %s" % (max_channel_num, max_in_channel_index, tf_op.name)) if num_channels_to_winnow == module_op.num_in_channels: raise ValueError( "Winnowing all the input channels is not allowed, module: %s" % tf_op.name) input_index = 0 # Using op index 0 to examine input to op if module_op.inputs[input_index].is_model_input: logger.critical( "Winnowing the first module of a model is NOT supported. Please ignore the first " "module and try again. First module: %s, shape %s, channels to winnow: %s", module_op.dotted_name, module_op.inputs[input_index].shape, list_of_channels_to_winnow) raise NotImplementedError(module_op.dotted_name)