def draw_lang_model_to_file(model, png_fname): """Draw a language model graph to a PNG file. Caveat: the PNG that is produced has some problems, which we suspect are due to PyTorch issues related to RNN ONNX export. """ try: # if dataset == 'wikitext2': batch_size = 1 seq_len = 255 dummy_input = torch.LongTensor(seq_len * batch_size).zero_().view(-1, batch_size, 1).to(device) hidden = model.init_hidden(batch_size) dummy_input = (dummy_input, hidden) # else: # msglogger.info("Unsupported dataset (%s) - aborting draw operation" % dataset) # return g = distiller.SummaryGraph(model, dummy_input) distiller.draw_model_to_file(g, png_fname) msglogger.info("Network PNG image generation completed") except FileNotFoundError as e: msglogger.info("An error has occured while generating the network PNG image.") msglogger.info("Please check that you have graphviz installed.") msglogger.info("\t$ sudo apt-get install graphviz") raise e
def prepare_model(self, dummy_input=None): """ Traverses the model and replaces sub-modules with quantized counterparts according to the bit-width and overrides configuration provided to __init__(), and according to the replacement_factory as defined by the Quantizer sub-class being used. Note: If multiple sub-modules within the model actually reference the same module, then that module is replaced only once, according to the configuration (bit-width and/or overrides) of the first encountered reference. Toy Example - say a module is constructed using this bit of code: shared_relu = nn.ReLU self.relu1 = shared_relu self.relu2 = shared_relu When traversing the model, a replacement will be generated when 'self.relu1' is encountered. Let's call it `new_relu1'. When 'self.relu2' will be encountered, it'll simply be replaced with a reference to 'new_relu1'. Any override configuration made specifically for 'self.relu2' will be ignored. A warning message will be shown. """ msglogger.info('Preparing model for quantization using {0}'.format(self.__class__.__name__)) self.model.quantizer_metadata["dummy_input"] = dummy_input if dummy_input is not None: summary_graph = distiller.SummaryGraph(self.model, dummy_input) self.adjacency_map = summary_graph.adjacency_map(dedicated_modules_only=False) self._pre_prepare_model(dummy_input) self._pre_process_container(self.model) for module_name, module in self.model.named_modules(): qbits = self.module_qbits_map[module_name] curr_parameters = dict(module.named_parameters()) for param_name, param in curr_parameters.items(): n_bits = qbits.bias if param_name.endswith('bias') else qbits.wts if n_bits is None: continue fp_attr_name = param_name if self.train_with_fp_copy: hack_float_backup_parameter(module, param_name, n_bits) fp_attr_name = FP_BKP_PREFIX + param_name self.params_to_quantize.append(_ParamToQuant(module, module_name, fp_attr_name, param_name, n_bits)) param_full_name = '.'.join([module_name, param_name]) msglogger.info( "Parameter '{0}' will be quantized to {1} bits".format(param_full_name, n_bits)) # If an optimizer was passed, assume we need to update it if self.optimizer: optimizer_type = type(self.optimizer) new_optimizer = optimizer_type(self._get_updated_optimizer_params_groups(), **self.optimizer.defaults) self.optimizer.__setstate__({'param_groups': new_optimizer.param_groups}) self._post_prepare_model() msglogger.info('Quantized model:\n\n{0}\n'.format(self.model))
def on_epoch_begin(self, model, zeros_mask_dict, meta, **kwargs): msglogger.debug("Pruner {} is about to prune".format(self.pruner.name)) self.is_last_epoch = meta['current_epoch'] == (meta['ending_epoch'] - 1) if self.levels is not None: self.pruner.levels = self.levels meta['model'] = model is_initialized = self.is_initialized if self.fold_bn: # Cache this information (required for BN-folding) to improve performance self.named_modules = OrderedDict(model.named_modules()) dummy_input = torch.randn(model.input_shape) self.sg = distiller.SummaryGraph(model, dummy_input) for param_name, param in model.named_parameters(): if self.fold_bn: param = self._fold_batchnorm(model, param_name, param, self.named_modules, self.sg) if not is_initialized: # Initialize the maskers masker = zeros_mask_dict[param_name] masker.use_double_copies = self.use_double_copies masker.mask_on_forward_only = self.mask_on_forward_only # register for the backward hook of the parameters if self.mask_gradients: masker.backward_hook_handle = param.register_hook( masker.mask_gradient) self.is_initialized = True if not self.skip_first_minibatch: self.pruner.set_param_mask(param, param_name, zeros_mask_dict, meta) else: self.pruner.set_param_mask(param, param_name, zeros_mask_dict, meta)
def fuse_modules(model, types_sequence, fuse_fn, dummy_input=None, adjacency_map=None): """ Scans the module for sequences of modules of the specified types and "fuses" them. As an example, consider the following sequence of 3 modules: 'm_1' --> 'm_2' --> 'm_3'. Assuming they match the specified sequence of types, they will be fused such that the "fused" module replaces 'm_1', and 'm_2' and 'm_3' are replaced with identity operations. For a sequence of modules to be fused, it must not contain splits. That is - no module in the sequence can have more than a single output. For example, consider the following sequence: m_1 --> m_2 --> m_3 | | --> m_4 Even if m_1, m_2 and m_3 match the types sequence, they can't be fused because m_1's output also goes to m_4. The fused module is generated by the user specified function 'fuse_fn'. To infer the order of modules it is required to perform a forward pass on the model. Hence the need to pass the expected input shape. Args: model (nn.Module): Model instance on which the transformation is performed types_sequence (list or tuple): Sequence of module types. Each item in the sequence may itself be a list / tuple. For example - to fuse all possible convolution types with ReLU, pass: [[nn.Conv1d, nn.Conv2d, nn.Conv3d], nn.ReLU] fuse_fn (function): Function that takes a list of models to be fused, and returns a single fused module. If the sequence cannot be fused, this function should return None dummy_input (torch.Tensor or tuple): Dummy input to the model. Required if summary_graph is None adjacency_map (OrderedDict): Pre-computed adjacency map, via SummaryGraph.adjacency_map(). Must be based on the passed model, otherwise results are unexpected. If None, then the adjacency map will be created internally using the passed dummy_input. """ distiller.assign_layer_fq_names(model) if adjacency_map is None: if dummy_input is None: raise ValueError( 'Must pass either valid adjacency map instance or valid dummy input' ) summary_graph = distiller.SummaryGraph(model, dummy_input) adjacency_map = summary_graph.adjacency_map( dedicated_modules_only=False) named_modules = OrderedDict(model.named_modules()) in_sequence_idx = 0 curr_sequence = [] for node_name, adj_entry in adjacency_map.items(): module = named_modules.get(node_name, None) if module is None: reset = True else: reset = False if isinstance(module, types_sequence[in_sequence_idx]): curr_sequence.append(module) in_sequence_idx += 1 if in_sequence_idx == len(types_sequence): _fuse_sequence(curr_sequence, named_modules, fuse_fn) reset = True elif len(adj_entry.successors) > 1: msglogger.debug( node_name + " is connected to multiple outputs, not fuse-able") reset = True elif isinstance(module, types_sequence[0]): # Current module breaks the current sequence, check if it's the start of a new sequence in_sequence_idx = 1 curr_sequence = [module] else: reset = True if reset: in_sequence_idx = 0 curr_sequence = [] return model
def create_graph(model): if input_tensor is not None: dummy_input = input_tensor else: dummy_input = torch.randn(16, 3, 32, 32) return distiller.SummaryGraph(model, dummy_input)