Beispiel #1
0
    def forward(self, input):
        self.shift.data = ste.clamp(self.shift.data, *self.shift_range)
        shift_rounded = ste.round(self.shift, self.rounding)
        sign_rounded_signed = ste.sign(ste.round(self.sign, self.rounding))
        weight_ps = ste.unsym_grad_mul(self.shift_base**shift_rounded, sign_rounded_signed)
        input_fixed_point = ste.round_fixed_point(input)
        if self.bias is not None:
            bias_fixed_point = ste.round_fixed_point(self.bias)
        else:
            bias_fixed_point = None

        if self.padding_mode == 'circular':
            expanded_padding = ((self.padding[1] + 1) // 2, self.padding[1] // 2,
                                (self.padding[0] + 1) // 2, self.padding[0] // 2)

            input_padded = F.pad(input_fixed_point, expanded_padding, mode='circular')
            padding =  _pair(0)
        else:
            input_padded = input_fixed_point
            padding = self.padding
                
        if self.use_kernel:
            return Conv2dShiftFunction.apply(input_padded, self.shift, self.sign, self.shift_base, bias_fixed_point, self.conc_weight, 
                                              self.stride, padding, self.dilation, self.groups, 
                                              self.use_kernel, self.use_cuda, self.rounding, self.shift_range)
        else:
            return torch.nn.functional.conv2d(input_padded, weight_ps, bias_fixed_point, 
                                              self.stride, padding, self.dilation, self.groups)
Beispiel #2
0
    def forward(self, input):
        self.shift.data = ste.clamp(self.shift.data, *self.shift_range)
        shift_rounded = ste.round(self.shift, self.rounding)
        # print(self.threshold)
        if self.threshold is None:
            # print('default threshold')
            sign_rounded_signed = ste.sign(ste.round(self.sign, self.rounding))
        else:
            sign_rounded_signed = ste.sign(
                ste.myround(self.sign, self.threshold))
        weight_ps = ste.unsym_grad_mul(2**shift_rounded, sign_rounded_signed)
        # input_fixed_point = ste.round_fixed_point(input, quant_bits=self.quant_bits)
        if self.quant_bits > 0:
            input_fixed_point = ste.round_fixed_point(
                input, quant_bits=self.quant_bits)
        else:
            input_fixed_point = ste.round_fixed_point(
                input, quant_bits=self.quant_bits)
        if self.bias is not None:
            bias_fixed_point = ste.round_fixed_point(self.bias)
        else:
            bias_fixed_point = None

        if self.padding_mode == 'circular':
            expanded_padding = ((self.padding[1] + 1) // 2,
                                self.padding[1] // 2,
                                (self.padding[0] + 1) // 2,
                                self.padding[0] // 2)

            input_padded = F.pad(input_fixed_point,
                                 expanded_padding,
                                 mode='circular')
            padding = _pair(0)
        else:
            input_padded = input_fixed_point
            padding = self.padding

        if self.use_kernel:
            # if True:
            output = Conv2dShiftFunction.apply(
                input_padded, self.shift, self.sign, bias_fixed_point,
                self.conc_weight, self.stride, padding, self.dilation,
                self.groups, self.use_kernel, self.use_cuda, self.rounding,
                self.shift_range)
        else:
            output = torch.nn.functional.conv2d(input_padded, weight_ps,
                                                bias_fixed_point, self.stride,
                                                padding, self.dilation,
                                                self.groups)

        # quantize backpropogation
        if self.quant_bits > 0:
            output = quantize_grad(output,
                                   num_bits=self.quant_bits,
                                   flatten_dims=(1, -1))

        return output
Beispiel #3
0
 def forward(self, input):
     self.shift.data = ste.clamp(self.shift.data, *self.shift_range)
     shift_rounded = ste.round(self.shift, rounding=self.rounding)
     sign_rounded_signed = ste.sign(ste.round(self.sign, rounding=self.rounding))
     weight_ps = ste.unsym_grad_mul(self.shift_base**shift_rounded, sign_rounded_signed)
     if self.use_kernel:
         return LinearShiftFunction.apply(input, self.shift, self.sign, self.bias, self.shift_base, self.conc_weight, self.use_kernel, self.use_cuda, self.rounding, self.shift_range)
     else:
         return torch.nn.functional.linear(input, weight_ps, self.bias)
Beispiel #4
0
    def forward(self, input):
        self.shift.data = ste.clamp(self.shift.data, *self.shift_range)
        shift_rounded = ste.round(self.shift, rounding=self.rounding)
        sign_rounded_signed = ste.sign(
            ste.round(self.sign, rounding=self.rounding))
        weight_ps = ste.unsym_grad_mul(2**shift_rounded, sign_rounded_signed)

        # TODO: round bias and input to fixed point

        if self.use_kernel:
            return LinearShiftFunction.apply(input, self.shift, self.sign,
                                             self.bias, self.conc_weight,
                                             self.use_kernel, self.use_cuda,
                                             self.rounding, self.shift_range,
                                             self.act_integer_bits,
                                             self.act_fraction_bits)
        else:
            return torch.nn.functional.linear(input, weight_ps, self.bias)
Beispiel #5
0
    def forward(self, input):
        shift_rounded = ste.round(self.shift)
        sign_rounded_signed = ste.sign(ste.round(self.sign))
        weight_ps = ste.unsym_grad_mul(
            2**shift_rounded, sign_rounded_signed
        )  # 2**utils.stochastic_rounding(shift) * sign.round().sign()
        input_fixed_point = ste.round_fixed_point(input)
        if self.bias is not None:
            bias_fixed_point = ste.round_fixed_point(self.bias)
        else:
            bias_fixed_point = None

        if self.padding_mode == 'circular':
            expanded_padding = ((self.padding[1] + 1) // 2,
                                self.padding[1] // 2,
                                (self.padding[0] + 1) // 2,
                                self.padding[0] // 2)

            input_padded = F.pad(input_fixed_point,
                                 expanded_padding,
                                 mode='circular')
            padding = _pair(0)
        else:
            input_padded = input_fixed_point
            padding = self.padding

        if self.use_kernel:
            return Conv2dShiftFunction.apply(input_padded, self.shift,
                                             self.sign, bias_fixed_point,
                                             self.conc_weight, self.stride,
                                             padding, self.dilation,
                                             self.groups, self.use_kernel,
                                             self.use_cuda)
        else:
            return torch.nn.functional.conv2d(input_padded, weight_ps,
                                              bias_fixed_point, self.stride,
                                              padding, self.dilation,
                                              self.groups)
        '''