def build(self): fd_method = self._get_method(cfg_to_method[self.fd_name]) new_model = copy.deepcopy(self.origin_model) current_idx = 1 for name, layer in new_model.named_modules(): layer_type = get_layer_type(layer) if _exclude_layer(layer): continue if layer_type == 'Conv': if self.start_idx <= current_idx and layer.out_channels > layer.in_channels: module, last_name = get_module_of_layer(new_model, name) module._modules[str(last_name)], ranks = fd_method(layer, self.rank, self.device) in_channels, out_channels = layer.in_channels, layer.out_channels logger.info(f'In, Out channels are decomposed: [{in_channels}, {out_channels}] -> [{ranks[1]}, {ranks[0]}] at "{name}" layer') current_idx += 1 return new_model
def get_prune_idx(self, i_node, pruning_ratio=0.0): is_check = False for idx, (name, layer) in enumerate(self.model.named_modules()): if idx <= self.check_point or _exclude_layer(layer): continue if i_node['id'] == hash(name): self.check_point = idx is_check = True if self.check_point + 1 == idx and get_layer_type(layer) == 'BN' and is_check == True: assert layer.num_features == i_node['layer'].out_channels gamma = layer.weight.clone() gamma_norm = torch.abs(gamma) n_to_prune = int(pruning_ratio*len(gamma)) if n_to_prune == 0: return [] threshold = torch.kthvalue(gamma_norm, k=n_to_prune).values indices = torch.nonzero(gamma_norm <= threshold).view(-1).tolist() return indices
def _get_layer_info(self, torch_model): layer_info = OrderedDict() i = 0 for idx, (name, layer) in enumerate(torch_model.named_modules()): if idx == 0 or _exclude_layer(layer): continue layer_info[i] = {'layer': layer, 'torch_name': name} i += 1 return layer_info
def _remove_bn_from_model(self, model): """ It emoves bn layer from the base model because the weights and its name of conv-bn layers are conbimed when converting to onnx model """ for name, layer in model.named_modules(): layer_type = get_layer_type(layer) if _exclude_layer(layer): continue if layer_type == 'BN': new_layer = [] module, last_name = get_module_of_layer(model, name) module._modules[str(last_name)] = nn.Sequential(*new_layer) return model
def get_pruned_layers(self): origin_t_layers = OrderedDict() new_t_layers = OrderedDict() check_idx = -100 hook_layers = [] # suppose that the number of layers between origin and new models are same. for idx, (origin_data, new_data) in enumerate(zip(self.origin_model.named_modules(),\ self.new_model.named_modules())): origin_name, origin_layer = origin_data new_name, new_layer = new_data if _exclude_layer(origin_layer): continue layer_type = get_layer_type(origin_layer) if layer_type == 'Conv' and origin_layer.out_channels != new_layer.out_channels: prev_origin_layer = origin_layer prev_origin_name = origin_name prev_new_layer = new_layer prev_new_name = new_name check_idx = idx if layer_type == 'BN' and check_idx + 1 == idx: origin_t_layers[origin_name] = origin_layer hook_1 = origin_layer.register_forward_hook(self.origin_forward_hook) new_t_layers[new_name] = new_layer hook_2 = new_layer.register_forward_hook(self.new_forward_hook) hook_layers.extend([hook_1, hook_2]) elif check_idx + 1 == idx: origin_t_layers[prev_origin_name] = prev_origin_layer hook_1 = prev_origin_layer.register_forward_hook(self.origin_forward_hook) new_t_layers[prev_new_name] = prev_new_layer hook_2 = prev_new_layer.register_forward_hook(self.new_forward_hook) hook_layers.extend([hook_1, hook_2]) return origin_t_layers, new_t_layers, hook_layers
def get_fd_layers(self): origin_t_layers = OrderedDict() new_t_layers = OrderedDict() new_tmp_layers = [] for name, layer in self.new_model.named_modules(): if _exclude_layer(layer): continue layer_type = get_layer_type(layer) if layer_type == 'Conv' or layer_type == 'BN': new_tmp_layers.append([name, layer]) new_tmp_len = len(new_tmp_layers) new_layers_idx = 0 check_idx = -100 hook_layers = [] for idx, (name, layer) in enumerate(self.origin_model.named_modules()): if _exclude_layer(layer): continue if new_tmp_len < new_layers_idx+1: break layer_type = get_layer_type(layer) if layer_type == 'BN' and check_idx + 1 == idx: origin_t_layers[name] = layer hook_1 = layer.register_forward_hook(self.origin_forward_hook) new_t_layers[new_tmp_layers[new_layers_idx][0]] = new_tmp_layers[new_layers_idx][1] hook_2 = new_tmp_layers[new_layers_idx][1].register_forward_hook(self.new_forward_hook) hook_layers.extend([hook_1, hook_2]) elif check_idx + 1 == idx: origin_t_layers[prev_origin_name] = prev_origin_layer hook_1 = prev_origin_layer.register_forward_hook(self.origin_forward_hook) new_t_layers[prev_new_name] = prev_new_layer hook_2 = prev_new_layer.register_forward_hook(self.new_forward_hook) hook_layers.extend([hook_1, hook_2]) if layer_type == 'Conv' and name != new_tmp_layers[new_layers_idx][0]: new_layers_idx += 1 while name in new_tmp_layers[new_layers_idx][0]: new_layers_idx += 1 prev_origin_layer = layer prev_origin_name = name prev_new_layer = new_tmp_layers[new_layers_idx-1][1] prev_new_name = new_tmp_layers[new_layers_idx-1][0] check_idx = idx elif (layer_type == 'Conv' and name == new_tmp_layers[new_layers_idx][0]) or \ (layer_type == 'BN' and name == new_tmp_layers[new_layers_idx][0]): new_layers_idx += 1 return origin_t_layers, new_t_layers, hook_layers
def set_hooking(self): def save_fmaps(key): def forward_hook(module, inputs, outputs): if key not in self.activations: self.activations[key] = inputs[0] else: self.activations[key] = torch.cat( (self.activations[key], inputs[0]), dim=0) return forward_hook prev_names = [] prev_layers = [] last_conv = 0 for name, layer in reversed(list(self.model.named_modules())): if _exclude_layer(layer): continue layer_type = get_layer_type(layer) if layer_type == 'Conv': if last_conv == 0: prev_i_layer = prev_layers[last_conv] prev_i_name = prev_names[last_conv] while get_layer_type(prev_i_layer) == 'BN': last_conv += 1 prev_i_layer = prev_layers[last_conv] prev_i_name = prev_names[last_conv] target_layer = prev_layers[last_conv - 1] target_name = prev_names[last_conv - 1] self.hook_layers.append( target_layer.register_forward_hook( save_fmaps(target_name))) last_conv = True self.hook_layers.append( layer.register_forward_hook(save_fmaps(name))) if last_conv == 0: prev_names.append(name) prev_layers.append(layer) batches = 0 batch_size = self.train_loader.batch_size with torch.no_grad(): for images, labels in self.train_loader: batches += batch_size images = images.to(self.device) out = self.model(images) if batches >= self.cluster_args.NUM_SAMPLES: break for key, val in self.node_graph.items(): if 'group' in val: if val['input_convs'] != None: for i_input in val['input_convs']: self.conv2target_conv[i_input] = val['torch_name'] if get_layer_type( val['layer']) == "Linear" and 'input_convs' in val: if 'Conv' in val['input_convs'][0]: for i_input in val['input_convs']: self.conv2target_conv[i_input] = target_name else: for i_input in val['input_convs']: self.conv2target_conv[i_input] = val['torch_name'] for i_hook in self.hook_layers: i_hook.remove()
def cluster(self, node_graph): node_graph = self.set_cluster_idx(node_graph) new_model = copy.deepcopy(self.model) i = 0 clustering_info = [] for idx, data in enumerate(new_model.named_modules()): name, layer = data if idx == 0 or _exclude_layer(layer): continue layer_type = get_layer_type(layer) if layer_type == 'Conv': prev_cluster_idx = [] if 'input_convs' in node_graph[i]: prev_cluster_idx = self.get_prev_cluster_idx( node_graph=node_graph, index=i) cluster_idx = node_graph[i]['cluster_idx'] keep_prev_idx = list( set(range(layer.in_channels)) - set(prev_cluster_idx)) keep_idx = list( set(range(layer.out_channels)) - set(cluster_idx)) w = layer.weight.data[:, keep_prev_idx, :, :].clone() layer.weight.data = w[keep_idx, :, :, :].clone() if layer.bias is not None: layer.bias.data = layer.bias.data[keep_idx].clone() clustering_info.append( f'Out channels are clustered: [{layer.out_channels:4d}] -> [{len(keep_idx):4d}] at "{name}" layer' ) layer.out_channels = len(keep_idx) layer.in_channels = len(keep_prev_idx) elif layer_type == 'GroupConv': cluster_idx = node_graph[i]['cluster_idx'] keep_idx = list( set(range(layer.out_channels)) - set(pruncluster_idxe_idx)) layer.weight.data = layer.weight.data[ keep_idx, :, :, :].clone() if layer.bias is not None: layer.bias.data = layer.bias.data[keep_idx].clone() clustering_info.append( f'Out channels are clustered: [{layer.out_channels:4d}] -> [{len(keep_idx):4d}] at "{name}" layer' ) layer.out_channels = len(keep_idx) layer.in_channels = len(keep_idx) layer.groups = len(keep_idx) elif layer_type == 'BN': prev_cluster_idx = self.get_prev_cluster_idx( node_graph=node_graph, index=i) keep_idx = list( set(range(layer.num_features)) - set(prev_cluster_idx)) layer.running_mean.data = layer.running_mean.data[ keep_idx].clone() layer.running_var.data = layer.running_var.data[ keep_idx].clone() if layer.affine: layer.weight.data = layer.weight.data[keep_idx].clone() layer.bias.data = layer.bias.data[keep_idx].clone() clustering_info.append( f'Out channels are clustered: [{layer.num_features:4d}] -> [{len(keep_idx):4d}] at "{name}" layer' ) layer.num_features = len(keep_idx) elif layer_type == 'Linear': if 'input_convs' in node_graph[i]: prev_cluster_idx = self.get_prev_cluster_idx( node_graph=node_graph, index=i) keep_idx = list( set(range(layer.in_features)) - set(prev_cluster_idx)) layer.weight.data = layer.weight.data[:, keep_idx].clone() layer.in_features = len(keep_idx) i += 1 return new_model, clustering_info, node_graph
def set_hooking(self): def save_fmaps(key): def forward_hook(module, inputs, outputs): if key not in self.activations: self.activations[key] = outputs else: self.activations[key] = torch.cat( (self.activations[key], outputs), dim=0) return forward_hook group_to_name = defaultdict() bn_to_conv = defaultdict() for key, val in reversed(list(self.node_graph.items())): layer_type = get_layer_type(val['layer']) if layer_type == 'BN': bn_to_conv[val['torch_name']] = val['input_convs'] elif layer_type == 'Conv': if val['group'] not in group_to_name: group_to_name[val['group']] = [val['name']] else: group_to_name[val['group']].append(val['name']) for name, layer in self.model.named_modules(): if _exclude_layer(layer): continue layer_type = get_layer_type(layer) if layer_type == 'BN': link_convs = bn_to_conv[name] for i_link_conv in link_convs: self.hook_layers.append( layer.register_forward_hook(save_fmaps(i_link_conv))) batches = 0 batch_size = self.train_loader.batch_size with torch.no_grad(): for images, labels in self.train_loader: batches += batch_size images = images.to(self.device) out = self.model(images) if batches >= self.cluster_args.NUM_SAMPLES: break for group, name in group_to_name.items(): for idx, i_name in enumerate(name): if idx == 0: group_output = self.activations[i_name] else: group_output += self.activations[i_name] for i_name in name: self.activations[i_name] = group_output for i_hook in self.hook_layers: i_hook.remove()
def prune(self, node_graph): group_to_ratio = load_strategy(strategy_name=self.pruning_cfg.STRATEGY, group_set=self.group_set, pruning_ratio=self.pruning_cfg. STRATEGY_ARGS.PRUNING_RATIO).build() node_graph = self.set_prune_idx(group_to_ratio, node_graph) new_model = copy.deepcopy(self.model) i = 0 pruning_info = [] for idx, data in enumerate(new_model.named_modules()): name, layer = data if idx == 0 or _exclude_layer(layer): continue layer_type = get_layer_type(layer) if layer_type == 'Conv': prev_prune_idx = [] if 'input_convs' in node_graph[i]: prev_prune_idx = self.get_prev_prune_idx( node_graph=node_graph, index=i) prune_idx = node_graph[i]['prune_idx'] keep_prev_idx = list( set(range(layer.in_channels)) - set(prev_prune_idx)) keep_idx = list( set(range(layer.out_channels)) - set(prune_idx)) w = layer.weight.data[:, keep_prev_idx, :, :].clone() layer.weight.data = w[keep_idx, :, :, :].clone() if layer.bias is not None: layer.bias.data = layer.bias.data[keep_idx].clone() pruning_info.append( f'Out channels are pruned: [{layer.out_channels:4d}] -> [{len(keep_idx):4d}] at "{name}" layer' ) layer.out_channels = len(keep_idx) layer.in_chaneels = len(keep_prev_idx) elif layer_type == 'GroupConv': prune_idx = node_graph[i]['prune_idx'] keep_idx = list( set(range(layer.out_channels)) - set(prune_idx)) layer.weight.data = layer.weight.data[ keep_idx, :, :, :].clone() if layer.bias is not None: layer.bias.data = layer.bias.data[keep_idx].clone() pruning_info.append( f'Out channels are pruned: [{layer.out_channels:4d}] -> [{len(keep_idx):4d}] at "{name}" layer' ) layer.out_channels = len(keep_idx) layer.in_channels = len(keep_idx) layer.groups = len(keep_idx) elif layer_type == 'BN': prev_prune_idx = self.get_prev_prune_idx(node_graph=node_graph, index=i) keep_idx = list( set(range(layer.num_features)) - set(prev_prune_idx)) layer.running_mean.data = layer.running_mean.data[ keep_idx].clone() layer.running_var.data = layer.running_var.data[ keep_idx].clone() if layer.affine: layer.weight.data = layer.weight.data[keep_idx].clone() layer.bias.data = layer.bias.data[keep_idx].clone() pruning_info.append( f'Out channels are pruned: [{layer.num_features:4d}] -> [{len(keep_idx):4d}] at "{name}" layer' ) layer.num_features = len(keep_idx) elif layer_type == 'Linear': if 'input_convs' in node_graph[i]: prev_prune_idx = self.get_prev_prune_idx( node_graph=node_graph, index=i) keep_idx = list( set(range(layer.in_features)) - set(prev_prune_idx)) layer.weight.data = layer.weight.data[:, keep_idx].clone() i += 1 return new_model, pruning_info, node_graph