def forward(self, input): self.shift.data = ste.clamp(self.shift.data, *self.shift_range) shift_rounded = ste.round(self.shift, self.rounding) sign_rounded_signed = ste.sign(ste.round(self.sign, self.rounding)) weight_ps = ste.unsym_grad_mul(self.shift_base**shift_rounded, sign_rounded_signed) input_fixed_point = ste.round_fixed_point(input) if self.bias is not None: bias_fixed_point = ste.round_fixed_point(self.bias) else: bias_fixed_point = None if self.padding_mode == 'circular': expanded_padding = ((self.padding[1] + 1) // 2, self.padding[1] // 2, (self.padding[0] + 1) // 2, self.padding[0] // 2) input_padded = F.pad(input_fixed_point, expanded_padding, mode='circular') padding = _pair(0) else: input_padded = input_fixed_point padding = self.padding if self.use_kernel: return Conv2dShiftFunction.apply(input_padded, self.shift, self.sign, self.shift_base, bias_fixed_point, self.conc_weight, self.stride, padding, self.dilation, self.groups, self.use_kernel, self.use_cuda, self.rounding, self.shift_range) else: return torch.nn.functional.conv2d(input_padded, weight_ps, bias_fixed_point, self.stride, padding, self.dilation, self.groups)
def forward(self, input): self.shift.data = ste.clamp(self.shift.data, *self.shift_range) shift_rounded = ste.round(self.shift, self.rounding) # print(self.threshold) if self.threshold is None: # print('default threshold') sign_rounded_signed = ste.sign(ste.round(self.sign, self.rounding)) else: sign_rounded_signed = ste.sign( ste.myround(self.sign, self.threshold)) weight_ps = ste.unsym_grad_mul(2**shift_rounded, sign_rounded_signed) # input_fixed_point = ste.round_fixed_point(input, quant_bits=self.quant_bits) if self.quant_bits > 0: input_fixed_point = ste.round_fixed_point( input, quant_bits=self.quant_bits) else: input_fixed_point = ste.round_fixed_point( input, quant_bits=self.quant_bits) if self.bias is not None: bias_fixed_point = ste.round_fixed_point(self.bias) else: bias_fixed_point = None if self.padding_mode == 'circular': expanded_padding = ((self.padding[1] + 1) // 2, self.padding[1] // 2, (self.padding[0] + 1) // 2, self.padding[0] // 2) input_padded = F.pad(input_fixed_point, expanded_padding, mode='circular') padding = _pair(0) else: input_padded = input_fixed_point padding = self.padding if self.use_kernel: # if True: output = Conv2dShiftFunction.apply( input_padded, self.shift, self.sign, bias_fixed_point, self.conc_weight, self.stride, padding, self.dilation, self.groups, self.use_kernel, self.use_cuda, self.rounding, self.shift_range) else: output = torch.nn.functional.conv2d(input_padded, weight_ps, bias_fixed_point, self.stride, padding, self.dilation, self.groups) # quantize backpropogation if self.quant_bits > 0: output = quantize_grad(output, num_bits=self.quant_bits, flatten_dims=(1, -1)) return output
def forward(self, input): self.shift.data = ste.clamp(self.shift.data, *self.shift_range) shift_rounded = ste.round(self.shift, rounding=self.rounding) sign_rounded_signed = ste.sign(ste.round(self.sign, rounding=self.rounding)) weight_ps = ste.unsym_grad_mul(self.shift_base**shift_rounded, sign_rounded_signed) if self.use_kernel: return LinearShiftFunction.apply(input, self.shift, self.sign, self.bias, self.shift_base, self.conc_weight, self.use_kernel, self.use_cuda, self.rounding, self.shift_range) else: return torch.nn.functional.linear(input, weight_ps, self.bias)
def forward(self, input): self.shift.data = ste.clamp(self.shift.data, *self.shift_range) shift_rounded = ste.round(self.shift, rounding=self.rounding) sign_rounded_signed = ste.sign( ste.round(self.sign, rounding=self.rounding)) weight_ps = ste.unsym_grad_mul(2**shift_rounded, sign_rounded_signed) # TODO: round bias and input to fixed point if self.use_kernel: return LinearShiftFunction.apply(input, self.shift, self.sign, self.bias, self.conc_weight, self.use_kernel, self.use_cuda, self.rounding, self.shift_range, self.act_integer_bits, self.act_fraction_bits) else: return torch.nn.functional.linear(input, weight_ps, self.bias)
def forward(self, input): shift_rounded = ste.round(self.shift) sign_rounded_signed = ste.sign(ste.round(self.sign)) weight_ps = ste.unsym_grad_mul( 2**shift_rounded, sign_rounded_signed ) # 2**utils.stochastic_rounding(shift) * sign.round().sign() input_fixed_point = ste.round_fixed_point(input) if self.bias is not None: bias_fixed_point = ste.round_fixed_point(self.bias) else: bias_fixed_point = None if self.padding_mode == 'circular': expanded_padding = ((self.padding[1] + 1) // 2, self.padding[1] // 2, (self.padding[0] + 1) // 2, self.padding[0] // 2) input_padded = F.pad(input_fixed_point, expanded_padding, mode='circular') padding = _pair(0) else: input_padded = input_fixed_point padding = self.padding if self.use_kernel: return Conv2dShiftFunction.apply(input_padded, self.shift, self.sign, bias_fixed_point, self.conc_weight, self.stride, padding, self.dilation, self.groups, self.use_kernel, self.use_cuda) else: return torch.nn.functional.conv2d(input_padded, weight_ps, bias_fixed_point, self.stride, padding, self.dilation, self.groups) '''