def int_weight(self, x: torch.Tensor):
     zero_hw_sentinel = getattr(self, ZERO_HW_SENTINEL_NAME)
     quant_weight, scale, _ = self.tensor_quant(x, zero_hw_sentinel)
     quant_weight = quant_weight / scale
     quant_weight = round_ste(quant_weight)
     quant_weight = quant_weight.int()
     return quant_weight
 def __int__(self):
     return round_ste(self.tensor / self.scale)
Ejemplo n.º 3
0
 def forward(self, x: torch.Tensor):
     return round_ste(x)