示例#1
0
    def forward(self, input):       
        weight_q = ste.round_power_of_2(self.weight)
        input_fixed_point = ste.round_fixed_point(input)
        if self.bias is not None:
            bias_fixed_point = ste.round_fixed_point(self.bias)
        else:
            bias_fixed_point = None

        if self.padding_mode == 'circular':
            expanded_padding = ((self.padding[1] + 1) // 2, self.padding[1] // 2,
                                (self.padding[0] + 1) // 2, self.padding[0] // 2)

            input_padded = F.pad(input_fixed_point, expanded_padding, mode='circular')
            padding =  _pair(0)
        else:
            input_padded = input_fixed_point
            padding = self.padding

        if self.use_kernel:
            return Conv2dShiftQFunction.apply(input_padded, weight_q, bias_fixed_point, self.conc_weight, 
                                              self.stride, padding, self.dilation, self.groups, 
                                              self.use_kernel, self.use_cuda)
        else:
            return torch.nn.functional.conv2d(input_padded, weight_q, bias_fixed_point, 
                                              self.stride, padding, self.dilation, self.groups)
示例#2
0
    def forward(self, input):
        if self.shift_or_full == 'shift':
            self.weight.data = ste.clampabs(self.weight.data,
                                            2**self.shift_range[0],
                                            2**self.shift_range[1])
            weight_q = ste.round_power_of_2(self.weight, self.rounding)
            # print(weight_q)
            input_fixed_point = ste.round_fixed_point(input)
            if self.bias is not None:
                bias_fixed_point = ste.round_fixed_point(self.bias)
            else:
                bias_fixed_point = None
        elif self.shift_or_full == 'full':  # full precision training
            weight_q = self.weight
            input_fixed_point = ste.round_fixed_point(input)
            if self.bias is not None:
                bias_fixed_point = self.bias
            else:
                bias_fixed_point = None
        else:
            raise NotImplementedError('No such type!')

        if self.padding_mode == 'circular':
            expanded_padding = ((self.padding[1] + 1) // 2,
                                self.padding[1] // 2,
                                (self.padding[0] + 1) // 2,
                                self.padding[0] // 2)

            input_padded = F.pad(input_fixed_point,
                                 expanded_padding,
                                 mode='circular')
            padding = _pair(0)
        else:
            input_padded = input_fixed_point
            padding = self.padding

        if self.use_kernel:
            return Conv2dShiftQFunction.apply(input_padded, weight_q,
                                              bias_fixed_point,
                                              self.conc_weight, self.stride,
                                              padding, self.dilation,
                                              self.groups, self.use_kernel,
                                              self.use_cuda)
        else:
            return torch.nn.functional.conv2d(input_padded, weight_q,
                                              bias_fixed_point, self.stride,
                                              padding, self.dilation,
                                              self.groups)
示例#3
0
 def forward(self, input):
     weight_q = ste.round_power_of_2(self.weight)
     input_fixed_point = ste.round_fixed_point(input)
     if self.bias is not None:
         bias_fixed_point = ste.round_fixed_point(self.bias)
     else:
         bias_fixed_point = None
         
     if self.use_kernel:
         return LinearShiftQFunction.apply(input_fixed_point, weight_q, bias_fixed_point, self.conc_weight, self.use_kernel, self.use_cuda)
     else:
         out = input_fixed_point.mm(weight_q.t())
         if self.bias is not None:
             out += self.bias.unsqueeze(0).expand_as(out)
             
         return out
示例#4
0
 def forward(self, input):
     self.weight.data = ste.clampabs(self.weight.data, self.base**self.shift_range[0], self.base**self.shift_range[1])
     weight_q = ste.round_power_of_2(self.weight, self.base, self.rounding)
     input_fixed_point = ste.round_fixed_point(input)
     if self.bias is not None:
         bias_fixed_point = ste.round_fixed_point(self.bias)
     else:
         bias_fixed_point = None
         
     if self.use_kernel:
         return LinearShiftQFunction.apply(input_fixed_point, weight_q, self.base, bias_fixed_point, self.conc_weight, self.use_kernel, self.use_cuda)
     else:
         out = input_fixed_point.mm(weight_q.t())
         if self.bias is not None:
             out += self.bias.unsqueeze(0).expand_as(out)
         
         return out