def get_test_threshold(input_shape): torch.manual_seed(0) threshold_shape = get_per_channel_scale_shape(input_shape, is_weights=False) retval = torch.Tensor(torch.zeros(threshold_shape)) retval.random_(-10, 10) return retval
def __init__(self, input_shape, enabled=False, desc=""): super().__init__(enabled) self.input_shape = input_shape self.scale = torch.nn.Parameter(torch.Tensor([0]), requires_grad=enabled) self.scale.data.zero_() # Need scale_is_initialized as buffer for it to appear in the model state dict self.register_buffer('scale_initialized', torch.IntTensor([0])) threshold_shape = get_per_channel_scale_shape(self.input_shape, is_weights=False) self.threshold = torch.nn.Parameter(torch.ones(threshold_shape), requires_grad=enabled) self.threshold.data.zero_() self.bin = activation_bin_scale_threshold_op
def __init__(self, num_bits=8, input_shape=None, is_weights=True, per_channel=False): super().__init__() self.input_shape = input_shape self.is_weights = is_weights scale_shape = 1 if per_channel: scale_shape = get_per_channel_scale_shape(self.input_shape, self.is_weights) self.scale = nn.Parameter(torch.ones(scale_shape)) self.num_bits = num_bits self.level_high = 2**(self.num_bits - 1) - 1 self.level_low = -(self.level_high + 1) self.quantize = ReferenceQuantizeSymmetric.apply