Exemple #1
0
    def build(self):
        fd_method = self._get_method(cfg_to_method[self.fd_name])
        
        new_model = copy.deepcopy(self.origin_model)
        current_idx = 1

        for name, layer in new_model.named_modules():
            layer_type = get_layer_type(layer)

            if _exclude_layer(layer):
                continue

            if layer_type == 'Conv':
                if self.start_idx <= current_idx and layer.out_channels > layer.in_channels:

                    module, last_name = get_module_of_layer(new_model, name)
                    module._modules[str(last_name)], ranks = fd_method(layer, self.rank, self.device)
                    
                    in_channels, out_channels = layer.in_channels, layer.out_channels
                    
                    logger.info(f'In, Out channels are decomposed: [{in_channels}, {out_channels}] -> [{ranks[1]}, {ranks[0]}] at "{name}" layer')
                    
                current_idx += 1

        return new_model
    def get_prune_idx(self, i_node, pruning_ratio=0.0): 
        is_check = False
        for idx, (name, layer) in enumerate(self.model.named_modules()):

            if idx <= self.check_point or _exclude_layer(layer):
                continue

            if i_node['id'] == hash(name):
                self.check_point = idx
                is_check = True
            
            if self.check_point + 1 == idx and get_layer_type(layer) == 'BN' and is_check == True:

                assert layer.num_features == i_node['layer'].out_channels
                gamma = layer.weight.clone()

                gamma_norm = torch.abs(gamma)

                n_to_prune = int(pruning_ratio*len(gamma))
                if n_to_prune == 0:
                    return []
                threshold = torch.kthvalue(gamma_norm, k=n_to_prune).values 

                indices = torch.nonzero(gamma_norm <= threshold).view(-1).tolist()        

                return indices    
Exemple #3
0
 def _get_layer_info(self, torch_model):
     layer_info = OrderedDict()
     i = 0
     for idx, (name, layer) in enumerate(torch_model.named_modules()):
         if idx == 0 or _exclude_layer(layer):
             continue
         layer_info[i] = {'layer': layer, 'torch_name': name}
         i += 1
     return layer_info
Exemple #4
0
    def _remove_bn_from_model(self, model):
        """
        It emoves bn layer from the base model because the weights and 
        its name of conv-bn layers are conbimed when converting to onnx model 
        """
        for name, layer in model.named_modules():
            layer_type = get_layer_type(layer)

            if _exclude_layer(layer):
                continue

            if layer_type == 'BN':
                new_layer = []
                module, last_name = get_module_of_layer(model, name)
                module._modules[str(last_name)] = nn.Sequential(*new_layer)

        return model
Exemple #5
0
	def get_pruned_layers(self):

		origin_t_layers = OrderedDict()
		new_t_layers = OrderedDict()

		check_idx = -100
		hook_layers = []
		# suppose that the number of layers between origin and new models are same.
		for idx, (origin_data, new_data) in enumerate(zip(self.origin_model.named_modules(),\
														self.new_model.named_modules())):
			origin_name, origin_layer = origin_data
			new_name, new_layer = new_data

			if _exclude_layer(origin_layer):
				continue
			layer_type = get_layer_type(origin_layer)

			if layer_type == 'Conv' and origin_layer.out_channels != new_layer.out_channels:
				prev_origin_layer = origin_layer
				prev_origin_name = origin_name
				prev_new_layer = new_layer
				prev_new_name = new_name
				check_idx = idx

			if layer_type == 'BN' and check_idx + 1 == idx:
				origin_t_layers[origin_name] = origin_layer
				hook_1 = origin_layer.register_forward_hook(self.origin_forward_hook)

				new_t_layers[new_name] = new_layer
				hook_2 = new_layer.register_forward_hook(self.new_forward_hook)

				hook_layers.extend([hook_1, hook_2])
			elif check_idx + 1 == idx:
				origin_t_layers[prev_origin_name] = prev_origin_layer
				hook_1 = prev_origin_layer.register_forward_hook(self.origin_forward_hook)

				new_t_layers[prev_new_name] = prev_new_layer
				hook_2 = prev_new_layer.register_forward_hook(self.new_forward_hook)
				hook_layers.extend([hook_1, hook_2])
			
		return origin_t_layers, new_t_layers, hook_layers				
Exemple #6
0
	def get_fd_layers(self):
		
		origin_t_layers = OrderedDict()
		new_t_layers = OrderedDict()
		new_tmp_layers = []

		for name, layer in self.new_model.named_modules():

			if _exclude_layer(layer):
				continue
			layer_type = get_layer_type(layer)

			if layer_type == 'Conv' or layer_type == 'BN':
				new_tmp_layers.append([name, layer])
		
		new_tmp_len = len(new_tmp_layers)
		new_layers_idx = 0
		check_idx = -100
		hook_layers = []
		for idx, (name, layer) in enumerate(self.origin_model.named_modules()):

			if _exclude_layer(layer):
				continue
			if new_tmp_len < new_layers_idx+1:
				break
			layer_type = get_layer_type(layer)


			if layer_type == 'BN' and check_idx + 1 == idx:
				origin_t_layers[name] = layer
				hook_1 = layer.register_forward_hook(self.origin_forward_hook)

				new_t_layers[new_tmp_layers[new_layers_idx][0]] = new_tmp_layers[new_layers_idx][1]
				hook_2 = new_tmp_layers[new_layers_idx][1].register_forward_hook(self.new_forward_hook)

				hook_layers.extend([hook_1, hook_2])
			elif check_idx + 1 == idx:
				origin_t_layers[prev_origin_name] = prev_origin_layer
				hook_1 = prev_origin_layer.register_forward_hook(self.origin_forward_hook)

				new_t_layers[prev_new_name] = prev_new_layer
				hook_2 = prev_new_layer.register_forward_hook(self.new_forward_hook)
				hook_layers.extend([hook_1, hook_2])

			if layer_type == 'Conv' and name != new_tmp_layers[new_layers_idx][0]:

				new_layers_idx += 1
				while name in new_tmp_layers[new_layers_idx][0]:
					new_layers_idx += 1
				
				prev_origin_layer = layer
				prev_origin_name = name
				prev_new_layer = new_tmp_layers[new_layers_idx-1][1]
				prev_new_name = new_tmp_layers[new_layers_idx-1][0]
				check_idx = idx
				
			elif (layer_type == 'Conv' and name == new_tmp_layers[new_layers_idx][0]) or \
				(layer_type == 'BN' and name == new_tmp_layers[new_layers_idx][0]):
				new_layers_idx += 1

			
		return origin_t_layers, new_t_layers, hook_layers
Exemple #7
0
    def set_hooking(self):
        def save_fmaps(key):
            def forward_hook(module, inputs, outputs):

                if key not in self.activations:
                    self.activations[key] = inputs[0]
                else:

                    self.activations[key] = torch.cat(
                        (self.activations[key], inputs[0]), dim=0)

            return forward_hook

        prev_names = []
        prev_layers = []
        last_conv = 0
        for name, layer in reversed(list(self.model.named_modules())):

            if _exclude_layer(layer):
                continue
            layer_type = get_layer_type(layer)

            if layer_type == 'Conv':
                if last_conv == 0:

                    prev_i_layer = prev_layers[last_conv]
                    prev_i_name = prev_names[last_conv]

                    while get_layer_type(prev_i_layer) == 'BN':
                        last_conv += 1
                        prev_i_layer = prev_layers[last_conv]
                        prev_i_name = prev_names[last_conv]

                    target_layer = prev_layers[last_conv - 1]
                    target_name = prev_names[last_conv - 1]
                    self.hook_layers.append(
                        target_layer.register_forward_hook(
                            save_fmaps(target_name)))
                    last_conv = True
                self.hook_layers.append(
                    layer.register_forward_hook(save_fmaps(name)))

            if last_conv == 0:
                prev_names.append(name)
                prev_layers.append(layer)

        batches = 0
        batch_size = self.train_loader.batch_size
        with torch.no_grad():
            for images, labels in self.train_loader:
                batches += batch_size
                images = images.to(self.device)
                out = self.model(images)

                if batches >= self.cluster_args.NUM_SAMPLES:
                    break

        for key, val in self.node_graph.items():
            if 'group' in val:
                if val['input_convs'] != None:

                    for i_input in val['input_convs']:
                        self.conv2target_conv[i_input] = val['torch_name']

            if get_layer_type(
                    val['layer']) == "Linear" and 'input_convs' in val:
                if 'Conv' in val['input_convs'][0]:
                    for i_input in val['input_convs']:
                        self.conv2target_conv[i_input] = target_name
                else:
                    for i_input in val['input_convs']:
                        self.conv2target_conv[i_input] = val['torch_name']

        for i_hook in self.hook_layers:
            i_hook.remove()
Exemple #8
0
    def cluster(self, node_graph):

        node_graph = self.set_cluster_idx(node_graph)

        new_model = copy.deepcopy(self.model)

        i = 0
        clustering_info = []
        for idx, data in enumerate(new_model.named_modules()):
            name, layer = data
            if idx == 0 or _exclude_layer(layer):
                continue

            layer_type = get_layer_type(layer)
            if layer_type == 'Conv':
                prev_cluster_idx = []
                if 'input_convs' in node_graph[i]:
                    prev_cluster_idx = self.get_prev_cluster_idx(
                        node_graph=node_graph, index=i)
                cluster_idx = node_graph[i]['cluster_idx']

                keep_prev_idx = list(
                    set(range(layer.in_channels)) - set(prev_cluster_idx))
                keep_idx = list(
                    set(range(layer.out_channels)) - set(cluster_idx))

                w = layer.weight.data[:, keep_prev_idx, :, :].clone()
                layer.weight.data = w[keep_idx, :, :, :].clone()
                if layer.bias is not None:
                    layer.bias.data = layer.bias.data[keep_idx].clone()

                clustering_info.append(
                    f'Out channels are clustered: [{layer.out_channels:4d}] -> [{len(keep_idx):4d}] at "{name}" layer'
                )
                layer.out_channels = len(keep_idx)
                layer.in_channels = len(keep_prev_idx)

            elif layer_type == 'GroupConv':
                cluster_idx = node_graph[i]['cluster_idx']

                keep_idx = list(
                    set(range(layer.out_channels)) - set(pruncluster_idxe_idx))

                layer.weight.data = layer.weight.data[
                    keep_idx, :, :, :].clone()
                if layer.bias is not None:
                    layer.bias.data = layer.bias.data[keep_idx].clone()

                clustering_info.append(
                    f'Out channels are clustered: [{layer.out_channels:4d}] -> [{len(keep_idx):4d}] at "{name}" layer'
                )
                layer.out_channels = len(keep_idx)
                layer.in_channels = len(keep_idx)
                layer.groups = len(keep_idx)

            elif layer_type == 'BN':
                prev_cluster_idx = self.get_prev_cluster_idx(
                    node_graph=node_graph, index=i)
                keep_idx = list(
                    set(range(layer.num_features)) - set(prev_cluster_idx))

                layer.running_mean.data = layer.running_mean.data[
                    keep_idx].clone()
                layer.running_var.data = layer.running_var.data[
                    keep_idx].clone()
                if layer.affine:
                    layer.weight.data = layer.weight.data[keep_idx].clone()
                    layer.bias.data = layer.bias.data[keep_idx].clone()

                clustering_info.append(
                    f'Out channels are clustered: [{layer.num_features:4d}] -> [{len(keep_idx):4d}] at "{name}" layer'
                )
                layer.num_features = len(keep_idx)

            elif layer_type == 'Linear':
                if 'input_convs' in node_graph[i]:
                    prev_cluster_idx = self.get_prev_cluster_idx(
                        node_graph=node_graph, index=i)
                    keep_idx = list(
                        set(range(layer.in_features)) - set(prev_cluster_idx))

                    layer.weight.data = layer.weight.data[:, keep_idx].clone()

                    layer.in_features = len(keep_idx)
            i += 1

        return new_model, clustering_info, node_graph
Exemple #9
0
    def set_hooking(self):
        def save_fmaps(key):
            def forward_hook(module, inputs, outputs):
                if key not in self.activations:
                    self.activations[key] = outputs
                else:

                    self.activations[key] = torch.cat(
                        (self.activations[key], outputs), dim=0)

            return forward_hook

        group_to_name = defaultdict()
        bn_to_conv = defaultdict()

        for key, val in reversed(list(self.node_graph.items())):

            layer_type = get_layer_type(val['layer'])
            if layer_type == 'BN':
                bn_to_conv[val['torch_name']] = val['input_convs']

            elif layer_type == 'Conv':
                if val['group'] not in group_to_name:
                    group_to_name[val['group']] = [val['name']]
                else:
                    group_to_name[val['group']].append(val['name'])

        for name, layer in self.model.named_modules():

            if _exclude_layer(layer):
                continue

            layer_type = get_layer_type(layer)
            if layer_type == 'BN':
                link_convs = bn_to_conv[name]
                for i_link_conv in link_convs:

                    self.hook_layers.append(
                        layer.register_forward_hook(save_fmaps(i_link_conv)))

        batches = 0
        batch_size = self.train_loader.batch_size
        with torch.no_grad():
            for images, labels in self.train_loader:
                batches += batch_size
                images = images.to(self.device)
                out = self.model(images)

                if batches >= self.cluster_args.NUM_SAMPLES:
                    break

        for group, name in group_to_name.items():
            for idx, i_name in enumerate(name):
                if idx == 0:
                    group_output = self.activations[i_name]
                else:
                    group_output += self.activations[i_name]

            for i_name in name:
                self.activations[i_name] = group_output

        for i_hook in self.hook_layers:
            i_hook.remove()
Exemple #10
0
    def prune(self, node_graph):
        group_to_ratio = load_strategy(strategy_name=self.pruning_cfg.STRATEGY,
                                       group_set=self.group_set,
                                       pruning_ratio=self.pruning_cfg.
                                       STRATEGY_ARGS.PRUNING_RATIO).build()

        node_graph = self.set_prune_idx(group_to_ratio, node_graph)

        new_model = copy.deepcopy(self.model)

        i = 0
        pruning_info = []
        for idx, data in enumerate(new_model.named_modules()):
            name, layer = data
            if idx == 0 or _exclude_layer(layer):
                continue

            layer_type = get_layer_type(layer)

            if layer_type == 'Conv':
                prev_prune_idx = []
                if 'input_convs' in node_graph[i]:
                    prev_prune_idx = self.get_prev_prune_idx(
                        node_graph=node_graph, index=i)
                prune_idx = node_graph[i]['prune_idx']

                keep_prev_idx = list(
                    set(range(layer.in_channels)) - set(prev_prune_idx))
                keep_idx = list(
                    set(range(layer.out_channels)) - set(prune_idx))

                w = layer.weight.data[:, keep_prev_idx, :, :].clone()
                layer.weight.data = w[keep_idx, :, :, :].clone()
                if layer.bias is not None:
                    layer.bias.data = layer.bias.data[keep_idx].clone()

                pruning_info.append(
                    f'Out channels are pruned: [{layer.out_channels:4d}] -> [{len(keep_idx):4d}] at "{name}" layer'
                )
                layer.out_channels = len(keep_idx)
                layer.in_chaneels = len(keep_prev_idx)

            elif layer_type == 'GroupConv':
                prune_idx = node_graph[i]['prune_idx']

                keep_idx = list(
                    set(range(layer.out_channels)) - set(prune_idx))

                layer.weight.data = layer.weight.data[
                    keep_idx, :, :, :].clone()
                if layer.bias is not None:
                    layer.bias.data = layer.bias.data[keep_idx].clone()

                pruning_info.append(
                    f'Out channels are pruned: [{layer.out_channels:4d}] -> [{len(keep_idx):4d}] at "{name}" layer'
                )
                layer.out_channels = len(keep_idx)
                layer.in_channels = len(keep_idx)
                layer.groups = len(keep_idx)

            elif layer_type == 'BN':
                prev_prune_idx = self.get_prev_prune_idx(node_graph=node_graph,
                                                         index=i)
                keep_idx = list(
                    set(range(layer.num_features)) - set(prev_prune_idx))

                layer.running_mean.data = layer.running_mean.data[
                    keep_idx].clone()
                layer.running_var.data = layer.running_var.data[
                    keep_idx].clone()
                if layer.affine:
                    layer.weight.data = layer.weight.data[keep_idx].clone()
                    layer.bias.data = layer.bias.data[keep_idx].clone()

                pruning_info.append(
                    f'Out channels are pruned: [{layer.num_features:4d}] -> [{len(keep_idx):4d}] at "{name}" layer'
                )
                layer.num_features = len(keep_idx)

            elif layer_type == 'Linear':
                if 'input_convs' in node_graph[i]:
                    prev_prune_idx = self.get_prev_prune_idx(
                        node_graph=node_graph, index=i)
                    keep_idx = list(
                        set(range(layer.in_features)) - set(prev_prune_idx))

                    layer.weight.data = layer.weight.data[:, keep_idx].clone()

            i += 1

        return new_model, pruning_info, node_graph