def test_magnitude_pruning(): # Create a 4-D tensor of 1s a = torch.ones(3, 64, 32, 32) # Change one element a[1, 4, 17, 31] = 0.2 # Create a masks dictionary and populate it with one ParameterMasker zeros_mask_dict = {} masker = distiller.ParameterMasker('a') zeros_mask_dict['a'] = masker # Try to use a MagnitudeParameterPruner with defining a default threshold with pytest.raises(AssertionError): pruner = distiller.pruning.MagnitudeParameterPruner("test", None) # Now define the default threshold thresholds = {"*": 0.4} pruner = distiller.pruning.MagnitudeParameterPruner("test", thresholds) assert distiller.sparsity(a) == 0 # Create a mask for parameter 'a' pruner.set_param_mask(a, 'a', zeros_mask_dict, None) assert common.almost_equal(distiller.sparsity(zeros_mask_dict['a'].mask), 1/distiller.volume(a)) # Let's now use the masker to prune a parameter masker = zeros_mask_dict['a'] masker.apply_mask(a) assert common.almost_equal(distiller.sparsity(a), 1/distiller.volume(a)) # We can use the masker on other tensors, if we want (and if they have the correct shape). # Remember that the mask was created already, so we're not thresholding - we are pruning b = torch.ones(3, 64, 32, 32) b[:] = 0.3 masker.apply_mask(b) assert common.almost_equal(distiller.sparsity(b), 1/distiller.volume(a))
def test_threshold_mask(): # Create a 4-D tensor of 1s a = torch.ones(3, 64, 32, 32) # Change one element a[1, 4, 17, 31] = 0.2 # Create and apply a mask mask = distiller.threshold_mask(a, threshold=0.3) assert np.sum(distiller.to_np(mask)) == (distiller.volume(a) - 1) assert mask[1, 4, 17, 31] == 0 assert common.almost_equal(distiller.sparsity(mask), 1/distiller.volume(a))
def module_visitor(self, input, output, df, model, weights_vol, macs, attrs=None): in_features_shape = input[0].size() out_features_shape = output.size() mod_name = distiller.model_find_module_name(model, self) df.loc[len(df.index)] = ([mod_name, self.__class__.__name__, attrs if attrs is not None else '', distiller.size_to_str(in_features_shape), distiller.volume(input[0]), distiller.size_to_str(out_features_shape), distiller.volume(output), int(weights_vol), int(macs)])
def conv_visitor(self, input, output, df, model, memo): assert isinstance(self, torch.nn.Conv2d) if self in memo: return weights_vol = distiller.volume(self.weight) # Multiply-accumulate operations: MACs = volume(OFM) * (#IFM * K^2) / #Groups # Bias is ignored macs = (distiller.volume(output) * (self.in_channels / self.groups * self.kernel_size[0] * self.kernel_size[1])) attrs = 'k=' + '('+(', ').join(['%d' % v for v in self.kernel_size])+')' module_visitor(self, input, output, df, model, weights_vol, macs, attrs)
def get_model_representation(self): """Initialize an embedding representation of the entire model. At runtime, a specific row in the embedding matrix is chosen (depending on the current state) and the dynamic fields in the resulting state-embedding vector are updated. """ num_states = self.net_wrapper.num_pruned_layers() network_obs = np.empty(shape=(num_states, ObservationLen)) for state_id, layer_id in enumerate(self.net_wrapper.model_metadata.pruned_idxs): layer = self.net_wrapper.get_layer(layer_id) layer_macs = self.net_wrapper.layer_macs(layer) conv_module = distiller.model_find_module(self.model, layer.name) obs = [state_id, conv_module.out_channels, conv_module.in_channels, layer.ifm_h, layer.ifm_w, layer.stride[0], layer.k, distiller.volume(conv_module.weight), layer_macs, 0, 0, 0] network_obs[state_id:] = np.array(obs) # Feature normalization for feature in range(ObservationLen): feature_vec = network_obs[:, feature] fmin = min(feature_vec) fmax = max(feature_vec) if fmax - fmin > 0: network_obs[:, feature] = (feature_vec - fmin) / (fmax - fmin) # msglogger.debug("model representation=\n{}".format(network_obs)) return network_obs
def fc_visitor(self, input, output, df, model, memo): assert isinstance(self, torch.nn.Linear) if self in memo: return # Multiply-accumulate operations: MACs = #IFM * #OFM # Bias is ignored weights_vol = macs = distiller.volume(self.weight) module_visitor(self, input, output, df, model, weights_vol, macs)
def test_weights_size_attr(dataset, arch, parallel): model = create_model(False, dataset, arch, parallel=parallel) sgraph = SummaryGraph(model, distiller.get_dummy_input(dataset)) distiller.assign_layer_fq_names(model) for name, mod in model.named_modules(): if isinstance(mod, nn.Conv2d) or isinstance(mod, nn.Linear): op = sgraph.find_op(name) assert op is not None assert op['attrs']['weights_vol'] == distiller.volume(mod.weight)
def module_visitor(self, input, output, df, model, weights_vol, macs, attrs=None): in_features_shape = input[0].size() out_features_shape = output.size() param_name = distiller.model_find_param_name(model, self.weight.data) if param_name is None: return mod_name = param_name[:param_name.find(".weight")] df.loc[len(df.index)] = ([ mod_name, self.__class__.__name__, attrs if attrs is not None else '', distiller.size_to_str(in_features_shape), distiller.volume(input[0]), distiller.size_to_str(out_features_shape), distiller.volume(output), weights_vol, int(macs) ])
def create_png(sgraph, display_param_nodes=False, rankdir="TB", styles=None): """Create a PNG object containing a graphiz-dot graph of the network, as represented by SummaryGraph 'sgraph'. Args: sgraph (SummaryGraph): the SummaryGraph instance to draw. display_param_nodes (boolean): if True, draw the parameter nodes rankdir: diagram direction. 'TB'/'BT' is Top-to-Bottom/Bottom-to-Top 'LR'/'R/L' is Left-to-Rt/Rt-to-Left styles: a dictionary of styles. Key is module name. Value is a legal pydot style dictionary. For example: styles['conv1'] = {'shape': 'oval', 'fillcolor': 'gray', 'style': 'rounded, filled'} """ def annotate_op_node(op): if op["type"] == "Conv": return [ "sh={}".format(distiller.size2str( op["attrs"]["kernel_shape"])), "g={}".format(str(op["attrs"]["group"])), ] return "" op_nodes = [op["name"] for op in sgraph.ops.values()] data_nodes = [] param_nodes = [] for id, param in sgraph.params.items(): n_data = (id, str(distiller.volume(param["shape"])), str(param["shape"])) if data_node_has_parent(sgraph, id): data_nodes.append(n_data) else: param_nodes.append(n_data) edges = sgraph.edges if not display_param_nodes: # Use only the edges that don't have a parameter source non_param_ids = op_nodes + [dn[0] for dn in data_nodes] edges = [edge for edge in sgraph.edges if edge.src in non_param_ids] param_nodes = None op_nodes_desc = [(op["name"], op["type"], *annotate_op_node(op)) for op in sgraph.ops.values()] pydot_graph = create_pydot_graph(op_nodes_desc, data_nodes, param_nodes, edges, rankdir, styles) png = pydot_graph.create_png() return png
def __remove_structures(self, idx, fraction_to_prune, prune_what="channels"): """Physically remove channels and corresponding filters from the model""" if idx not in range(self.num_layers()): raise ValueError("idx=%d is not in correct range (0-%d)" % (idx, self.num_layers())) if fraction_to_prune < 0: raise ValueError("fraction_to_prune=%f is illegal" % (fraction_to_prune)) if fraction_to_prune == 0: return 0 if fraction_to_prune == 1.0: # For now, prevent the removal of entire layers fraction_to_prune = ALMOST_ONE layer = self.conv_layers[idx] conv_pname = layer.name + ".weight" conv_p = distiller.model_find_param(self.model, conv_pname) msglogger.info("ADC: removing %.1f%% %s from %s" % (fraction_to_prune*100, prune_what, conv_pname)) if prune_what == "channels": calculate_sparsity = distiller.sparsity_ch remove_structures = distiller.remove_channels group_type = "Channels" elif prune_what == "filters": calculate_sparsity = distiller.sparsity_3D group_type = "Filters" remove_structures = distiller.remove_filters else: raise ValueError("unsupported structure {}".format(prune_what)) # Create a channel-ranking pruner pruner = distiller.pruning.L1RankedStructureParameterPruner("adc_pruner", group_type, fraction_to_prune, conv_pname) pruner.set_param_mask(conv_p, conv_pname, self.zeros_mask_dict, meta=None) if (self.zeros_mask_dict[conv_pname].mask is None or calculate_sparsity(self.zeros_mask_dict[conv_pname].mask) == 0): msglogger.info("__remove_structures: aborting because there are no channels to prune") return 0 # Use the mask to prune self.zeros_mask_dict[conv_pname].apply_mask(conv_p) if PERFORM_THINNING: remove_structures(self.model, self.zeros_mask_dict, self.app_args.dataset, self.app_args.dataset, optimizer=None) conv_p = distiller.model_find_param(self.model, conv_pname) return distiller.volume(conv_p) / layer.weights_vol actual_sparsity = calculate_sparsity(conv_p) return actual_sparsity
def conv_visitor(self, input, output, df, model, memo): assert isinstance(self, torch.nn.Conv2d) if self in memo: return weights_vol = (self.out_channels * self.in_channels * self.kernel_size[0] * self.kernel_size[1]) # Multiply-accumulate operations: MACs = volume(OFM) * (#IFM * K^2) / #Groups # Bias is ignored macs = distiller.volume(output) * (self.in_channels / self.groups * self.kernel_size[0] * self.kernel_size[1]) attrs = "k=" + "(" + (", ").join(["%d" % v for v in self.kernel_size]) + ")" module_visitor(self, input, output, df, model, weights_vol, macs, attrs)
def rank_and_prune_blocks(fraction_to_prune, param, param_name=None, zeros_mask_dict=None, model=None, binary_map=None, block_shape=None, magnitude_fn=distiller.norms.l1_norm, group_size=1): """Block-wise pruning for 4D tensors. The block shape is specified using a tuple: [block_repetitions, block_depth, block_height, block_width]. The dimension 'block_repetitions' specifies in how many consecutive filters the "basic block" (shaped as [block_depth, block_height, block_width]) repeats to produce a (4D) "super block". For example: block_pruner: class: L1RankedStructureParameterPruner_AGP initial_sparsity : 0.05 final_sparsity: 0.70 group_type: Blocks kwargs: block_shape: [1,8,1,1] # [block_repetitions, block_depth, block_height, block_width] Currently the only supported block shape is: block_repetitions x block_depth x 1 x 1 """ if len(block_shape) != 4: raise ValueError("The block shape must be specified as a 4-element tuple") block_repetitions, block_depth, block_height, block_width = block_shape if not block_width == block_height == 1: raise ValueError("Currently the only supported block shape is: block_repetitions x block_depth x 1 x 1") super_block_volume = distiller.volume(block_shape) num_super_blocks = distiller.volume(param) / super_block_volume if distiller.volume(param) % super_block_volume != 0: raise ValueError("The super-block size must divide the weight tensor exactly.") num_filters, num_channels = param.size(0), param.size(1) kernel_size = param.size(2) * param.size(3) if block_depth > 1: view_dims = (num_filters*num_channels//(block_repetitions*block_depth), block_repetitions*block_depth, kernel_size,) else: view_dims = (num_filters // block_repetitions, block_repetitions, -1,) def rank_blocks(fraction_to_prune, param): # Create a view where each block is a column view1 = param.view(*view_dims) # Next, compute the sums of each column (block) block_mags = magnitude_fn(view1, dim=1) block_mags = block_mags.view(-1) # flatten k = int(fraction_to_prune * block_mags.size(0)) if k == 0: msglogger.info("Too few blocks (%d)- can't prune %.1f%% blocks", block_mags.size(0), 100*fraction_to_prune) return None, None bottomk, _ = torch.topk(block_mags, k, largest=False, sorted=True) return bottomk, block_mags def binary_map_to_mask(binary_map, param): a = binary_map.view(view_dims[0], view_dims[2]) c = a.unsqueeze(1) d = c.expand(*view_dims).contiguous() return d.view(num_filters, num_channels, param.size(2), param.size(3)) if binary_map is None: bottomk_blocks, block_mags = rank_blocks(fraction_to_prune, param) if bottomk_blocks is None: # Empty list means that fraction_to_prune is too low to prune anything return threshold = bottomk_blocks[-1] binary_map = block_mags.gt(threshold).type(param.data.type()) if zeros_mask_dict is not None: zeros_mask_dict[param_name].mask = binary_map_to_mask(binary_map, param) msglogger.info("%sRankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)", magnitude_fn, param_name, distiller.sparsity_blocks(zeros_mask_dict[param_name].mask, block_shape=block_shape), fraction_to_prune, binary_map.sum().item(), num_super_blocks) return binary_map