示例#1
0
 def from_float(cls, mod):
     if hasattr(mod, "weight_fake_quant"):
         # assert type(mod) == cls.__QAT_MODULE, " nnq." + cls.__name__ + \
         # ".from_float only works for " + cls.__QAT_MODULE.__name__
         if type(mod) == cls._NNIQAT_CONV_BN_MODULE:
             mod.weight, mod.bias = fuse_conv_bn_weights(
                 mod.weight, mod.bias, mod.bn.running_mean,
                 mod.bn.running_var, mod.bn.eps, mod.bn.weight, mod.bn.bias)
         assert hasattr(mod, "activation_post_process"), \
             "Input QAT module must have observer attached"
         weight_post_process = mod.weight_fake_quant
         activation_post_process = mod.activation_post_process
     else:
         assert type(mod) == cls._FLOAT_MODULE, \
             " nnq." + cls.__name__ + ".from_float only works for " + \
             cls._FLOAT_MODULE.__name__ + " but got:" + str(type(mod))
         assert hasattr(mod, "qconfig"), \
             "Input float module must have qconfig defined."
         activation_post_process = None if not hasattr(
             mod,
             "activation_post_process") else mod.activation_post_process
         if type(mod) == cls._NNI_CONV_RELU_MODULE:
             mod = mod[0]
         weight_post_process = mod.qconfig.weight()
     return cls.get_qconv(mod, activation_post_process, weight_post_process)
示例#2
0
    def to_float(self):
        cls = type(self)
        conv = cls._FLOAT_CONV_MODULE(  # type: ignore[attr-defined]
            self.in_channels, self.out_channels, self.kernel_size, self.stride,
            self.padding, self.dilation, self.groups, self.bias is not None,
            self.padding_mode)
        conv.weight = torch.nn.Parameter(self.weight.detach())
        if self.bias is not None:
            conv.bias = torch.nn.Parameter(self.bias.detach())

        if cls._FLOAT_BN_MODULE:  # type: ignore[attr-defined]
            # fuse bn into conv
            conv.weight, conv.bias = fuse_conv_bn_weights(
                conv.weight, conv.bias, self.bn.running_mean,
                self.bn.running_var, self.bn.eps, self.bn.weight, self.bn.bias)

        if cls._FLOAT_RELU_MODULE:  # type: ignore[attr-defined]
            modules = []
            modules.append(conv)
            relu = cls._FLOAT_RELU_MODULE()  # type: ignore[attr-defined]
            modules.append(relu)
            conv_relu = cls._FUSED_FLOAT_MODULE(
                *modules)  # type: ignore[attr-defined]
            conv_relu.train(self.training)
            return conv_relu
        else:
            conv.train(self.training)
            return conv
示例#3
0
文件: conv.py 项目: ytknzw/pytorch
    def from_float(cls, mod):
        r"""Creates a quantized module from a float module or qparams_dict.

        Args:
            mod (Module): a float module, either produced by torch.quantization
              utilities or provided by the user
        """
        if hasattr(mod, 'weight_fake_quant'):
            # assert type(mod) == cls.__QAT_MODULE, ' nnq.' + cls.__name__ + \
            # '.from_float only works for ' + cls.__QAT_MODULE.__name__
            if type(mod) == nniqat.ConvBn2d:
                mod.weight, mod.bias = fuse_conv_bn_weights(
                    mod.weight, mod.bias, mod.bn.running_mean,
                    mod.bn.running_var, mod.bn.eps, mod.bn.weight, mod.bn.bias)
            assert hasattr(mod, 'activation_post_process'), \
                'Input QAT module must have observer attached'
            weight_post_process = mod.weight_fake_quant
            activation_post_process = mod.activation_post_process
        else:
            assert type(mod) == cls._FLOAT_MODULE, \
                ' nnq.' + cls.__name__ + '.from_float only works for ' + \
                cls._FLOAT_MODULE.__name__
            assert hasattr(mod, 'qconfig'), \
                'Input float module must have qconfig defined.'
            # workaround for sequential, ConvReLU2d should probably
            # inherit from Conv2d instead
            if type(mod) == nni.ConvReLU2d:
                activation_post_process = mod[1].activation_post_process
                mod = mod[0]
            else:
                activation_post_process = mod.activation_post_process
            weight_post_process = mod.qconfig.weight()

        return cls.get_qconv(mod, activation_post_process, weight_post_process)
示例#4
0
文件: conv.py 项目: zyl001/pytorch
    def from_float(cls, mod):
        r"""Creates a quantized module from a float module or qparams_dict.

        Args:
            mod (Module): a float module, either produced by torch.quantization
                          utilities or provided by the user
        """
        if hasattr(mod, 'weight_fake_quant'):
            # assert type(mod) == cls.__QAT_MODULE, ' nnq.' + cls.__name__ + '.from_float only works for ' + \
            #     cls.__QAT_MODULE.__name__
            if type(mod) == nniqat.ConvBn2d:
                mod.weight, mod.bias = \
                    fuse_conv_bn_weights(mod.weight, mod.bias, mod.running_mean,
                                         mod.running_var, mod.eps, mod.gamma, mod.beta)
            assert hasattr(
                mod,
                'observer'), 'Input QAT module must have observer attached'
            weight_observer = mod.weight_fake_quant
            activation_observer = mod.observer
        else:
            assert type(mod) == cls._FLOAT_MODULE, ' nnq.' + cls.__name__ + '.from_float only works for ' + \
                cls._FLOAT_MODULE.__name__
            assert hasattr(
                mod, 'qconfig'), 'Input float module must have qconfig defined'
            # workaround for sequential, ConvReLU2d should probably
            # inherit from Conv2d instead
            if type(mod) == nni.ConvReLU2d:
                activation_observer = mod[1].observer
                mod = mod[0]
            else:
                activation_observer = mod.observer
            weight_observer = mod.qconfig.weight()
            weight_observer(mod.weight)
        act_scale, act_zp = activation_observer.calculate_qparams()
        assert weight_observer.dtype == torch.qint8, 'Weight observer must have a dtype of qint8'
        wt_scale, wt_zp = weight_observer.calculate_qparams()
        # Scale bias to activation_scale/2^16, this quantizes bias
        # to about 24 bits of precision
        bias_scale = float(act_scale / (2**16))

        qweight = torch.quantize_linear(mod.weight.float(), float(wt_scale),
                                        int(wt_zp), torch.qint8)
        qconv = cls(mod.in_channels, mod.out_channels, mod.kernel_size,
                    mod.stride, mod.padding, mod.dilation, mod.groups, mod.bias
                    is not None, mod.padding_mode)
        qconv.set_weight(qweight)
        if mod.bias is not None:
            qbias = torch.quantize_linear(mod.bias.float(), bias_scale, 0,
                                          torch.qint32)
        else:
            qbias = None
        qconv.bias = qbias
        qconv.scale = float(act_scale)
        qconv.zero_point = int(act_zp)

        return qconv
示例#5
0
 def from_float(cls, mod):
     if type(mod) == torch.nn.intrinsic.qat.ConvBnReLU3d:
         mod.weight, mod.bias = fuse_conv_bn_weights(
             mod.weight,
             mod.bias,
             mod.bn.running_mean,
             mod.bn.running_var,
             mod.bn.eps,
             mod.bn.weight,
             mod.bn.bias,
         )
     return super(ConvReLU3d, cls).from_float(mod)
示例#6
0
文件: conv.py 项目: khabya/DeepStack
    def from_float(cls, mod):
        r"""Creates a quantized module from a float module or qparams_dict.

        Args:
            mod (Module): a float module, either produced by torch.quantization
              utilities or provided by the user
        """
        if hasattr(mod, 'weight_fake_quant'):
            # assert type(mod) == cls.__QAT_MODULE, ' nnq.' + cls.__name__ + \
            # '.from_float only works for ' + cls.__QAT_MODULE.__name__
            if type(mod) == nniqat.ConvBn2d:
                mod.weight, mod.bias = fuse_conv_bn_weights(
                    mod.weight, mod.bias, mod.bn.running_mean, mod.bn.running_var,
                    mod.bn.eps, mod.bn.weight, mod.bn.bias)
            assert hasattr(mod, 'activation_post_process'), \
                'Input QAT module must have observer attached'
            weight_post_process = mod.weight_fake_quant
            activation_post_process = mod.activation_post_process
        else:
            assert type(mod) == cls._FLOAT_MODULE, \
                ' nnq.' + cls.__name__ + '.from_float only works for ' + \
                cls._FLOAT_MODULE.__name__
            assert hasattr(mod, 'qconfig'), \
                'Input float module must have qconfig defined.'
            # workaround for sequential, ConvReLU2d should probably
            # inherit from Conv2d instead
            if type(mod) == nni.ConvReLU2d:
                activation_post_process = mod[1].activation_post_process
                mod = mod[0]
            else:
                activation_post_process = mod.activation_post_process
            weight_post_process = mod.qconfig.weight()
        weight_post_process(mod.weight)
        act_scale, act_zp = activation_post_process.calculate_qparams()
        assert weight_post_process.dtype == torch.qint8, \
            'Weight observer must have a dtype of qint8'
        qweight = _quantize_weight(mod.weight.float(), weight_post_process)
        qconv = cls(mod.in_channels, mod.out_channels, mod.kernel_size,
                    mod.stride, mod.padding, mod.dilation, mod.groups,
                    mod.bias is not None, mod.padding_mode)
        qconv.set_weight_bias(qweight, mod.bias)
        qconv.scale = float(act_scale)
        qconv.zero_point = int(act_zp)

        return qconv
示例#7
0
    def quantize(self, quantizer, node, load_arg):
        mod = self.conv
        weight, bias = mod.weight, mod.bias

        if self.bn_node is not None:
            weight, bias = fuse_conv_bn_weights(weight, bias,
                                                self.bn.running_mean,
                                                self.bn.running_var,
                                                self.bn.eps, self.bn.weight,
                                                self.bn.bias)

        min_val, max_val = float(weight.min()), float(weight.max())

        act_scale, act_zp = self.scale_zeropoint()

        weight_scale, weight_zp = _minmax_scale_zeropoint(min_val, max_val)
        qweight = torch.quantize_per_tensor(weight, weight_scale, weight_zp,
                                            torch.qint8)

        ctor = torch.nn.intrinsic.quantized.ConvReLU2d if self.relu_node is not None else torch.nn.quantized.Conv2d

        qconv = ctor(mod.in_channels, mod.out_channels, mod.kernel_size,
                     mod.stride, mod.padding, mod.dilation, mod.groups,
                     mod.bias is not None, mod.padding_mode)

        qconv.set_weight_bias(qweight, bias)
        qconv.scale = float(act_scale)
        qconv.zero_point = int(act_zp)
        parent_name, name = _parent_name(self.conv_node.target)
        setattr(quantizer.modules[parent_name], name, qconv)
        if self.bn_node is not None:
            parent_bn, bn_name = _parent_name(self.bn_node.target)
            # we can't just delete this because submodules's forwards (which are not longer use)
            # try to call it, so replace with something that does nothing.
            setattr(quantizer.modules[parent_name], bn_name, IdentityModule())

        return quantizer.quantized_graph.create_node(
            'call_module', self.conv_node.target,
            (load_arg(self.conv_node.args[0]), ), {})
示例#8
0
 def from_float(cls, mod):
     if type(mod) == torch.nn._intrinsic.qat.ConvBnReLU2d:
         mod.weight, mod.bias = \
             fuse_conv_bn_weights(mod.weight, mod.bias, mod.running_mean,
                                  mod.running_var, mod.eps, mod.gamma, mod.beta)
     return super(ConvReLU2d, cls).from_float(mod)
示例#9
0
    def traceHook(self, module, input: Union[T, Tuple[T, T]], output: T):
        """
        Pytorch NN module forward hook. Used to intercept NN layer execution order and dependency. This function will be
            called by PyTorch during inference.
        TODO: Add support for transposed convolution
        TODO: Add support fof MaxPool2d properly
        :param module: The PyTorch NN module to be hooked.
        :param input: Input tensor. Can be a tuple if the module has multiple inputs.
        :param output:
        :return: The modified output.
        """
        # Modify the output
        # output[0]: layer id of the layer
        # output[1]: INT8 activation precision

        print("Tracing module {}. Type; {}. {}".format(self.layerID,
                                                       type(module), module))
        # Extract the information: input precision(s), input id
        inputIDs = []
        inputPrecisions = []
        inputChannels = []
        inputGroupsSeenbySource = []
        inputHeights = []
        inputWidths = []
        input0FracBits = None
        outputPrecisionScale = None
        isAfterInput = False
        if isinstance(module, QuantStub) is False:
            if isinstance(input, T):
                nElements = input.numel()
                idx = int(input.view(nElements)[self.ID_IDX].item())
                # print('Input idx: {}'.format(idx))
                inputIDs.append(idx)
                # inputPrecision = input.view(nElements)[self.PRECISION_IDX].item()
                inputPrecision = torch.pow(
                    torch.tensor(0.5, dtype=torch.float),
                    torch.tensor(self.layerList[idx].outputFracBits,
                                 dtype=torch.float))
                # print('Input precision: {}'.format(inputPrecision))
                inputPrecisions.append(inputPrecision)
                inputChannels.append(self.layerList[idx].outputChannels)
                inputGroupsSeenbySource.append(
                    self.layerList[idx].outputCurrentNumGroups)
                inputHeights.append(self.layerList[idx].outputHeight)
                inputWidths.append(self.layerList[idx].outputWidth)
                if (self.layerList[idx].operationType == 'quantstub'):
                    isAfterInput = True
            elif isinstance(input, tuple):
                for tensor in input:
                    nElements = tensor.numel()
                    # print('Input: {}'.format(tensor.view(nElements)[0:2]))
                    idx = int(tensor.view(nElements)[self.ID_IDX].item())
                    # print('Input idx: {}'.format(idx))
                    inputIDs.append(idx)
                    # inputPrecision = tensor.view(nElements)[self.PRECISION_IDX].item()
                    inputPrecision = torch.pow(
                        torch.tensor(0.5, dtype=torch.float),
                        torch.tensor(self.layerList[idx].outputFracBits,
                                     dtype=torch.float))
                    # print('Input precision: {}'.format(inputPrecision))
                    inputPrecisions.append(inputPrecision)
                    inputChannels.append(self.layerList[idx].outputChannels)
                    inputGroupsSeenbySource.append(
                        self.layerList[idx].outputCurrentNumGroups)
                    inputHeights.append(self.layerList[idx].outputHeight)
                    inputWidths.append(self.layerList[idx].outputWidth)

                    if (self.layerList[idx].operationType == 'quantstub'):
                        isAfterInput = True
            else:
                raise TypeError(
                    'The input argument is neither a tensor nor a tuple of tensors'
                )

            input0FracBits = torch.round(torch.log2(
                1.0 / inputPrecisions[0])).view(1)[0].item()
            outputPrecisionScale = inputPrecisions[0]

        # Determine output precision scales
        if hasattr(module, 'activation_post_process'):
            outputPrecisionScale = module.activation_post_process.scale.view(1)
        elif isinstance(module, cm.EltwiseAdd):
            outputPrecisionScale = module.quant.activation_post_process.scale.view(
                1)
        else:
            if isinstance(module, (cm.MaxPool2dRelu, cm.AvgPool2dRelu)):
                # Number of fraction bits = log2(1/scale)
                # Number of interger bits = 8 - number of fraction bits
                # Number of fraction bits new = log2(1/scale_new) = number of fraction bits - 1 = log2(1/scale) - 1
                # log2(scale) + 1 = log2(scale_new)
                # scale_new = 2 * scale
                # iprecision0 = inputPrecisions[0].view(1)[0].item()
                # iprecision1 = inputPrecisions[1].view(1)[0].item()
                # import math
                # if  math.isclose(iprecision0, iprecision1):
                #     outputPrecisionScale *= 2.0
                # else:
                #     outputPrecisionScale = torch.tensor(iprecision1) if iprecision1 > iprecision0 else torch.tensor(iprecision0)
                outputPrecisionScale = module.quant.activation_post_process.scale.view(
                    1)
            else:
                pass
        outputFracBits = int(
            torch.round(torch.log2(1.0 /
                                   outputPrecisionScale)).view(1)[0].item())

        # Instantiate and insert a layer, register input/output adjacencies
        # If this is a convolution-based layer. Even the fused layer types are Conv2d thanks to inheritance
        newLayer = None
        outputChannels = output.size()[1]
        outputRelu = False

        # For the list that convolution layer-like types after qat is applied,
        # see
        # https://github.com/pytorch/pytorch/blob/20ac7362009dd8e0aca6e72fc9357773136a83b8/torch/quantization/quantization_mappings.py#L54
        if isinstance(
                module,
            (nnqat.Linear, nnqat.Conv2d, nniqat.ConvBn2d, nniqat.ConvBnReLU2d,
             nniqat.ConvReLU2d, nniqat.LinearReLU, cqat_modules.Linear,
             cqat_modules.Conv2d, cqat_modules.ConvBn2d,
             cqat_modules.ConvBnReLU2d, cqat_modules.ConvReLU2d,
             cqat_modules.LinearReLU)):
            if isinstance(module,
                          (nniqat.ConvReLU2d, nniqat.LinearReLU,
                           nniqat.ConvBnReLU2d, cqat_modules.ConvBnReLU2d,
                           cqat_modules.ConvReLU2d, cqat_modules.LinearReLU)):
                outputRelu = True

            # Determine padding and kernel size
            padding = 0
            kernelSize = inputWidths[0]
            kernelStride = kernelSize
            groups = 1
            if isinstance(
                    module,
                (nnqat.Conv2d, nniqat.ConvBn2d, nniqat.ConvBnReLU2d,
                 nniqat.ConvReLU2d, cqat_modules.Conv2d, cqat_modules.ConvBn2d,
                 cqat_modules.ConvBnReLU2d, cqat_modules.ConvReLU2d)):
                # Extract the padding for convolution.
                # Assume that the horizontal and vertical paddings are the same
                padding = module.padding[0]
                kernelSize = module.kernel_size[0]
                kernelStride = module.stride[0]
                groups = module.groups

            weight = module.weight
            bias = None
            if module.bias is not None:
                bias = module.bias
            # Perform batchnorm folding and update the quantization parameters if necessary
            if self.foldBN and isinstance(
                    module,
                (nniqat.ConvBn2d, nniqat.ConvBnReLU2d, cqat_modules.ConvBn2d,
                 cqat_modules.ConvBnReLU2d)):
                """
                Assumptions:
                    - This is a fused model
                    - When the fusion occured, the model was not in eval mode, so its bn parameters are not folded at this moment
                    - Warning: If bn folding have occured, this the following folding will mess up the weights and bias
                    -   Reason: PyTorch's BN folding function only change conv weights and biases, but does affect BN parameters
                """
                # PyTorch v1.5.0
                # weight, bias = fuse_conv_bn_weights(
                #     module.weight, module.bias, module.running_mean, module.running_var,
                #     module.eps, module.gamma, module.beta)
                # PyTorch v1.6.0
                weight, bias = fuse_conv_bn_weights(module.weight, module.bias,
                                                    module.bn.running_mean,
                                                    module.bn.running_var,
                                                    module.bn.eps,
                                                    module.bn.weight,
                                                    module.bn.bias)

            # Determine weight frac bits
            # The quantization observer for weight of fused conv_bn already monitors the folded weights
            # see: https://github.com/pytorch/pytorch/blob/7f73f1d591afba823daa4a99a939217fb54d7688/torch/nn/intrinsic/qat/modules/conv_fused.py#L113
            weight_quantizer = module.weight_fake_quant
            # weight_post_process.enable_observer()
            # weight_post_process(weight)
            # weight_post_process.disable_observer()
            # weight_quantizer = custom_quant.RoundedMinMaxObserver()
            # weight_quantizer.forward(weight)
            weightPrecisionScale, _ = weight_quantizer.calculate_qparams()
            weightPrecisionScale = weightPrecisionScale.view(1)
            weightFracBits = int(
                torch.ceil(torch.log2(1.0 /
                                      weightPrecisionScale)).view(1)[0].item())
            # Just use the quantizer to quantize the weights before exporting
            weight = weight_quantizer.forward(weight)

            hasBias = False if bias is None else True

            # Quantize bias
            # if hasBias:
            #     bias = cqat_modules.quantize_bias(module, bias)

            # Determine the SpW parameters
            flagFoundSpWInfo = False
            pruneRangeInCluster = None
            pruneCluster = None
            sparsity = None
            for _, hook in module._forward_pre_hooks.items():
                if isinstance(hook, cm_prune.balancedPruningMethod):
                    flagFoundSpWInfo = True
                    pruneRangeInCluster = hook.pruneRangeInCluster
                    sparsity = hook.sparsity
                    pruneCluster = hook.clusterSize

            if not flagFoundSpWInfo:
                print(
                    "Using default sparsity pruning parameters for module {}".
                    format(module))
                pruneRangeInCluster = self.defaultPruneRangeInCluster
                sparsity = 0.0
                pruneCluster = self.defaultPruneCluster

            needToPermuteWeight = True
            if isinstance(
                    module,
                (nnqat.Linear, nniqat.LinearReLU, cqat_modules.Linear)):
                needToPermuteWeight = False

            newLayer = ConvInfo(
                outputFracBits=int(outputFracBits),
                outputChannels=outputChannels,
                outputRelu=outputRelu,
                inputFracBits=int(input0FracBits),
                inputHeight=inputHeights[0],
                inputWidth=inputWidths[0],
                inputBorderPadding=padding,
                inputTransConvPadding=0,
                inputChannels=inputChannels[0],
                inputGroupsSeenBySource=inputGroupsSeenbySource[0],
                weightFracBits=int(weightFracBits),
                kernelSize=kernelSize,
                kernelStride=kernelStride,
                hasBias=hasBias,
                pruneRangeInCluster=pruneRangeInCluster,
                pruneClusterSize=pruneCluster,
                sparsity=sparsity,
                channelGroups=groups,
                layerID=self.layerID,
                isAfterInput=isAfterInput,
                needToPermuteWeight=needToPermuteWeight)

            # Extract parameters
            # Extract weights
            newLayer.weightParameterFileStartPosition = self.parameterCount
            self.parameterCount += weight.numel()
            self.parameters.append(weight)
            self.parameterKeys.append(str(self.layerID) + '_weight')

            newLayer.biasParameterFileStartPosition = self.parameterCount
            self.parameterCount += outputChannels
            self.parameterKeys.append(str(self.layerID) + '_bias')
            if hasBias is False:
                bias = torch.zeros([outputChannels])

            self.parameters.append(bias)

        # if this is an average pooling layer
        elif isinstance(module, cm.AvgPool2dRelu):
            #Average pooling layer should be seen as a special case of depth-wise convolution layer
            padding: int = 0
            if hasattr(module.padding, '__getitem__'):
                padding = module.padding[0]
            else:
                padding = module.padding
            newLayer = AvgPoolInfo(
                outputFracBits=outputFracBits,
                outputRelu=False,
                inputFracBits=int(input0FracBits),
                inputHeight=inputHeights[0],
                inputWidth=inputWidths[0],
                inputBorderPadding=padding,
                inputChannels=inputChannels[0],
                inputGroupsSeenBySource=inputGroupsSeenbySource[0],
                kernelSize=module.kernel_size,
                kernelStride=module.stride,
                divisor=module.divisor_override,
                layerID=self.layerID)
        elif isinstance(module, cm.MaxPool2dRelu):
            padding: int = 0
            if hasattr(module.padding, '__getitem__'):
                padding = module.padding[0]
            else:
                padding = module.padding
            newLayer = MaxPoolInfo(
                outputFracBits=outputFracBits,
                outputRelu=module.relu,
                inputFracBits=int(input0FracBits),
                inputHeight=inputHeights[0],
                inputWidth=inputWidths[0],
                inputBorderPadding=padding,
                inputChannels=inputChannels[0],
                inputGroupsSeenBySource=inputGroupsSeenbySource[0],
                kernelSize=module.kernel_size,
                kernelStride=module.stride,
                layerID=self.layerID)
        elif isinstance(module, cm.EltwiseAdd):
            assert inputHeights[0] == inputHeights[
                1], "Input heights do not match for eltwise-add operation"
            assert inputWidths[0] == inputWidths[
                1], "Input widths do not match for eltwise-add operation"
            assert inputChannels[0] == inputChannels[
                1], "Input channels do not match for eltwise-add operation"
            input1FracBits = int(
                torch.round(torch.log2(1.0 /
                                       inputPrecisions[1])).view(1)[0].item())
            newLayer = EltAddInfo(
                outputFracBits=outputFracBits,
                outputRelu=module.relu,
                inputHeight=inputHeights[0],
                inputWidth=inputWidths[0],
                inputChannels=inputChannels[0],
                inputLeftFracBits=int(input0FracBits),
                inputRightFracBits=int(input1FracBits),
                inputLeftGroupsSeenBySource=int(inputGroupsSeenbySource[0]),
                inputRightGroupsSeenBySource=int(inputGroupsSeenbySource[1]),
                layerID=self.layerID)
        elif isinstance(module, QuantStub):
            newLayer = QuantStubInfo(outputFracBits=outputFracBits,
                                     outputChannels=input[0].size()[1],
                                     outputHeight=input[0].size()[2],
                                     outputWidth=input[0].size()[3],
                                     layerID=self.layerID)
        elif isinstance(module, DeQuantStub):
            newLayer = DeQuantStubInfo(
                inputFracBits=int(input0FracBits),
                inputChannels=input[0].size()[1],
                inputGroupsSeenBySource=inputGroupsSeenbySource[0],
                inputHeight=input[0].size()[2]
                if len(input[0].size()) > 2 else 1,
                inputWidth=input[0].size()[3]
                if len(input[0].size()) > 2 else 1,
                layerID=self.layerID)
        else:
            raise TypeError(
                'The tracer hook can only be applied to QuantStub, Conv2d types, Linear types, AvgPool2d types, '
                'AvgPool2d, and EltwiseAdd types. Input module type is {}'.
                format(type(module)))

        # Insert the layer and save the connections
        self.insertLayer(newLayer)
        if isinstance(module, QuantStub) is False:
            for idx in inputIDs:
                self.addForwardEdge(sourceLayerId=idx,
                                    sinkLayerId=self.layerID)
                self.addBackward(sourceLayerId=idx, sinkLayerId=self.layerID)

        # Propagate the layer id, and output precision to the next layer
        outputNumel = output.numel()
        output.view(outputNumel)[self.ID_IDX] = self.layerID

        self.layerID += 1
        self.layerCount += 1