def _get_quant_module(model, node): if are_inputs_unsigned(model, node, []): quant_module = qnn.QuantIdentity(Uint8ActPerTensorFixedPoint, return_quant_tensor=True) else: quant_module = qnn.QuantIdentity(Int8ActPerTensorFixedPoint, return_quant_tensor=True) quant_module_name = node.name + '_quant' model.add_module(quant_module_name, quant_module) return quant_module, quant_module_name
def __init__(self, negative_slope, alpha_quant=qnn.QuantIdentity(Uint8ActPerTensorFixedPoint, bit_width=16), input_quant=qnn.QuantIdentity(Int8ActPerTensorFixedPoint, bit_width=16, scaling_stats_momentum=None), output_quant=qnn.QuantIdentity(Int8ActPerTensorFixedPoint, return_quant_tensor=True)): super(FlexMLQuantLeakyReLU, self).__init__() self.alpha_quant = alpha_quant self.input_quant = input_quant self.output_quant = output_quant self.negative_slope = StatelessBuffer(torch.tensor(negative_slope))
def flexml_inp_placeholder_handler(model): rewriters = [] for node in model.graph.nodes: if node.op == 'placeholder': inp_quant = qnn.QuantIdentity(Int8ActPerTensorFixedPoint, return_quant_tensor=True) name = node.name + '_quant' model.add_module(name, inp_quant) rewriters.append(InsertModuleCallAfter(name, node)) for rewriter in rewriters: model = rewriter.apply(model) return model
def output_quant_handler(model, node, rewriters, is_sign_preserving): quant_module = None quant_module_name = None for user in node.users: output_quant = True if user.op == 'call_module': user_module = get_module(model, user.target) if isinstance( user_module, (qnn.QuantReLU, qnn.QuantIdentity, FlexMLQuantLeakyReLU)): output_quant = False if output_quant: if quant_module_name is None and quant_module is None: if is_sign_preserving and are_inputs_unsigned(model, node, []): quant_module = qnn.QuantIdentity( act_quant=Uint8ActPerTensorFixedPoint, return_quant_tensor=True) else: quant_module = qnn.QuantIdentity( act_quant=Int8ActPerTensorFixedPoint, return_quant_tensor=True) quant_module_name = node.name + '_output_quant' model.add_module(quant_module_name, quant_module) rewriters.append(InsertModuleCallAfter(quant_module_name, node))
def __init__(self, VGG_type='A', batch_norm=False, bit_width=8, num_classes=1000, pretrained_model=None): super(QuantVGG, self).__init__() self.logger = get_logger(name=("{}{}".format(__name__, dist.get_rank()) if dist.is_initialized() else __name__)) self.inp_quant = qnn.QuantIdentity(bit_width=bit_width, act_quant=INPUT_QUANTIZER, return_quant_tensor=RETURN_QUANT_TENSOR) self.features = make_layers(cfgs[VGG_type], batch_norm, bit_width) self.avgpool = qnn.QuantAdaptiveAvgPool2d((7, 7)) self.classifier = nn.Sequential( qnn.QuantLinear(512 * 7 * 7, 4096, bias=True, bias_quant=BIAS_QUANTIZER, weight_quant=WEIGHT_QUANTIZER, weight_bit_width=bit_width, weight_scaling_min_val=SCALING_MIN_VAL, return_quant_tensor=RETURN_QUANT_TENSOR), qnn.QuantReLU(bit_width=bit_width, act_quant=ACT_QUANTIZER, return_quant_tensor=RETURN_QUANT_TENSOR), qnn.QuantDropout(), qnn.QuantLinear(4096, 4096, bias=True, bias_quant=BIAS_QUANTIZER, weight_quant=WEIGHT_QUANTIZER, weight_bit_width=bit_width, weight_scaling_min_val=SCALING_MIN_VAL, return_quant_tensor=RETURN_QUANT_TENSOR), qnn.QuantReLU(bit_width=bit_width, act_quant=ACT_QUANTIZER, return_quant_tensor=RETURN_QUANT_TENSOR), nn.Dropout(), qnn.QuantLinear(4096, num_classes, bias=False, weight_quant=WEIGHT_QUANTIZER, weight_scaling_min_val=SCALING_MIN_VAL, weight_bit_width=bit_width, return_quant_tensor=False), ) self.classifier[0].cache_inference_quant_bias = True self.classifier[3].cache_inference_quant_bias = True self.classifier[6].cache_inference_quant_bias = True if is_master(): print_config(self.logger) if pretrained_model == None: self._initialize_weights() else: pre_model = None if pretrained_model == 'pytorch': self.logger.info( "Initializing with pretrained model from PyTorch") # use pytorch's pretrained model pre_model = models.vgg16(pretrained=True) else: pre_model = VGG_net(VGG_type=VGG_type, batch_norm=batch_norm, num_classes=num_classes) loaded_model = torch.load(pretrained_model)['state_dict'] # check if model was trained using DataParallel, keys() return 'odict_keys' which does not support indexing if next(iter(loaded_model.keys())).startswith('module'): # if model is trained w/ DataParallel it's warraped under module pre_model = torch.nn.DataParallel(pre_model) pre_model.load_state_dict(loaded_model) unwrapped_sd = pre_model.module.state_dict() pre_model = VGG_net(VGG_type=VGG_type, batch_norm=batch_norm, num_classes=num_classes) pre_model.load_state_dict(unwrapped_sd) else: pre_model.load_state_dict(loaded_model) self._initialize_custom_weights(pre_model) self.logger.info("Initialization Done")