def test_load_gpu_model_on_cpu_with_thinning(): # Issue #148 # 1. create a GPU model and remove 50% of the filters in one of the layers (thninning) # 2. save the thinned model in a checkpoint file # 3. load the checkpoint and place it on the CPU CPU_DEVICE_ID = -1 gpu_model = create_model(False, 'cifar10', 'resnet20_cifar') conv_pname = "module.layer1.0.conv1.weight" conv_p = distiller.model_find_param(gpu_model, conv_pname) pruner = distiller.pruning.L1RankedStructureParameterPruner("test_pruner", group_type="Filters", desired_sparsity=0.5, weights=conv_pname) zeros_mask_dict = distiller.create_model_masks_dict(gpu_model) pruner.set_param_mask(conv_p, conv_pname, zeros_mask_dict, meta=None) # Use the mask to prune zeros_mask_dict[conv_pname].apply_mask(conv_p) distiller.remove_filters(gpu_model, zeros_mask_dict, 'resnet20_cifar', 'cifar10', optimizer=None) assert hasattr(gpu_model, 'thinning_recipes') scheduler = distiller.CompressionScheduler(gpu_model) save_checkpoint(epoch=0, arch='resnet20_cifar', model=gpu_model, scheduler=scheduler, optimizer=None) CPU_DEVICE_ID = -1 cpu_model = create_model(False, 'cifar10', 'resnet20_cifar', device_ids=CPU_DEVICE_ID) load_checkpoint(cpu_model, "checkpoint.pth.tar") assert distiller.model_device(cpu_model) == 'cpu'
def export_img_classifier_to_onnx(model, onnx_fname, dataset, add_softmax=True, **kwargs): """Export a PyTorch image classifier to ONNX. Args: add_softmax: when True, adds softmax layer to the output model. kwargs: arguments to be passed to torch.onnx.export """ dummy_input = distiller.get_dummy_input(dataset, distiller.model_device(model)) # Pytorch doesn't support exporting modules wrapped in DataParallel non_para_model = distiller.make_non_parallel_copy(model) try: if add_softmax: # Explicitly add a softmax layer, because it is needed for the ONNX inference phase. # TorchVision models use nn.CrossEntropyLoss for computing the loss, # instead of adding a softmax layer non_para_model.original_forward = non_para_model.forward softmax = torch.nn.Softmax(dim=-1) non_para_model.forward = lambda input: softmax( non_para_model.original_forward(input)) torch.onnx.export(non_para_model, dummy_input, onnx_fname, **kwargs) msglogger.info("Exported the model to ONNX format at %s" % os.path.realpath(onnx_fname)) finally: del non_para_model
def model_summary(model, what, dataset=None): if what.startswith('png'): draw_img_classifier_to_file(model, 'model.png', dataset, what == 'png_w_params') elif what == 'sparsity': pylogger = PythonLogger(msglogger) csvlogger = CsvLogger() distiller.log_weights_sparsity(model, -1, loggers=[pylogger, csvlogger]) elif what == 'compute': try: dummy_input = distiller.get_dummy_input(dataset, distiller.model_device(model)) except ValueError as e: print(e) return df = model_performance_summary(model, dummy_input, 1) t = tabulate(df, headers='keys', tablefmt='psql', floatfmt=".5f") total_macs = df['MACs'].sum() print(t) print("Total MACs: " + "{:,}".format(total_macs)) elif what == 'model': # print the simple form of the model print(model) elif what == 'modules': # Print the names of non-leaf modules # Remember that in PyTorch not every node is a module (e.g. F.relu). # Also remember that parameterless modules, like nn.MaxPool2d, can be used multiple # times in the same model, but they will only appear once in the modules list. nodes = [] for name, module in model.named_modules(): # Only print leaf modules if len(module._modules) == 0: nodes.append([name, module.__class__.__name__]) print(tabulate(nodes, headers=['Name', 'Type'])) else: raise ValueError("%s is not a supported summary type" % what)
def model_performance_summary(model, dummy_input, batch_size=1): """Collect performance data""" def install_perf_collector(m): if isinstance(m, torch.nn.Conv2d): hook_handles.append( m.register_forward_hook( partial(conv_visitor, df=df, model=model, memo=memo))) elif isinstance(m, torch.nn.Linear): hook_handles.append( m.register_forward_hook( partial(fc_visitor, df=df, model=model, memo=memo))) df = pd.DataFrame(columns=[ 'Name', 'Type', 'Attrs', 'IFM', 'IFM volume', 'OFM', 'OFM volume', 'Weights volume', 'MACs' ]) hook_handles = [] memo = [] model = distiller.make_non_parallel_copy(model) model.apply(install_perf_collector) # Now run the forward path and collect the data dummy_input = dummy_input.to(distiller.model_device(model)) model(dummy_input) # Unregister from the forward hooks for handle in hook_handles: handle.remove() return df
def draw_img_classifier_to_file(model, png_fname, dataset=None, display_param_nodes=False, rankdir='TB', styles=None, input_shape=None): """Draw a PyTorch image classifier to a PNG file. This a helper function that simplifies the interface of draw_model_to_file(). Args: model: PyTorch model instance png_fname (string): PNG file name dataset (string): one of 'imagenet' or 'cifar10'. This is required in order to create a dummy input of the correct shape. display_param_nodes (boolean): if True, draw the parameter nodes rankdir: diagram direction. 'TB'/'BT' is Top-to-Bottom/Bottom-to-Top 'LR'/'R/L' is Left-to-Rt/Rt-to-Left styles: a dictionary of styles. Key is module name. Value is a legal pydot style dictionary. For example: styles['conv1'] = {'shape': 'oval', 'fillcolor': 'gray', 'style': 'rounded, filled'} input_shape (tuple): List of integers representing the input shape. Used only if 'dataset' is None """ dummy_input = distiller.get_dummy_input(dataset=dataset, device=distiller.model_device(model), input_shape=input_shape) try: non_para_model = distiller.make_non_parallel_copy(model) g = SummaryGraph(non_para_model, dummy_input) draw_model_to_file(g, png_fname, display_param_nodes, rankdir, styles) print("Network PNG image generation completed") except FileNotFoundError: print("An error has occured while generating the network PNG image.") print("Please check that you have graphviz installed.") print("\t$ sudo apt-get install graphviz") finally: del non_para_model
def load_checkpoint(model, chkpt_file, optimizer=None): """Load a pytorch training checkpoint Args: model: the pytorch model to which we will load the parameters chkpt_file: the checkpoint file optimizer: the optimizer to which we will load the serialized state """ compression_scheduler = None start_epoch = 0 if os.path.isfile(chkpt_file): msglogger.info("=> loading checkpoint %s", chkpt_file) checkpoint = torch.load(chkpt_file, map_location=lambda storage, loc: storage) msglogger.info("Checkpoint keys:\n{}".format("\n\t".join( k for k in checkpoint.keys()))) start_epoch = checkpoint['epoch'] + 1 best_top1 = checkpoint.get('best_top1', None) if best_top1 is not None: msglogger.info(" best top@1: %.3f", best_top1) if 'compression_sched' in checkpoint: compression_scheduler = distiller.CompressionScheduler(model) compression_scheduler.load_state_dict( checkpoint['compression_sched'], distiller.model_device(model)) msglogger.info( "Loaded compression schedule from checkpoint (epoch %d)", checkpoint['epoch']) else: msglogger.info( "Warning: compression schedule data does not exist in the checkpoint" ) if 'thinning_recipes' in checkpoint: if 'compression_sched' not in checkpoint: raise KeyError( "Found thinning_recipes key, but missing mandatory key compression_sched" ) msglogger.info("Loaded a thinning recipe from the checkpoint") # Cache the recipes in case we need them later model.thinning_recipes = checkpoint['thinning_recipes'] distiller.execute_thinning_recipes_list( model, compression_scheduler.zeros_mask_dict, model.thinning_recipes) if 'quantizer_metadata' in checkpoint: msglogger.info('Loaded quantizer metadata from the checkpoint') qmd = checkpoint['quantizer_metadata'] quantizer = qmd['type'](model, **qmd['params']) quantizer.prepare_model() msglogger.info("=> loaded checkpoint '%s' (epoch %d)", chkpt_file, checkpoint['epoch']) model.load_state_dict(checkpoint['state_dict']) return model, compression_scheduler, start_epoch else: raise IOError(ENOENT, 'Could not find a checkpoint file at', chkpt_file)
def test_load_gpu_model_on_cpu(): # Issue #148 CPU_DEVICE_ID = -1 model = create_model(False, 'cifar10', 'resnet20_cifar', device_ids=CPU_DEVICE_ID) model, compression_scheduler, start_epoch = load_checkpoint(model, '../examples/ssl/checkpoints/checkpoint_trained_dense.pth.tar') assert compression_scheduler is not None assert start_epoch == 180 assert distiller.model_device(model) == 'cpu'
def test_load_gpu_model_on_cpu_lean_checkpoint(): CPU_DEVICE_ID = -1 CPU_DEVICE_NAME = 'cpu' checkpoint_filename = '../examples/ssl/checkpoints/checkpoint_trained_dense.pth.tar' model = create_model(False, 'cifar10', 'resnet20_cifar', device_ids=CPU_DEVICE_ID) model = load_lean_checkpoint(model, checkpoint_filename, model_device=CPU_DEVICE_NAME) assert distiller.model_device(model) == CPU_DEVICE_NAME
def test_load_gpu_model_on_cpu_lean_checkpoint(): CPU_DEVICE_ID = -1 checkpoint_filename = '../examples/ssl/checkpoints/checkpoint_trained_dense.pth.tar' model = create_model(False, 'cifar10', 'resnet20_cifar', device_ids=CPU_DEVICE_ID) model, compression_scheduler, optimizer, start_epoch, train_steps = load_checkpoint( model, checkpoint_filename, lean_checkpoint=True) assert compression_scheduler is None assert optimizer is None assert distiller.model_device(model) == 'cpu'
def create_graph(dataset, model): dummy_input = None if dataset == 'imagenet': dummy_input = torch.randn((1, 3, 224, 224), requires_grad=False) elif dataset == 'cifar10': dummy_input = torch.randn((1, 3, 32, 32), requires_grad=False) assert dummy_input is not None, "Unsupported dataset ({}) - aborting draw operation".format( dataset) dummy_input = dummy_input.to(distiller.model_device(model)) return SummaryGraph(model, dummy_input)
def test_load_gpu_model_on_cpu(): # Issue #148 CPU_DEVICE_ID = -1 CPU_DEVICE_NAME = 'cpu' checkpoint_filename = 'checkpoints/resnet20_cifar10_checkpoint.pth.tar' model = create_model(False, 'cifar10', 'resnet20_cifar', device_ids=CPU_DEVICE_ID) model, compression_scheduler, optimizer, start_epoch = load_checkpoint( model, checkpoint_filename) assert compression_scheduler is not None assert optimizer is not None assert distiller.utils.optimizer_device_name(optimizer) == CPU_DEVICE_NAME assert start_epoch == 1 assert distiller.model_device(model) == CPU_DEVICE_NAME
def prepare_model(self, dummy_input=None): """ Traverses the model and replaces sub-modules with quantized counterparts according to the bit-width and overrides configuration provided to __init__(), and according to the replacement_factory as defined by the Quantizer sub-class being used. Note: If multiple sub-modules within the model actually reference the same module, then that module is replaced only once, according to the configuration (bit-width and/or overrides) of the first encountered reference. Toy Example - say a module is constructed using this bit of code: shared_relu = nn.ReLU self.relu1 = shared_relu self.relu2 = shared_relu When traversing the model, a replacement will be generated when 'self.relu1' is encountered. Let's call it `new_relu1'. When 'self.relu2' will be encountered, it'll simply be replaced with a reference to 'new_relu1'. Any override configuration made specifically for 'self.relu2' will be ignored. A warning message will be shown. """ msglogger.info('Preparing model for quantization using {0}'.format( self.__class__.__name__)) self.model.quantizer_metadata["dummy_input"] = dummy_input if dummy_input is not None: summary_graph = distiller.SummaryGraph(self.model, dummy_input) self.adjacency_map = summary_graph.adjacency_map( dedicated_modules_only=False) model_device = distiller.model_device(self.model) self._pre_prepare_model(dummy_input) self._pre_process_container(self.model) for module_name, module in self.model.named_modules(): qbits = self.module_qbits_map[module_name] curr_parameters = dict(module.named_parameters()) for param_name, param in curr_parameters.items(): n_bits = qbits.bias if param_name.endswith( 'bias') else qbits.wts if n_bits is None: continue fp_attr_name = param_name if self.train_with_fp_copy: hack_float_backup_parameter(module, param_name, n_bits) fp_attr_name = FP_BKP_PREFIX + param_name self.params_to_quantize.append( _ParamToQuant(module, module_name, fp_attr_name, param_name, n_bits)) param_full_name = '.'.join([module_name, param_name]) msglogger.debug( "Parameter '{0}' will be quantized to {1} bits".format( param_full_name, n_bits)) # If an optimizer was passed, assume we need to update it if self.optimizer: for pg in self._get_new_optimizer_params_groups(): self.optimizer.add_param_group(pg) self._post_prepare_model() # Re-transfer model to the device it was on, in case the quantizer created new parameters/buffers self.model.to(model_device) distiller.assign_layer_fq_names(self.model) self.prepared = True msglogger.debug('Quantized model:\n\n{0}\n'.format(self.model))
def arbitrary_channel_pruning(config, channels_to_remove, is_parallel): """Test removal of arbitrary channels. The test receives a specification of channels to remove. Based on this specification, the channels are pruned and then physically removed from the model (via a "thinning" process). """ model, zeros_mask_dict = common.setup_test(config.arch, config.dataset, is_parallel) pair = config.module_pairs[0] conv2 = common.find_module_by_name(model, pair[1]) assert conv2 is not None # Test that we can access the weights tensor of the first convolution in layer 1 conv2_p = distiller.model_find_param(model, pair[1] + ".weight") assert conv2_p is not None assert conv2_p.dim() == 4 num_channels = conv2_p.size(1) cnt_nnz_channels = num_channels - len(channels_to_remove) mask = create_channels_mask(conv2_p, channels_to_remove) assert distiller.density_ch(mask) == ( conv2.in_channels - len(channels_to_remove)) / conv2.in_channels # Cool, so now we have a mask for pruning our channels. # Use the mask to prune zeros_mask_dict[pair[1] + ".weight"].mask = mask zeros_mask_dict[pair[1] + ".weight"].apply_mask(conv2_p) all_channels = set([ch for ch in range(num_channels)]) nnz_channels = set(distiller.non_zero_channels(conv2_p)) channels_removed = all_channels - nnz_channels logger.info("Channels removed {}".format(channels_removed)) # Now, let's do the actual network thinning distiller.remove_channels(model, zeros_mask_dict, config.arch, config.dataset, optimizer=None) conv1 = common.find_module_by_name(model, pair[0]) assert conv1 assert conv1.out_channels == cnt_nnz_channels assert conv2.in_channels == cnt_nnz_channels assert conv1.weight.size(0) == cnt_nnz_channels assert conv2.weight.size(1) == cnt_nnz_channels if config.bn_name is not None: bn1 = common.find_module_by_name(model, config.bn_name) assert bn1.running_var.size(0) == cnt_nnz_channels assert bn1.running_mean.size(0) == cnt_nnz_channels assert bn1.num_features == cnt_nnz_channels assert bn1.bias.size(0) == cnt_nnz_channels assert bn1.weight.size(0) == cnt_nnz_channels dummy_input = distiller.get_dummy_input(config.dataset, distiller.model_device(model)) optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=0.1) run_forward_backward(model, optimizer, dummy_input) # Let's test saving and loading a thinned model. # We save 3 times, and load twice, to make sure to cover some corner cases: # - Make sure that after loading, the model still has hold of the thinning recipes # - Make sure that after a 2nd load, there no problem loading (in this case, the # tensors are already thin, so this is a new flow) # (1) save_checkpoint(epoch=0, arch=config.arch, model=model, optimizer=None) model_2 = create_model(False, config.dataset, config.arch, parallel=is_parallel) model(dummy_input) model_2(dummy_input) conv2 = common.find_module_by_name(model_2, pair[1]) assert conv2 is not None model_2 = load_lean_checkpoint(model_2, 'checkpoint.pth.tar') assert hasattr(model_2, 'thinning_recipes') run_forward_backward(model, optimizer, dummy_input) # (2) compression_scheduler = distiller.CompressionScheduler(model) save_checkpoint(epoch=0, arch=config.arch, model=model, optimizer=None, scheduler=compression_scheduler) model_2 = load_lean_checkpoint(model_2, 'checkpoint.pth.tar') assert hasattr(model_2, 'thinning_recipes') logger.info("test_arbitrary_channel_pruning - Done") # (3) save_checkpoint(epoch=0, arch=config.arch, model=model_2, optimizer=None, scheduler=compression_scheduler) model_2 = load_lean_checkpoint(model_2, 'checkpoint.pth.tar') assert hasattr(model_2, 'thinning_recipes') logger.info("test_arbitrary_channel_pruning - Done 2")
def create_graph(input_shape, model): dummy_input = distiller.get_dummy_input( device=distiller.model_device(model), input_shape=input_shape) return SummaryGraph(model, dummy_input)
def create_graph(dataset, model, arch=None): dummy_input = distiller.get_dummy_input(dataset, distiller.model_device(model), model_name=arch) return SummaryGraph(model, dummy_input)
def __init__(self, model, dummy_input, apply_scope_name_workarounds=True): self._src_model = model model_clone = distiller.make_non_parallel_copy(model) # Switch all instances of torch.nn.ModuleList in the model to our DistillerModuleList # See documentation of _DistillerModuleList class for details on why this is done model_clone, converted_module_names_map = _to_distiller_modulelist(model_clone) with torch.onnx.set_training(model_clone, False): device = distiller.model_device(model_clone) dummy_input = distiller.convert_tensors_recursively_to(dummy_input, device=device) trace, _ = jit.get_trace_graph(model_clone, dummy_input, _force_outplace=True) # As of PyTorch 1.1.0, ONNX trace optimization has two issues that result in incorrect scope names # of nodes in the trace graph. # These can make it impossible, in some cases, to derive the connectivity of the model using the original # module names. So we try to detect these cases and apply workarounds # Issue #1: # Gemm ops (aka "Linear" / "addmm" / "FC") get the scope name of the last non-Gemm node # that came before them. # Note that if the node prior to the Gemm node isn't the result of a dedicated module call, # then this issue doesn't occur. For simplicity we just track all Gemms. # TODO: This should be fixed in PyTorch 1.2.0, revisit when it's released aten_addmm_nodes_scope_names = [] onnx_gemm_count = 0 # Issue #2: # Dropout ops are removed by ONNX trace optimization. However, the op BEFORE the original dropout op # gets the scope name of the dropout op pre_dropout_nodes_scope_names = OrderedDict() prev_non_dropout_op = None for node in trace.graph().nodes(): kind = node.kind() if 'aten' not in kind: continue if kind == 'aten::dropout': if prev_non_dropout_op: pre_dropout_nodes_scope_names[node.scopeName()] = prev_non_dropout_op.scopeName() else: prev_non_dropout_op = node if kind == 'aten::addmm': aten_addmm_nodes_scope_names.append(node.scopeName()) # Let ONNX do the heavy lifting: fusing the convolution nodes; fusing the nodes # composing a GEMM operation; etc. torch.onnx._optimize_trace(trace, torch.onnx.OperatorExportTypes.ONNX) graph = trace.graph() self.ops = OrderedDict() self.module_ops_map = defaultdict(list) self.params = OrderedDict() self.edges = [] self.temp = OrderedDict() in_out = list(graph.inputs()) + list(graph.outputs()) for param in in_out: self.__add_param(param) for node in graph.nodes(): new_op = self.__create_op(node) if apply_scope_name_workarounds: # Here we apply the workaround to the Gemm nodes scope name issue mentioned above if new_op['type'] == 'Gemm': new_op['orig-name'] = aten_addmm_nodes_scope_names[onnx_gemm_count] new_op['name'] = new_op['orig-name'] onnx_gemm_count += 1 # Here we apply the workaround to the issue of dropout op scope name overriding previous op's # scope name if new_op['name'] in pre_dropout_nodes_scope_names: new_op['orig-name'] = pre_dropout_nodes_scope_names[new_op['name']] new_op['name'] = new_op['orig-name'] # Convert the graph node's scope name to a PyTorch module name module_name = onnx_name_2_pytorch_name(new_op['orig-name']) # Get name from before conversion to DistillerModuleList module_name = converted_module_names_map[module_name] if len(module_name) == 0: # Special case where the module name is an empty string - this happens # when the op is called from the "top-level" of the model new_op['name'] = 'top_level_op' else: new_op['name'] = module_name # Save the calling module name in the op dict. Denormalize it so it can # be directly matched with the actual model module_name = distiller.denormalize_module_name(self._src_model, module_name) new_op['module-name'] = module_name # The node's scope name in the graph corresponds to the module from which the op was called. # This means that when ops are invoked from the same module via functional calls or direct # operations on tensors, these ops will have the SAME MODEL NAME associated with them. # For example: # t = t1 + t2 # t = F.relu(t) # In this case the add operation and the ReLU operation will have the same name, which is # derived from the module they're contained in. # # Another case where different ops will have the same module name is when a module is reused: # out = self.conv1(x) # out = self.relu(out) <=== First use of self.relu # out = self.conv2(out) # out = self.relu(out) <=== Second use of self.relu # In this case the graph will have 2 distinct ReLU nodes, with the same scope name. # # Operators with the same name create very confusing graphs (in ResNet, for example), # so we "unroll" them. same_module_cnt = len(self.module_ops_map[module_name]) if same_module_cnt: new_op['name'] += "__" + str(same_module_cnt) self.module_ops_map[module_name].append(new_op['name']) # Finally we register the new op in the ops collection msglogger.debug("new sgraph node - Scope name: {} ; Type: {} ; Display name {}".format( new_op['orig-name'], new_op['type'], new_op['name'])) self.ops[new_op['name']] = new_op for input_ in node.inputs(): self.__add_input(new_op, input_) self.edges.append(SummaryGraph.Edge(input_.uniqueName(), new_op['name'])) for output in node.outputs(): self.__add_output(new_op, output) self.edges.append(SummaryGraph.Edge(new_op['name'], output.uniqueName())) new_op['attrs'] = OrderedDict([(attr_name, node[attr_name]) for attr_name in node.attributeNames()]) self.__merge_pad_avgpool() self.add_macs_attr() self.add_footprint_attr() self.add_arithmetic_intensity_attr() del model_clone
def _create_graph(dataset, model): dummy_input = distiller.get_dummy_input(dataset, distiller.model_device(model)) return SummaryGraph(model, dummy_input)
def ranked_filter_pruning(config, ratio_to_prune, is_parallel, rounding_fn=math.floor): """Test L1 ranking and pruning of filters. First we rank and prune the filters of a Convolutional layer using a L1RankedStructureParameterPruner. Then we physically remove the filters from the model (via "thining" process). """ logger.info("executing: %s (invoked by %s)" % (inspect.currentframe().f_code.co_name, inspect.currentframe().f_back.f_code.co_name)) model, zeros_mask_dict = common.setup_test(config.arch, config.dataset, is_parallel) for pair in config.module_pairs: # Test that we can access the weights tensor of the first convolution in layer 1 conv1_p = distiller.model_find_param(model, pair[0] + ".weight") assert conv1_p is not None num_filters = conv1_p.size(0) # Test that there are no zero-filters assert distiller.sparsity_3D(conv1_p) == 0.0 # Create a filter-ranking pruner pruner = distiller.pruning.L1RankedStructureParameterPruner( "filter_pruner", group_type="Filters", desired_sparsity=ratio_to_prune, weights=pair[0] + ".weight", rounding_fn=rounding_fn) pruner.set_param_mask(conv1_p, pair[0] + ".weight", zeros_mask_dict, meta=None) conv1 = common.find_module_by_name(model, pair[0]) assert conv1 is not None # Test that the mask has the correct fraction of filters pruned. # We asked for 10%, but there are only 16 filters, so we have to settle for 1/16 filters expected_cnt_removed_filters = int(ratio_to_prune * conv1.out_channels) expected_pruning = expected_cnt_removed_filters / conv1.out_channels masker = zeros_mask_dict[pair[0] + ".weight"] assert masker is not None assert distiller.sparsity_3D(masker.mask) == expected_pruning # Use the mask to prune assert distiller.sparsity_3D(conv1_p) == 0 masker.apply_mask(conv1_p) assert distiller.sparsity_3D(conv1_p) == expected_pruning # Remove filters conv2 = common.find_module_by_name(model, pair[1]) assert conv2 is not None assert conv1.out_channels == num_filters assert conv2.in_channels == num_filters # Test thinning distiller.remove_filters(model, zeros_mask_dict, config.arch, config.dataset, optimizer=None) assert conv1.out_channels == num_filters - expected_cnt_removed_filters assert conv2.in_channels == num_filters - expected_cnt_removed_filters # Test the thinned model dummy_input = distiller.get_dummy_input(config.dataset, distiller.model_device(model)) optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=0.1) run_forward_backward(model, optimizer, dummy_input) return model, zeros_mask_dict