示例#1
0
    def forward(self, input):
        self.shift.data = ste.clamp(self.shift.data, *self.shift_range)
        shift_rounded = ste.round(self.shift, self.rounding)
        sign_rounded_signed = ste.sign(ste.round(self.sign, self.rounding))
        weight_ps = ste.unsym_grad_mul(self.shift_base**shift_rounded, sign_rounded_signed)
        input_fixed_point = ste.round_fixed_point(input)
        if self.bias is not None:
            bias_fixed_point = ste.round_fixed_point(self.bias)
        else:
            bias_fixed_point = None

        if self.padding_mode == 'circular':
            expanded_padding = ((self.padding[1] + 1) // 2, self.padding[1] // 2,
                                (self.padding[0] + 1) // 2, self.padding[0] // 2)

            input_padded = F.pad(input_fixed_point, expanded_padding, mode='circular')
            padding =  _pair(0)
        else:
            input_padded = input_fixed_point
            padding = self.padding
                
        if self.use_kernel:
            return Conv2dShiftFunction.apply(input_padded, self.shift, self.sign, self.shift_base, bias_fixed_point, self.conc_weight, 
                                              self.stride, padding, self.dilation, self.groups, 
                                              self.use_kernel, self.use_cuda, self.rounding, self.shift_range)
        else:
            return torch.nn.functional.conv2d(input_padded, weight_ps, bias_fixed_point, 
                                              self.stride, padding, self.dilation, self.groups)
示例#2
0
    def forward(self, input):       
        weight_q = ste.round_power_of_2(self.weight)
        input_fixed_point = ste.round_fixed_point(input)
        if self.bias is not None:
            bias_fixed_point = ste.round_fixed_point(self.bias)
        else:
            bias_fixed_point = None

        if self.padding_mode == 'circular':
            expanded_padding = ((self.padding[1] + 1) // 2, self.padding[1] // 2,
                                (self.padding[0] + 1) // 2, self.padding[0] // 2)

            input_padded = F.pad(input_fixed_point, expanded_padding, mode='circular')
            padding =  _pair(0)
        else:
            input_padded = input_fixed_point
            padding = self.padding

        if self.use_kernel:
            return Conv2dShiftQFunction.apply(input_padded, weight_q, bias_fixed_point, self.conc_weight, 
                                              self.stride, padding, self.dilation, self.groups, 
                                              self.use_kernel, self.use_cuda)
        else:
            return torch.nn.functional.conv2d(input_padded, weight_q, bias_fixed_point, 
                                              self.stride, padding, self.dilation, self.groups)
示例#3
0
    def forward(self, input):
        self.shift.data = ste.clamp(self.shift.data, *self.shift_range)
        shift_rounded = ste.round(self.shift, self.rounding)
        # print(self.threshold)
        if self.threshold is None:
            # print('default threshold')
            sign_rounded_signed = ste.sign(ste.round(self.sign, self.rounding))
        else:
            sign_rounded_signed = ste.sign(
                ste.myround(self.sign, self.threshold))
        weight_ps = ste.unsym_grad_mul(2**shift_rounded, sign_rounded_signed)
        # input_fixed_point = ste.round_fixed_point(input, quant_bits=self.quant_bits)
        if self.quant_bits > 0:
            input_fixed_point = ste.round_fixed_point(
                input, quant_bits=self.quant_bits)
        else:
            input_fixed_point = ste.round_fixed_point(
                input, quant_bits=self.quant_bits)
        if self.bias is not None:
            bias_fixed_point = ste.round_fixed_point(self.bias)
        else:
            bias_fixed_point = None

        if self.padding_mode == 'circular':
            expanded_padding = ((self.padding[1] + 1) // 2,
                                self.padding[1] // 2,
                                (self.padding[0] + 1) // 2,
                                self.padding[0] // 2)

            input_padded = F.pad(input_fixed_point,
                                 expanded_padding,
                                 mode='circular')
            padding = _pair(0)
        else:
            input_padded = input_fixed_point
            padding = self.padding

        if self.use_kernel:
            # if True:
            output = Conv2dShiftFunction.apply(
                input_padded, self.shift, self.sign, bias_fixed_point,
                self.conc_weight, self.stride, padding, self.dilation,
                self.groups, self.use_kernel, self.use_cuda, self.rounding,
                self.shift_range)
        else:
            output = torch.nn.functional.conv2d(input_padded, weight_ps,
                                                bias_fixed_point, self.stride,
                                                padding, self.dilation,
                                                self.groups)

        # quantize backpropogation
        if self.quant_bits > 0:
            output = quantize_grad(output,
                                   num_bits=self.quant_bits,
                                   flatten_dims=(1, -1))

        return output
示例#4
0
    def forward(self, input):
        if self.shift_or_full == 'shift':
            self.weight.data = ste.clampabs(self.weight.data,
                                            2**self.shift_range[0],
                                            2**self.shift_range[1])
            weight_q = ste.round_power_of_2(self.weight, self.rounding)
            # print(weight_q)
            input_fixed_point = ste.round_fixed_point(input)
            if self.bias is not None:
                bias_fixed_point = ste.round_fixed_point(self.bias)
            else:
                bias_fixed_point = None
        elif self.shift_or_full == 'full':  # full precision training
            weight_q = self.weight
            input_fixed_point = ste.round_fixed_point(input)
            if self.bias is not None:
                bias_fixed_point = self.bias
            else:
                bias_fixed_point = None
        else:
            raise NotImplementedError('No such type!')

        if self.padding_mode == 'circular':
            expanded_padding = ((self.padding[1] + 1) // 2,
                                self.padding[1] // 2,
                                (self.padding[0] + 1) // 2,
                                self.padding[0] // 2)

            input_padded = F.pad(input_fixed_point,
                                 expanded_padding,
                                 mode='circular')
            padding = _pair(0)
        else:
            input_padded = input_fixed_point
            padding = self.padding

        if self.use_kernel:
            return Conv2dShiftQFunction.apply(input_padded, weight_q,
                                              bias_fixed_point,
                                              self.conc_weight, self.stride,
                                              padding, self.dilation,
                                              self.groups, self.use_kernel,
                                              self.use_cuda)
        else:
            return torch.nn.functional.conv2d(input_padded, weight_q,
                                              bias_fixed_point, self.stride,
                                              padding, self.dilation,
                                              self.groups)
示例#5
0
 def forward(self, input):
     weight_q = ste.round_power_of_2(self.weight)
     input_fixed_point = ste.round_fixed_point(input)
     if self.bias is not None:
         bias_fixed_point = ste.round_fixed_point(self.bias)
     else:
         bias_fixed_point = None
         
     if self.use_kernel:
         return LinearShiftQFunction.apply(input_fixed_point, weight_q, bias_fixed_point, self.conc_weight, self.use_kernel, self.use_cuda)
     else:
         out = input_fixed_point.mm(weight_q.t())
         if self.bias is not None:
             out += self.bias.unsqueeze(0).expand_as(out)
             
         return out
示例#6
0
 def forward(self, input):
     self.weight.data = ste.clampabs(self.weight.data, self.base**self.shift_range[0], self.base**self.shift_range[1])
     weight_q = ste.round_power_of_2(self.weight, self.base, self.rounding)
     input_fixed_point = ste.round_fixed_point(input)
     if self.bias is not None:
         bias_fixed_point = ste.round_fixed_point(self.bias)
     else:
         bias_fixed_point = None
         
     if self.use_kernel:
         return LinearShiftQFunction.apply(input_fixed_point, weight_q, self.base, bias_fixed_point, self.conc_weight, self.use_kernel, self.use_cuda)
     else:
         out = input_fixed_point.mm(weight_q.t())
         if self.bias is not None:
             out += self.bias.unsqueeze(0).expand_as(out)
         
         return out
示例#7
0
    def forward(self, input):
        shift_rounded = ste.round(self.shift)
        sign_rounded_signed = ste.sign(ste.round(self.sign))
        weight_ps = ste.unsym_grad_mul(
            2**shift_rounded, sign_rounded_signed
        )  # 2**utils.stochastic_rounding(shift) * sign.round().sign()
        input_fixed_point = ste.round_fixed_point(input)
        if self.bias is not None:
            bias_fixed_point = ste.round_fixed_point(self.bias)
        else:
            bias_fixed_point = None

        if self.padding_mode == 'circular':
            expanded_padding = ((self.padding[1] + 1) // 2,
                                self.padding[1] // 2,
                                (self.padding[0] + 1) // 2,
                                self.padding[0] // 2)

            input_padded = F.pad(input_fixed_point,
                                 expanded_padding,
                                 mode='circular')
            padding = _pair(0)
        else:
            input_padded = input_fixed_point
            padding = self.padding

        if self.use_kernel:
            return Conv2dShiftFunction.apply(input_padded, self.shift,
                                             self.sign, bias_fixed_point,
                                             self.conc_weight, self.stride,
                                             padding, self.dilation,
                                             self.groups, self.use_kernel,
                                             self.use_cuda)
        else:
            return torch.nn.functional.conv2d(input_padded, weight_ps,
                                              bias_fixed_point, self.stride,
                                              padding, self.dilation,
                                              self.groups)
        '''