예제 #1
0
 def __init__(self):
     super(SubModule, self).__init__()
     self.qconfig = default_qconfig
     self.mod1 = torch.nn.Conv2d(3, 3, 3, bias=False).to(dtype=torch.float)
     self.mod2 = nn.ReLU()
     self.quant = QuantStub()
     self.dequant = DeQuantStub()
예제 #2
0
 def __init__(self):
     super().__init__()
     self.sigmoid = torch.nn.Sigmoid()
     self.hardsigmoid = torch.nn.Hardsigmoid()
     self.tanh = torch.nn.Tanh()
     self.quant = QuantStub()
     self.dequant = DeQuantStub()
예제 #3
0
    def __init__(self, *args: Any, **kwargs: Any) -> None:
        """
        MobileNet V3 main class

        Args:
           Inherits args from floating point MobileNetV3
        """
        super().__init__(*args, **kwargs)
        self.quant = QuantStub()
        self.dequant = DeQuantStub()
예제 #4
0
 def __init__(self):
     super(ModelWithFunctionals, self).__init__()
     self.mycat = nnq.FloatFunctional()
     self.myadd = nnq.FloatFunctional()
     self.mymul = nnq.FloatFunctional()
     self.myadd_relu = nnq.FloatFunctional()
     self.my_scalar_add = nnq.FloatFunctional()
     self.my_scalar_mul = nnq.FloatFunctional()
     self.quant = QuantStub()
     self.dequant = DeQuantStub()
 def __init__(self, w, b, m, v):
     super(SimpleQuantizedBatchNormRelu, self).__init__()
     self.bn = torch.nn.BatchNorm3d(4)
     self.relu = torch.nn.ReLU()
     self.bn.weight = torch.nn.Parameter(w)
     self.bn.bias = torch.nn.Parameter(b)
     self.bn.running_mean = m
     self.bn.running_var = v
     self.q = QuantStub()
     self.dq = DeQuantStub()
예제 #6
0
 def test_linear_bn_workflow(self):
     qengine = torch.backends.quantized.engine
     m = nn.Sequential(
         QuantStub(),
         nn.Linear(4, 4),
         nn.BatchNorm1d(4),
     )
     data = torch.randn(4, 4)
     m.qconfig = torch.ao.quantization.get_default_qat_qconfig(qengine)
     m = torch.ao.quantization.fuse_modules_qat(m, [['1', '2']])
     mp = prepare_qat(m)
     mp(data)
     mq = convert(mp)
     self.assertTrue(type(mq[1]) == nnq.Linear)
     self.assertTrue(type(mq[2]) == nn.Identity)
예제 #7
0
 def __init__(self):
     super().__init__()
     self.act = Act()
     self.quant = QuantStub()
     self.dequant = DeQuantStub()