示例#1
0
    def forward(ctx,
                input,
                shift,
                sign,
                bias=None,
                conc_weight=None,
                stride=1,
                padding=0,
                dilation=1,
                groups=1,
                use_kernel=False,
                use_cuda=False,
                rounding='deterministic',
                shift_range=(-14, 0)):
        fraction_bits = 16
        integer_bits = 16

        # start_time = time.time()
        if use_kernel:
            input_fixed_point = (input * (2**fraction_bits)).int()
            if bias is not None:
                bias_fixed_point = (bias * (2**fraction_bits)).int()
            else:
                bias_fixed_point = None

            out = deepshift.kernels.conv2d(input_fixed_point, shift, sign,
                                           bias_fixed_point, conc_weight,
                                           stride, padding, dilation, groups,
                                           use_cuda)

            out = out.float()
            out = out / (2**fraction_bits)
        else:
            shift = shift.clamp(*shift_range)
            sign = sign.clamp(-1, 1)
            input.data = utils.round_to_fixed(input.data, fraction_bits,
                                              integer_bits)

            if bias is not None:
                bias.data = utils.round_to_fixed(bias.data, fraction_bits,
                                                 integer_bits)

            # shift_rounded = utils.round(self.shift, stochastic=False)
            # sign_rounded_signed = torch.sign(utils.round(self.sign, stochastic=False))

            shift_rounded = utils.round(shift, stochastic=False)
            sign_rounded_signed = torch.sign(
                utils.round(sign, stochastic=False))
            v = 2**shift_rounded * sign_rounded_signed
            out = F.conv2d(input, v, bias, stride, padding, dilation, groups)

            ctx.save_for_backward(input, shift, sign, bias, v)
            ctx.stride = stride
            ctx.padding = padding
            ctx.dilation = dilation
            ctx.groups = groups

        return out
示例#2
0
 def forward(ctx, input, rounding='deterministic'):
     return utils.round(input, rounding)