def modify_parameter_precision(weights, biases, config, attributes): if config.getboolean('cell', 'binarize_weights'): from snntoolbox.utils.utils import binarize print("Binarizing weights.") weights = binarize(weights) elif config.getboolean('cell', 'quantize_weights'): assert 'Qm.f' in attributes, \ "In the [cell] section of the configuration file, " \ "'quantize_weights' was set to True. For this to " \ "work, the layer needs to specify the fixed point " \ "number format 'Qm.f'." from snntoolbox.utils.utils import reduce_precision m, f = attributes.get('Qm.f') print("Quantizing weights to Q{}.{}.".format(m, f)) weights = reduce_precision(weights, m, f) if attributes.get('quantize_bias', False): biases = reduce_precision(biases, m, f) # These attributes are not needed any longer and would not be # understood by Keras when building the parsed model. attributes.pop('quantize_bias', None) attributes.pop('Qm.f', None) return weights, biases
def parse(self): """Extract the essential information about a neural network. This method serves to abstract the conversion process of a network from the language the input model was built in (e.g. Keras or Lasagne). The methods iterates over all layers of the input model and writes the layer specifications and parameters into `_layer_list`. The keys are chosen in accordance with Keras layer attributes to facilitate instantiation of a new, parsed Keras model (done in a later step by `build_parsed_model`). This function applies several simplifications and adaptations to prepare the model for conversion to spiking. These modifications include: - Removing layers only used during training (Dropout, BatchNormalization, ...) - Absorbing the parameters of BatchNormalization layers into the parameters of the preceeding layer. This does not affect performance because batch-norm-parameters are constant at inference time. - Removing ReLU activation layers, because their function is inherent to the spike generation mechanism. The information which nonlinearity was used in the original model is preserved in the ``activation`` key in `_layer_list`. If the output layer employs the softmax function, a spiking version is used when testing the SNN in INIsim or MegaSim simulators. - Inserting a Flatten layer between Conv and FC layers, if the input model did not explicitly include one. """ layers = self.get_layer_iterable() snn_layers = eval(self.config.get('restrictions', 'snn_layers')) name_map = {} idx = 0 inserted_flatten = False for layer in layers: layer_type = self.get_type(layer) # Absorb BatchNormalization layer into parameters of previous layer if layer_type == 'BatchNormalization': parameters_bn = list(self.get_batchnorm_parameters(layer)) inbound = self.get_inbound_layers_with_parameters(layer) assert len(inbound) == 1, \ "Could not find unique layer with parameters " \ "preceeding BatchNorm layer." prev_layer = inbound[0] prev_layer_idx = name_map[str(id(prev_layer))] parameters = list( self._layer_list[prev_layer_idx]['parameters']) print("Absorbing batch-normalization parameters into " + "parameters of previous {}.".format( self.get_type(prev_layer))) self._layer_list[prev_layer_idx]['parameters'] = \ absorb_bn_parameters(*(parameters + parameters_bn)) if layer_type == 'GlobalAveragePooling2D': print("Replacing GlobalAveragePooling by AveragePooling " "plus Flatten.") pool_size = [layer.input_shape[-2], layer.input_shape[-1]] self._layer_list.append({ 'layer_type': 'AveragePooling2D', 'name': self.get_name(layer, idx, 'AveragePooling2D'), 'input_shape': layer.input_shape, 'pool_size': pool_size, 'inbound': self.get_inbound_names(layer, name_map) }) name_map['AveragePooling2D' + str(idx)] = idx idx += 1 num_str = str(idx) if idx > 9 else '0' + str(idx) shape_string = str(np.prod(layer.output_shape[1:])) self._layer_list.append({ 'name': num_str + 'Flatten_' + shape_string, 'layer_type': 'Flatten', 'inbound': [self._layer_list[-1]['name']] }) name_map['Flatten' + str(idx)] = idx idx += 1 inserted_flatten = True if layer_type not in snn_layers: print("Skipping layer {}.".format(layer_type)) continue if not inserted_flatten: inserted_flatten = self.try_insert_flatten( layer, idx, name_map) idx += inserted_flatten print("Parsing layer {}.".format(layer_type)) if layer_type == 'MaxPooling2D' and \ self.config.getboolean('conversion', 'max2avg_pool'): print("Replacing max by average pooling.") layer_type = 'AveragePooling2D' if inserted_flatten: inbound = [self._layer_list[-1]['name']] inserted_flatten = False else: inbound = self.get_inbound_names(layer, name_map) attributes = self.initialize_attributes(layer) attributes.update({ 'layer_type': layer_type, 'name': self.get_name(layer, idx), 'inbound': inbound }) if layer_type == 'Dense': self.parse_dense(layer, attributes) if layer_type == 'Conv2D': self.parse_convolution(layer, attributes) if layer_type in {'Dense', 'Conv2D'}: weights, bias = attributes['parameters'] if self.config.getboolean('cell', 'binarize_weights'): from snntoolbox.utils.utils import binarize print("Binarizing weights.") weights = binarize(weights) elif self.config.getboolean('cell', 'quantize_weights'): assert 'Qm.f' in attributes, \ "In the [cell] section of the configuration file, "\ "'quantize_weights' was set to True. For this to " \ "work, the layer needs to specify the fixed point " \ "number format 'Qm.f'." from snntoolbox.utils.utils import reduce_precision m, f = attributes.get('Qm.f') print("Quantizing weights to Q{}.{}.".format(m, f)) weights = reduce_precision(weights, m, f) if attributes.get('quantize_bias', False): bias = reduce_precision(bias, m, f) attributes['parameters'] = (weights, bias) # These attributes are not needed any longer and would not be # understood by Keras when building the parsed model. attributes.pop('quantize_bias', None) attributes.pop('Qm.f', None) self.absorb_activation(layer, attributes) if 'Pooling' in layer_type: self.parse_pooling(layer, attributes) if layer_type == 'Concatenate': self.parse_concatenate(layer, attributes) self._layer_list.append(attributes) # Map layer index to layer id. Needed for inception modules. name_map[str(id(layer))] = idx idx += 1 print('')