def __init__(self, quantized=False): super(ModForWrapping, self).__init__() self.qconfig = default_qconfig if quantized: self.mycat = nnq.QFunctional() self.myadd = nnq.QFunctional() else: self.mycat = nnq.FloatFunctional() self.myadd = nnq.FloatFunctional() self.mycat.observer = DummyObserver() self.myadd.observer = DummyObserver()
def init(self, M, N, K, L, dim, contig, dtype): f_input = (torch.rand(M, N, K) - 0.5) * 256 self.qf = nnq.QFunctional() scale = 1.0 zero_point = 0 self.qf.scale = scale self.qf.zero_point = zero_point assert(contig in ('none', 'one', 'all')) q_input = torch.quantize_per_tensor(f_input, scale, zero_point, dtype) permute_dims = tuple(range(q_input.ndim - 1, -1, -1)) q_input_non_contig = q_input.permute(permute_dims).contiguous() q_input_non_contig = q_input_non_contig.permute(permute_dims) if contig == 'all': self.input = (q_input, q_input) elif contig == 'one': self.input = (q_input, q_input_non_contig) elif contig == 'none': self.input = (q_input_non_contig, q_input_non_contig) self.inputs = { "input": self.input, "dim": dim } self.set_module_name('qcat')
def __init__(self): super(QFunctionalWrapper, self).__init__() self.qfunc = nnq.QFunctional()