def forward(self, input): weight_q = ste.round_power_of_2(self.weight) input_fixed_point = ste.round_fixed_point(input) if self.bias is not None: bias_fixed_point = ste.round_fixed_point(self.bias) else: bias_fixed_point = None if self.padding_mode == 'circular': expanded_padding = ((self.padding[1] + 1) // 2, self.padding[1] // 2, (self.padding[0] + 1) // 2, self.padding[0] // 2) input_padded = F.pad(input_fixed_point, expanded_padding, mode='circular') padding = _pair(0) else: input_padded = input_fixed_point padding = self.padding if self.use_kernel: return Conv2dShiftQFunction.apply(input_padded, weight_q, bias_fixed_point, self.conc_weight, self.stride, padding, self.dilation, self.groups, self.use_kernel, self.use_cuda) else: return torch.nn.functional.conv2d(input_padded, weight_q, bias_fixed_point, self.stride, padding, self.dilation, self.groups)
def forward(self, input): if self.shift_or_full == 'shift': self.weight.data = ste.clampabs(self.weight.data, 2**self.shift_range[0], 2**self.shift_range[1]) weight_q = ste.round_power_of_2(self.weight, self.rounding) # print(weight_q) input_fixed_point = ste.round_fixed_point(input) if self.bias is not None: bias_fixed_point = ste.round_fixed_point(self.bias) else: bias_fixed_point = None elif self.shift_or_full == 'full': # full precision training weight_q = self.weight input_fixed_point = ste.round_fixed_point(input) if self.bias is not None: bias_fixed_point = self.bias else: bias_fixed_point = None else: raise NotImplementedError('No such type!') if self.padding_mode == 'circular': expanded_padding = ((self.padding[1] + 1) // 2, self.padding[1] // 2, (self.padding[0] + 1) // 2, self.padding[0] // 2) input_padded = F.pad(input_fixed_point, expanded_padding, mode='circular') padding = _pair(0) else: input_padded = input_fixed_point padding = self.padding if self.use_kernel: return Conv2dShiftQFunction.apply(input_padded, weight_q, bias_fixed_point, self.conc_weight, self.stride, padding, self.dilation, self.groups, self.use_kernel, self.use_cuda) else: return torch.nn.functional.conv2d(input_padded, weight_q, bias_fixed_point, self.stride, padding, self.dilation, self.groups)
def forward(self, input): weight_q = ste.round_power_of_2(self.weight) input_fixed_point = ste.round_fixed_point(input) if self.bias is not None: bias_fixed_point = ste.round_fixed_point(self.bias) else: bias_fixed_point = None if self.use_kernel: return LinearShiftQFunction.apply(input_fixed_point, weight_q, bias_fixed_point, self.conc_weight, self.use_kernel, self.use_cuda) else: out = input_fixed_point.mm(weight_q.t()) if self.bias is not None: out += self.bias.unsqueeze(0).expand_as(out) return out
def forward(self, input): self.weight.data = ste.clampabs(self.weight.data, self.base**self.shift_range[0], self.base**self.shift_range[1]) weight_q = ste.round_power_of_2(self.weight, self.base, self.rounding) input_fixed_point = ste.round_fixed_point(input) if self.bias is not None: bias_fixed_point = ste.round_fixed_point(self.bias) else: bias_fixed_point = None if self.use_kernel: return LinearShiftQFunction.apply(input_fixed_point, weight_q, self.base, bias_fixed_point, self.conc_weight, self.use_kernel, self.use_cuda) else: out = input_fixed_point.mm(weight_q.t()) if self.bias is not None: out += self.bias.unsqueeze(0).expand_as(out) return out