def forward(self, input): self.shift.data = ste.clamp(self.shift.data, *self.shift_range) shift_rounded = ste.round(self.shift, self.rounding) sign_rounded_signed = ste.sign(ste.round(self.sign, self.rounding)) weight_ps = ste.unsym_grad_mul(self.shift_base**shift_rounded, sign_rounded_signed) input_fixed_point = ste.round_fixed_point(input) if self.bias is not None: bias_fixed_point = ste.round_fixed_point(self.bias) else: bias_fixed_point = None if self.padding_mode == 'circular': expanded_padding = ((self.padding[1] + 1) // 2, self.padding[1] // 2, (self.padding[0] + 1) // 2, self.padding[0] // 2) input_padded = F.pad(input_fixed_point, expanded_padding, mode='circular') padding = _pair(0) else: input_padded = input_fixed_point padding = self.padding if self.use_kernel: return Conv2dShiftFunction.apply(input_padded, self.shift, self.sign, self.shift_base, bias_fixed_point, self.conc_weight, self.stride, padding, self.dilation, self.groups, self.use_kernel, self.use_cuda, self.rounding, self.shift_range) else: return torch.nn.functional.conv2d(input_padded, weight_ps, bias_fixed_point, self.stride, padding, self.dilation, self.groups)
def forward(self, input): self.shift.data = ste.clamp(self.shift.data, *self.shift_range) shift_rounded = ste.round(self.shift, self.rounding) # print(self.threshold) if self.threshold is None: # print('default threshold') sign_rounded_signed = ste.sign(ste.round(self.sign, self.rounding)) else: sign_rounded_signed = ste.sign( ste.myround(self.sign, self.threshold)) weight_ps = ste.unsym_grad_mul(2**shift_rounded, sign_rounded_signed) # input_fixed_point = ste.round_fixed_point(input, quant_bits=self.quant_bits) if self.quant_bits > 0: input_fixed_point = ste.round_fixed_point( input, quant_bits=self.quant_bits) else: input_fixed_point = ste.round_fixed_point( input, quant_bits=self.quant_bits) if self.bias is not None: bias_fixed_point = ste.round_fixed_point(self.bias) else: bias_fixed_point = None if self.padding_mode == 'circular': expanded_padding = ((self.padding[1] + 1) // 2, self.padding[1] // 2, (self.padding[0] + 1) // 2, self.padding[0] // 2) input_padded = F.pad(input_fixed_point, expanded_padding, mode='circular') padding = _pair(0) else: input_padded = input_fixed_point padding = self.padding if self.use_kernel: # if True: output = Conv2dShiftFunction.apply( input_padded, self.shift, self.sign, bias_fixed_point, self.conc_weight, self.stride, padding, self.dilation, self.groups, self.use_kernel, self.use_cuda, self.rounding, self.shift_range) else: output = torch.nn.functional.conv2d(input_padded, weight_ps, bias_fixed_point, self.stride, padding, self.dilation, self.groups) # quantize backpropogation if self.quant_bits > 0: output = quantize_grad(output, num_bits=self.quant_bits, flatten_dims=(1, -1)) return output
def forward(self, input): self.shift.data = ste.clamp(self.shift.data, *self.shift_range) shift_rounded = ste.round(self.shift, rounding=self.rounding) sign_rounded_signed = ste.sign(ste.round(self.sign, rounding=self.rounding)) weight_ps = ste.unsym_grad_mul(self.shift_base**shift_rounded, sign_rounded_signed) if self.use_kernel: return LinearShiftFunction.apply(input, self.shift, self.sign, self.bias, self.shift_base, self.conc_weight, self.use_kernel, self.use_cuda, self.rounding, self.shift_range) else: return torch.nn.functional.linear(input, weight_ps, self.bias)
def reset_parameters(self): # n = self.in_channels # init.kaiming_uniform_(self.weight, a=math.sqrt(5)) # init.kaiming_normal_(self.weight, mode='fan_out', nonlinearity='relu') # if self.distribution == 'kaiming_normal': init.kaiming_normal_(self.weight, mode='fan_out', nonlinearity='relu') self.set_mask() # quantize self.s.data.uniform_(-1, 1) sign = ste.sign(round(self.s, self.sign_threshold)) self.weight.data *= abs(sign) else: if self.distribution == 'uniform': self.p.data.uniform_(-self.min_p - 0.5, -1 + 0.5) elif self.distribution == 'normal': self.p.data.normal_(-self.min_p / 2, 1) self.p.data = ste.clamp(self.p.data, *self.shift_range) self.p.data = ste.round(self.p.data, 'deterministic') self.s.data.uniform_(-1, 1) sign = ste.sign(round(self.s, self.sign_threshold)) # self.weight.data = torch.sign(self.weight) * (2 ** self.p.data) self.weight.data = sign * (2**self.p.data) if self.bias is not None: print('use bias') fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight) bound = 1 / math.sqrt(fan_in) init.uniform_(self.bias, -bound, bound)
def forward(self, input): self.shift.data = ste.clamp(self.shift.data, *self.shift_range) shift_rounded = ste.round(self.shift, rounding=self.rounding) sign_rounded_signed = ste.sign( ste.round(self.sign, rounding=self.rounding)) weight_ps = ste.unsym_grad_mul(2**shift_rounded, sign_rounded_signed) # TODO: round bias and input to fixed point if self.use_kernel: return LinearShiftFunction.apply(input, self.shift, self.sign, self.bias, self.conc_weight, self.use_kernel, self.use_cuda, self.rounding, self.shift_range, self.act_integer_bits, self.act_fraction_bits) else: return torch.nn.functional.linear(input, weight_ps, self.bias)
def forward(self, input): shift_rounded = ste.round(self.shift) sign_rounded_signed = ste.sign(ste.round(self.sign)) weight_ps = ste.unsym_grad_mul( 2**shift_rounded, sign_rounded_signed ) # 2**utils.stochastic_rounding(shift) * sign.round().sign() input_fixed_point = ste.round_fixed_point(input) if self.bias is not None: bias_fixed_point = ste.round_fixed_point(self.bias) else: bias_fixed_point = None if self.padding_mode == 'circular': expanded_padding = ((self.padding[1] + 1) // 2, self.padding[1] // 2, (self.padding[0] + 1) // 2, self.padding[0] // 2) input_padded = F.pad(input_fixed_point, expanded_padding, mode='circular') padding = _pair(0) else: input_padded = input_fixed_point padding = self.padding if self.use_kernel: return Conv2dShiftFunction.apply(input_padded, self.shift, self.sign, bias_fixed_point, self.conc_weight, self.stride, padding, self.dilation, self.groups, self.use_kernel, self.use_cuda) else: return torch.nn.functional.conv2d(input_padded, weight_ps, bias_fixed_point, self.stride, padding, self.dilation, self.groups) '''