def _delete_channels(self, node, inputs, input_masks, channels=None, layer_name=None): """Delete selected channels of node.outbound_layer. Add it to the graph. """ old_layer = node.outbound_layer old_layer_output = utils.single_element(node.output_tensors) # Create a mask to propagate the deleted channels to downstream layers new_delete_mask = self._make_delete_mask(old_layer, channels) if len(set(channels)) == getattr(old_layer, utils.get_channels_attr(old_layer)): self._replace_tensors[old_layer_output] = (None, new_delete_mask) return None # If this layer has already been operated on, use the cached copy of # the new layer. Otherwise, apply the inbound delete mask and # delete channels to obtain the new layer if old_layer in self._new_layers_map.keys(): new_layer = self._new_layers_map[old_layer] else: temp_layer, new_mask = self._apply_delete_mask(node, input_masks) # This call is needed to initialise input_shape and output_shape temp_layer(utils.single_element(inputs)) new_layer = self._delete_channel_weights(temp_layer, channels) if layer_name: new_layer.name = layer_name self._new_layers_map[old_layer] = new_layer new_output = new_layer(utils.single_element(inputs)) # Replace the original layer's output with the modified layer's output self._replace_tensors[old_layer_output] = (new_output, new_delete_mask)
def layer_test_helper_flatten_1d(layer, channel_index): # This should test that the output is the correct shape so it should pass # into a Dense layer rather than a Conv layer. # The weighted layer is the previous layer, # Create model main_input = Input(shape=list(random.randint(10, 20, size=2))) x = Conv1D(3, 3)(main_input) x = layer(x) x = Flatten()(x) main_output = Dense(5)(x) model = Model(inputs=main_input, outputs=main_output) # Delete channels del_layer_index = 1 next_layer_index = 4 del_layer = model.layers[del_layer_index] surgeon = Surgeon(model) surgeon.add_job("delete_channels", del_layer, channels=channel_index) new_model = surgeon.operate() new_w = new_model.layers[next_layer_index].get_weights() # Calculate next layer's correct weights flat_sz = np.prod(layer.get_output_shape_at(0)[1:]) channel_count = getattr(del_layer, utils.get_channels_attr(del_layer)) channel_index = [i % channel_count for i in channel_index] delete_indices = [ x + i for i in range(0, flat_sz, channel_count) for x in channel_index ] correct_w = model.layers[next_layer_index].get_weights() correct_w[0] = np.delete(correct_w[0], delete_indices, axis=0) assert weights_equal(correct_w, new_w)
def layer_test_helper_2d_global(layer, channel_index, data_format): # This should test that the output is the correct shape so it should pass # into a Dense layer rather than a Conv layer. # The weighted layer is the previous layer, # Create model main_input = Input(shape=list(random.randint(10, 20, size=3))) x = Conv2D(3, [3, 3], data_format=data_format)(main_input) x = layer(x) main_output = Dense(5)(x) model = Model(inputs=main_input, outputs=main_output) # Delete channels del_layer_index = 1 next_layer_index = 3 del_layer = model.layers[del_layer_index] new_model = operations.delete_channels(model, del_layer, channel_index) new_w = new_model.layers[next_layer_index].get_weights() # Calculate next layer's correct weights channel_count = getattr(del_layer, utils.get_channels_attr(del_layer)) channel_index = [i % channel_count for i in channel_index] correct_w = model.layers[next_layer_index].get_weights() correct_w[0] = np.delete(correct_w[0], channel_index, axis=0) assert weights_equal(correct_w, new_w)
def test_delete_channels_flatten(channel_index, data_format): # Create model main_input = Input(shape=list(random.randint(4, 10, size=3))) x = Conv2D(3, [3, 3], data_format=data_format)(main_input) x = Flatten()(x) main_output = Dense(5)(x) model = Model(inputs=main_input, outputs=main_output) # Delete channels layer_index = 1 next_layer_index = 3 layer = model.layers[layer_index] new_model = operations.delete_channels(model, layer, channel_index) new_w = new_model.layers[next_layer_index].get_weights() # Calculate next layer's correct weights flat_sz = np.prod(layer.output_shape[1:]) channel_count = getattr(layer, utils.get_channels_attr(layer)) channel_index = [i % channel_count for i in channel_index] if data_format == 'channels_first': delete_indices = [x*flat_sz//channel_count + i for x in channel_index for i in range(0, flat_sz//channel_count, )] elif data_format == 'channels_last': delete_indices = [x + i for i in range(0, flat_sz, channel_count) for x in channel_index] else: raise ValueError correct_w = model.layers[next_layer_index].get_weights() correct_w[0] = np.delete(correct_w[0], delete_indices, axis=0) assert weights_equal(correct_w, new_w)
def test_delete_channels_merge_concatenate(channel_index, data_format): # This should test that the output is the correct shape so it should pass # into a Dense layer rather than a Conv layer. # The weighted layer is the previous layer, # Create model if data_format == "channels_first": axis = 1 elif data_format == "channels_last": axis = -1 else: raise ValueError input_shape = list(random.randint(10, 20, size=3)) input_1 = Input(shape=input_shape) input_2 = Input(shape=input_shape) x = Conv2D(3, [3, 3], data_format=data_format, name="conv_1")(input_1) y = Conv2D(3, [3, 3], data_format=data_format, name="conv_2")(input_2) x = Concatenate(axis=axis, name="cat_1")([x, y]) x = Flatten()(x) main_output = Dense(5, name="dense_1")(x) model = Model(inputs=[input_1, input_2], outputs=main_output) old_w = model.get_layer("dense_1").get_weights() # Delete channels layer = model.get_layer("cat_1") del_layer = model.get_layer("conv_1") surgeon = Surgeon(model, copy=True) surgeon.add_job("delete_channels", del_layer, channels=channel_index) new_model = surgeon.operate() new_w = new_model.get_layer("dense_1").get_weights() # Calculate next layer's correct weights flat_sz = np.prod(layer.get_output_shape_at(0)[1:]) channel_count = getattr(del_layer, utils.get_channels_attr(del_layer)) channel_index = [i % channel_count for i in channel_index] if data_format == "channels_first": delete_indices = [ x * flat_sz // 2 // channel_count + i for x in channel_index for i in range( 0, flat_sz // 2 // channel_count, ) ] elif data_format == "channels_last": delete_indices = [ x + i for i in range(0, flat_sz, channel_count * 2) for x in channel_index ] else: raise ValueError correct_w = model.get_layer("dense_1").get_weights() correct_w[0] = np.delete(correct_w[0], delete_indices, axis=0) assert weights_equal(correct_w, new_w)
def recursive_test_helper(layer, channel_index): main_input = Input(shape=[32, 10]) x = layer(main_input) x = GRU(4, return_sequences=False)(x) main_output = Dense(5)(x) model = Model(inputs=main_input, outputs=main_output) # Delete channels del_layer_index = 1 next_layer_index = 2 del_layer = model.layers[del_layer_index] new_model = operations.delete_channels(model, del_layer, channel_index) new_w = new_model.layers[next_layer_index].get_weights() # Calculate next layer's correct weights channel_count = getattr(del_layer, utils.get_channels_attr(del_layer)) channel_index = [i % channel_count for i in channel_index] correct_w = model.layers[next_layer_index].get_weights() correct_w[0] = np.delete(correct_w[0], channel_index, axis=0) assert weights_equal(correct_w, new_w)
def _delete_channel_weights(self, layer, channel_indices): """Delete channels from layer and remove the corresponding weights. Arguments: layer: A layer whose channels are to be deleted channel_indices: The indices of the channels to be deleted. Returns: A new layer with the channels and corresponding weights deleted. """ layer_config = layer.get_config() channels_attr = utils.get_channels_attr(layer) channel_count = layer_config[channels_attr] # Check inputs if any([i + 1 > channel_count for i in channel_indices]): raise ValueError( 'Channels_index value(s) out of range. ' 'This layer only has {0} channels.'.format(channel_count)) print('Deleting {0}/{1} channels from layer: {2}'.format( len(channel_indices), channel_count, layer.name)) # numpy.delete ignores negative indices in lists: wrap indices channel_indices = [i % channel_count for i in channel_indices] # Reduce layer channel count in config. layer_config[channels_attr] -= len(channel_indices) # Delete weights corresponding to deleted channels from config. # Except for recurrent layers, the weights' channels dimension is last. # Each recurrent layer type has a different internal weights layout. if layer.__class__.__name__ == 'SimpleRNN': weights = [ np.delete(w, channel_indices, axis=-1) for w in layer.get_weights() ] weights[1] = np.delete(weights[1], channel_indices, axis=0) elif layer.__class__.__name__ == 'GRU': # Repeat the channel indices for all internal GRU weights. channel_indices_gru = [ layer.units * m + i for m in range(3) for i in channel_indices ] weights = [ np.delete(w, channel_indices_gru, axis=-1) for w in layer.get_weights() ] weights[1] = np.delete(weights[1], channel_indices, axis=0) elif layer.__class__.__name__ == 'LSTM': # Repeat the channel indices for all interal LSTM weights. channel_indices_lstm = [ layer.units * m + i for m in range(4) for i in channel_indices ] weights = [ np.delete(w, channel_indices_lstm, axis=-1) for w in layer.get_weights() ] weights[1] = np.delete(weights[1], channel_indices, axis=0) else: weights = [ np.delete(w, channel_indices, axis=-1) for w in layer.get_weights() ] layer_config['weights'] = weights # Create new layer from the modified configuration and return it. return type(layer).from_config(layer_config)