def input_prune(self, model: NNCFNetwork, nx_node, graph: NNCFGraph, nx_graph: nx.DiGraph): input_mask = nx_node['input_masks'][0] if input_mask is None: return nncf_node = graph._nx_node_to_nncf_node(nx_node) node_module = model.get_module_by_scope( nncf_node.op_exec_context.scope_in_model) bool_mask = torch.tensor(input_mask, dtype=torch.bool) old_num_clannels = int(node_module.weight.size(0)) new_num_channels = int(torch.sum(input_mask)) node_module.num_features = new_num_channels node_module.weight = torch.nn.Parameter(node_module.weight[bool_mask]) node_module.bias = torch.nn.Parameter(node_module.bias[bool_mask]) node_module.running_mean = torch.nn.Parameter( node_module.running_mean[bool_mask], requires_grad=False) node_module.running_var = torch.nn.Parameter( node_module.running_var[bool_mask], requires_grad=False) nncf_logger.info( 'Pruned BatchNorm {} by input mask. Old num features: {}, new num features:' ' {}.'.format(nx_node['key'], old_num_clannels, new_num_channels))
def input_prune(self, model: NNCFNetwork, nx_node, graph: NNCFGraph, nx_graph: nx.DiGraph): input_mask = nx_node['input_masks'][0] if input_mask is None: return bool_mask = torch.tensor(input_mask, dtype=torch.bool) new_num_channels = int(torch.sum(input_mask)) nncf_node = graph._nx_node_to_nncf_node(nx_node) node_module = model.get_module_by_scope( nncf_node.op_exec_context.scope_in_model) is_depthwise = nx_node['is_depthwise'] old_num_clannels = int(node_module.weight.size(1)) if is_depthwise: # In depthwise case prune output channels by input mask, here only fix for new number of input channels node_module.groups = new_num_channels node_module.in_channels = new_num_channels else: out_channels = node_module.weight.size(0) broadcasted_mask = bool_mask.repeat(out_channels).view( out_channels, bool_mask.size(0)) new_weight_shape = list(node_module.weight.shape) new_weight_shape[1] = new_num_channels node_module.in_channels = new_num_channels node_module.weight = torch.nn.Parameter( node_module.weight[broadcasted_mask].view(new_weight_shape)) nncf_logger.info( 'Pruned Convolution {} by input mask. Old input filters number: {}, new filters number:' ' {}.'.format(nx_node['key'], old_num_clannels, new_num_channels))
def mask_propagation(self, model: NNCFNetwork, nx_node, graph: NNCFGraph, nx_graph: nx.DiGraph): output_mask = None accept_pruned_input = True is_depthwise = False input_masks = get_input_masks(nx_node, nx_graph) nncf_node = graph._nx_node_to_nncf_node(nx_node) node_module = model.get_module_by_scope( nncf_node.op_exec_context.scope_in_model) if node_module.pre_ops: output_mask = node_module.pre_ops[ '0'].op.binary_filter_pruning_mask # In case of group convs we can't prune by output filters if node_module.groups != 1: if node_module.weight.size(1) == 1: # Depthwise case is_depthwise = True output_mask = input_masks[0] else: accept_pruned_input = False output_mask = None nx_node['input_masks'] = input_masks nx_node['output_mask'] = output_mask nx_node['accept_pruned_input'] = accept_pruned_input nx_node['is_depthwise'] = is_depthwise
def output_prune(self, model: NNCFNetwork, nx_node, graph: NNCFGraph, nx_graph: nx.DiGraph): output_mask = nx_node['output_mask'] if output_mask is None: return bool_mask = torch.tensor(output_mask, dtype=torch.bool) new_num_channels = int(torch.sum(bool_mask)) nncf_node = graph._nx_node_to_nncf_node(nx_node) node_module = model.get_module_by_scope( nncf_node.op_exec_context.scope_in_model) old_num_clannels = int(node_module.weight.size(1)) in_channels = node_module.weight.size(0) broadcasted_mask = bool_mask.repeat(in_channels).view( in_channels, bool_mask.size(0)) new_weight_shape = list(node_module.weight.shape) new_weight_shape[1] = new_num_channels node_module.out_channels = new_num_channels node_module.weight = torch.nn.Parameter( node_module.weight[broadcasted_mask].view(new_weight_shape)) if node_module.bias is not None: node_module.bias = torch.nn.Parameter(node_module.bias[bool_mask]) nncf_logger.info( 'Pruned ConvTranspose {} by pruning mask. Old output filters number: {}, new filters number:' ' {}.'.format(nx_node['key'], old_num_clannels, node_module.out_channels))
def mask_propagation(self, model: NNCFNetwork, nx_node, graph: NNCFGraph, nx_graph: nx.DiGraph): output_mask = None accept_pruned_input = True input_masks = get_input_masks(nx_node, nx_graph) nncf_node = graph._nx_node_to_nncf_node(nx_node) node_module = model.get_module_by_scope( nncf_node.op_exec_context.scope_in_model) if node_module.pre_ops: output_mask = node_module.pre_ops[ '0'].op.binary_filter_pruning_mask nx_node['input_masks'] = input_masks nx_node['output_mask'] = output_mask nx_node['accept_pruned_input'] = accept_pruned_input
def input_prune(self, model: NNCFNetwork, nx_node, graph: NNCFGraph, nx_graph: nx.DiGraph): input_mask = nx_node['input_masks'][0] if input_mask is None: return bool_mask = torch.tensor(input_mask, dtype=torch.bool) nncf_node = graph._nx_node_to_nncf_node(nx_node) node_module = model.get_module_by_scope( nncf_node.op_exec_context.scope_in_model) old_num_clannels = int(node_module.weight.size(0)) node_module.in_channels = int(torch.sum(bool_mask)) node_module.weight = torch.nn.Parameter(node_module.weight[bool_mask]) nncf_logger.info( 'Pruned ConvTranspose {} by input mask. Old input filters number: {}, new filters number:' ' {}.'.format(nx_node['key'], old_num_clannels, node_module.in_channels))
def output_prune(self, model: NNCFNetwork, nx_node, graph: NNCFGraph, nx_graph: nx.DiGraph): mask = nx_node['output_mask'] if mask is None: return bool_mask = torch.tensor(mask, dtype=torch.bool) nncf_node = graph._nx_node_to_nncf_node(nx_node) node_module = model.get_module_by_scope( nncf_node.op_exec_context.scope_in_model) old_num_clannels = int(node_module.weight.size(0)) node_module.out_channels = int(torch.sum(mask)) node_module.weight = torch.nn.Parameter(node_module.weight[bool_mask]) if node_module.bias is not None and not nx_node['is_depthwise']: node_module.bias = torch.nn.Parameter(node_module.bias[bool_mask]) nncf_logger.info( 'Pruned Convolution {} by pruning mask. Old output filters number: {}, new filters number:' ' {}.'.format(nx_node['key'], old_num_clannels, node_module.out_channels))
def input_prune(self, model: NNCFNetwork, nx_node: dict, graph: NNCFGraph, nx_graph: nx.DiGraph): input_mask = nx_node['input_masks'][0] if input_mask is None: return bool_mask = torch.tensor(input_mask, dtype=torch.bool) nncf_node = graph._nx_node_to_nncf_node(nx_node) node_module = model.get_module_by_scope( nncf_node.op_exec_context.scope_in_model) if isinstance(node_module, tuple(NNCF_WRAPPED_USER_MODULES_DICT)): assert node_module.target_weight_dim_for_compression == 0,\ "Implemented only for target_weight_dim_for_compression == 0" old_num_clannels = int(node_module.weight.size(0)) new_num_channels = int(torch.sum(input_mask)) node_module.weight = torch.nn.Parameter( node_module.weight[bool_mask]) node_module.n_channels = new_num_channels nncf_logger.info( 'Pruned Elementwise {} by input mask. Old num features: {}, new num features:' ' {}.'.format(nx_node['key'], old_num_clannels, new_num_channels))