Example #1
0
def _get_quant_module(model, node):
    if are_inputs_unsigned(model, node, []):
        quant_module = qnn.QuantIdentity(Uint8ActPerTensorFixedPoint,
                                         return_quant_tensor=True)
    else:
        quant_module = qnn.QuantIdentity(Int8ActPerTensorFixedPoint,
                                         return_quant_tensor=True)
    quant_module_name = node.name + '_quant'
    model.add_module(quant_module_name, quant_module)
    return quant_module, quant_module_name
Example #2
0
 def __init__(self,
              negative_slope,
              alpha_quant=qnn.QuantIdentity(Uint8ActPerTensorFixedPoint,
                                            bit_width=16),
              input_quant=qnn.QuantIdentity(Int8ActPerTensorFixedPoint,
                                            bit_width=16,
                                            scaling_stats_momentum=None),
              output_quant=qnn.QuantIdentity(Int8ActPerTensorFixedPoint,
                                             return_quant_tensor=True)):
     super(FlexMLQuantLeakyReLU, self).__init__()
     self.alpha_quant = alpha_quant
     self.input_quant = input_quant
     self.output_quant = output_quant
     self.negative_slope = StatelessBuffer(torch.tensor(negative_slope))
Example #3
0
def flexml_inp_placeholder_handler(model):
    rewriters = []
    for node in model.graph.nodes:
        if node.op == 'placeholder':
            inp_quant = qnn.QuantIdentity(Int8ActPerTensorFixedPoint,
                                          return_quant_tensor=True)
            name = node.name + '_quant'
            model.add_module(name, inp_quant)
            rewriters.append(InsertModuleCallAfter(name, node))
    for rewriter in rewriters:
        model = rewriter.apply(model)
    return model
Example #4
0
def output_quant_handler(model, node, rewriters, is_sign_preserving):
    quant_module = None
    quant_module_name = None
    for user in node.users:
        output_quant = True
        if user.op == 'call_module':
            user_module = get_module(model, user.target)
            if isinstance(
                    user_module,
                (qnn.QuantReLU, qnn.QuantIdentity, FlexMLQuantLeakyReLU)):
                output_quant = False
        if output_quant:
            if quant_module_name is None and quant_module is None:
                if is_sign_preserving and are_inputs_unsigned(model, node, []):
                    quant_module = qnn.QuantIdentity(
                        act_quant=Uint8ActPerTensorFixedPoint,
                        return_quant_tensor=True)
                else:
                    quant_module = qnn.QuantIdentity(
                        act_quant=Int8ActPerTensorFixedPoint,
                        return_quant_tensor=True)
                quant_module_name = node.name + '_output_quant'
                model.add_module(quant_module_name, quant_module)
            rewriters.append(InsertModuleCallAfter(quant_module_name, node))
Example #5
0
    def __init__(self, VGG_type='A', batch_norm=False, bit_width=8, num_classes=1000, pretrained_model=None):
        super(QuantVGG, self).__init__()
        self.logger = get_logger(name=("{}{}".format(__name__, dist.get_rank()) if dist.is_initialized() else __name__))
        self.inp_quant = qnn.QuantIdentity(bit_width=bit_width, act_quant=INPUT_QUANTIZER, return_quant_tensor=RETURN_QUANT_TENSOR)
        self.features = make_layers(cfgs[VGG_type], batch_norm, bit_width)
        self.avgpool = qnn.QuantAdaptiveAvgPool2d((7, 7))
        self.classifier = nn.Sequential(
            qnn.QuantLinear(512 * 7 * 7, 4096,
                            bias=True,
                            bias_quant=BIAS_QUANTIZER,
                            weight_quant=WEIGHT_QUANTIZER,
                            weight_bit_width=bit_width,
                            weight_scaling_min_val=SCALING_MIN_VAL,
                            return_quant_tensor=RETURN_QUANT_TENSOR),
            qnn.QuantReLU(bit_width=bit_width,
                          act_quant=ACT_QUANTIZER,
                          return_quant_tensor=RETURN_QUANT_TENSOR),
            qnn.QuantDropout(),
            qnn.QuantLinear(4096, 4096,
                            bias=True,
                            bias_quant=BIAS_QUANTIZER,
                            weight_quant=WEIGHT_QUANTIZER,
                            weight_bit_width=bit_width,
                            weight_scaling_min_val=SCALING_MIN_VAL,
                            return_quant_tensor=RETURN_QUANT_TENSOR),
            qnn.QuantReLU(bit_width=bit_width,
                          act_quant=ACT_QUANTIZER,
                          return_quant_tensor=RETURN_QUANT_TENSOR),
            nn.Dropout(),
            qnn.QuantLinear(4096, num_classes,
                            bias=False,
                            weight_quant=WEIGHT_QUANTIZER,
                            weight_scaling_min_val=SCALING_MIN_VAL,
                            weight_bit_width=bit_width,
                            return_quant_tensor=False),
        )
        self.classifier[0].cache_inference_quant_bias = True
        self.classifier[3].cache_inference_quant_bias = True
        self.classifier[6].cache_inference_quant_bias = True

        if is_master():
            print_config(self.logger)

        if pretrained_model == None:
            self._initialize_weights()
        else:
            pre_model = None
            if pretrained_model == 'pytorch':
                self.logger.info(
                    "Initializing with pretrained model from PyTorch")
                # use pytorch's pretrained model
                pre_model = models.vgg16(pretrained=True)
            else:
                pre_model = VGG_net(VGG_type=VGG_type, batch_norm=batch_norm, num_classes=num_classes)
                loaded_model = torch.load(pretrained_model)['state_dict']
                # check if model was trained using DataParallel, keys() return 'odict_keys' which does not support indexing
                if next(iter(loaded_model.keys())).startswith('module'):
                    # if model is trained w/ DataParallel it's warraped under module
                    pre_model = torch.nn.DataParallel(pre_model)
                    pre_model.load_state_dict(loaded_model)
                    unwrapped_sd = pre_model.module.state_dict()
                    pre_model = VGG_net(VGG_type=VGG_type, batch_norm=batch_norm, num_classes=num_classes)
                    pre_model.load_state_dict(unwrapped_sd)
                else:
                    pre_model.load_state_dict(loaded_model)
            self._initialize_custom_weights(pre_model)
        self.logger.info("Initialization Done")