def forward(ctx, input, shift, sign, bias=None, conc_weight=None, stride=1, padding=0, dilation=1, groups=1, use_kernel=False, use_cuda=False, rounding='deterministic', shift_range=(-14, 0)): fraction_bits = 16 integer_bits = 16 # start_time = time.time() if use_kernel: input_fixed_point = (input * (2**fraction_bits)).int() if bias is not None: bias_fixed_point = (bias * (2**fraction_bits)).int() else: bias_fixed_point = None out = deepshift.kernels.conv2d(input_fixed_point, shift, sign, bias_fixed_point, conc_weight, stride, padding, dilation, groups, use_cuda) out = out.float() out = out / (2**fraction_bits) else: shift = shift.clamp(*shift_range) sign = sign.clamp(-1, 1) input.data = utils.round_to_fixed(input.data, fraction_bits, integer_bits) if bias is not None: bias.data = utils.round_to_fixed(bias.data, fraction_bits, integer_bits) # shift_rounded = utils.round(self.shift, stochastic=False) # sign_rounded_signed = torch.sign(utils.round(self.sign, stochastic=False)) shift_rounded = utils.round(shift, stochastic=False) sign_rounded_signed = torch.sign( utils.round(sign, stochastic=False)) v = 2**shift_rounded * sign_rounded_signed out = F.conv2d(input, v, bias, stride, padding, dilation, groups) ctx.save_for_backward(input, shift, sign, bias, v) ctx.stride = stride ctx.padding = padding ctx.dilation = dilation ctx.groups = groups return out
def forward(ctx, input, rounding='deterministic'): return utils.round(input, rounding)