def forward(self, input): # See the autograd section for explanation of what happens here. self.rshift_i, self.rshift_w, self.rshift_o = \ rshift_offset(input, self.weight, self.hwcfg["widthi"] - self.hwcfg["signmag"], self.hwcfg["widthw"] - self.hwcfg["signmag"], self.hwcfg["rounding"], self.hwcfg["quantilei"], self.hwcfg["quantilew"]) with torch.no_grad(): # all data are in NCHW output_size = conv2d_output_shape( (input.size()[2], input.size()[3]), kernel_size=self.kernel_size, dilation=self.dilation, pad=self.padding, stride=self.stride) input_im2col = torch.nn.functional.unfold(input, self.kernel_size, self.dilation, self.padding, self.stride) input_transpose = input_im2col.transpose(1, 2) input_reshape = input_transpose.reshape(-1, input_transpose.size()[-1]) weight = self.weight.view(self.weight.size()[0], -1) mm_out = HUBLinearFunction.apply(input_reshape, weight, None, self.rshift_i, self.rshift_w, self.rshift_o, self.cycle_act, self.mapctlee) mm_out_reshape = mm_out.reshape(input.size()[0], -1, mm_out.size()[-1]) mm_out_transpose = mm_out_reshape.transpose(1, 2) output = torch.nn.functional.fold(mm_out_transpose, output_size, (1, 1)) if self.bias is None: return output else: return output + self.bias.view([1, self.bias.size()[0], 1, 1])
def forward(self, input): # See the autograd section for explanation of what happens here. self.rshift_i, self.rshift_w, self.rshift_o = \ rshift_offset(input, self.weight, self.hwcfg["widthi"] - self.hwcfg["signmag"], self.hwcfg["widthw"] - self.hwcfg["signmag"], self.hwcfg["rounding"], self.hwcfg["quantilei"], self.hwcfg["quantilew"]) return HUBLinearFunction.apply(input, self.weight, self.bias, self.rshift_i, self.rshift_w, self.rshift_o, self.cycle_act, self.mapctlee)
def forward(self, input): # See the autograd section for explanation of what happens here. self.rshift_i, self.rshift_w, _ = \ rshift_offset(input, self.weight, self.hwcfg["widthi"] - 1, self.hwcfg["widthw"] - 1, self.hwcfg["rounding"], self.hwcfg["quantilei"], self.hwcfg["quantilew"]) self.rshift_o = 0 - self.rshift_i - self.rshift_w return FXPLinearFunction.apply(input, self.weight, self.bias, self.rshift_i, self.rshift_w, self.rshift_o, self.max_abs_i, self.max_abs_w)
def forward(ctx, input, weight, bias=None, temporal="i", width=8, widtht=4, degree=2, delta=0, cycle_pos=16, cycle_neg=-16, rounding="round", quantilei=1, quantilew=1): ctx.save_for_backward(input, weight, bias) input_fp32 = input.detach().clone().to(torch.float) weight_fp32 = weight.detach().clone().to(torch.float) rshift_i, rshift_w, _ = rshift_offset(input_fp32, weight_fp32, width, width, rounding, quantilei, quantilew) if temporal in ["i", "input"]: input_new = torch.zeros_like(input_fp32) frac = torch.zeros_like(input_fp32) torch.trunc((input_fp32 >> rshift_i).clamp(-2**width + 1, 2**width - 1), out=input_fp32) for i in range(degree): input_fp32 = input_fp32 >> widtht torch.frac(input_fp32, out=frac) torch.trunc(input_fp32, out=input_fp32) torch.clamp(frac << widtht, cycle_neg + 1, cycle_pos - 1, out=frac) torch.add(frac >> widtht, input_new >> widtht, out=input_new) input_new = (input_new << (delta + width + rshift_i)).type(weight.type()) weight_new = weight elif temporal in ["w", "weight"]: weight_new = torch.zeros_like(weight_fp32) frac = torch.zeros_like(weight_fp32) torch.trunc( (weight_fp32 >> rshift_w).clamp(-2**width + 1, 2**width - 1), out=weight_fp32) for i in range(degree): weight_fp32 = weight_fp32 >> widtht torch.frac(weight_fp32, out=frac) torch.trunc(weight_fp32, out=weight_fp32) torch.clamp(frac << widtht, cycle_neg + 1, cycle_pos - 1, out=frac) torch.add(frac >> widtht, weight_new >> widtht, out=weight_new) input_new = input weight_new = (weight_new << (delta + width + rshift_w)).type(input.type()) output = torch.matmul(input_new, weight_new.t()) if bias is not None: output += bias.unsqueeze(0).expand_as(output) return output