Esempio n. 1
0
    def reset_parameters(self):
        # n = self.in_channels
        # init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        # init.kaiming_normal_(self.weight, mode='fan_out', nonlinearity='relu')
        #
        if self.distribution == 'kaiming_normal':
            init.kaiming_normal_(self.weight,
                                 mode='fan_out',
                                 nonlinearity='relu')
            self.set_mask()  # quantize
            self.s.data.uniform_(-1, 1)
            sign = ste.sign(round(self.s, self.sign_threshold))
            self.weight.data *= abs(sign)
        else:
            if self.distribution == 'uniform':
                self.p.data.uniform_(-self.min_p - 0.5, -1 + 0.5)
            elif self.distribution == 'normal':
                self.p.data.normal_(-self.min_p / 2, 1)
            self.p.data = ste.clamp(self.p.data, *self.shift_range)
            self.p.data = ste.round(self.p.data, 'deterministic')
            self.s.data.uniform_(-1, 1)
            sign = ste.sign(round(self.s, self.sign_threshold))
            # self.weight.data = torch.sign(self.weight) * (2 ** self.p.data)
            self.weight.data = sign * (2**self.p.data)

        if self.bias is not None:
            print('use bias')
            fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            init.uniform_(self.bias, -bound, bound)
Esempio n. 2
0
    def forward(self, input):
        self.shift.data = ste.clamp(self.shift.data, *self.shift_range)
        shift_rounded = ste.round(self.shift, self.rounding)
        sign_rounded_signed = ste.sign(ste.round(self.sign, self.rounding))
        weight_ps = ste.unsym_grad_mul(self.shift_base**shift_rounded, sign_rounded_signed)
        input_fixed_point = ste.round_fixed_point(input)
        if self.bias is not None:
            bias_fixed_point = ste.round_fixed_point(self.bias)
        else:
            bias_fixed_point = None

        if self.padding_mode == 'circular':
            expanded_padding = ((self.padding[1] + 1) // 2, self.padding[1] // 2,
                                (self.padding[0] + 1) // 2, self.padding[0] // 2)

            input_padded = F.pad(input_fixed_point, expanded_padding, mode='circular')
            padding =  _pair(0)
        else:
            input_padded = input_fixed_point
            padding = self.padding
                
        if self.use_kernel:
            return Conv2dShiftFunction.apply(input_padded, self.shift, self.sign, self.shift_base, bias_fixed_point, self.conc_weight, 
                                              self.stride, padding, self.dilation, self.groups, 
                                              self.use_kernel, self.use_cuda, self.rounding, self.shift_range)
        else:
            return torch.nn.functional.conv2d(input_padded, weight_ps, bias_fixed_point, 
                                              self.stride, padding, self.dilation, self.groups)
Esempio n. 3
0
    def forward(self, input):
        self.shift.data = ste.clamp(self.shift.data, *self.shift_range)
        shift_rounded = ste.round(self.shift, self.rounding)
        # print(self.threshold)
        if self.threshold is None:
            # print('default threshold')
            sign_rounded_signed = ste.sign(ste.round(self.sign, self.rounding))
        else:
            sign_rounded_signed = ste.sign(
                ste.myround(self.sign, self.threshold))
        weight_ps = ste.unsym_grad_mul(2**shift_rounded, sign_rounded_signed)
        # input_fixed_point = ste.round_fixed_point(input, quant_bits=self.quant_bits)
        if self.quant_bits > 0:
            input_fixed_point = ste.round_fixed_point(
                input, quant_bits=self.quant_bits)
        else:
            input_fixed_point = ste.round_fixed_point(
                input, quant_bits=self.quant_bits)
        if self.bias is not None:
            bias_fixed_point = ste.round_fixed_point(self.bias)
        else:
            bias_fixed_point = None

        if self.padding_mode == 'circular':
            expanded_padding = ((self.padding[1] + 1) // 2,
                                self.padding[1] // 2,
                                (self.padding[0] + 1) // 2,
                                self.padding[0] // 2)

            input_padded = F.pad(input_fixed_point,
                                 expanded_padding,
                                 mode='circular')
            padding = _pair(0)
        else:
            input_padded = input_fixed_point
            padding = self.padding

        if self.use_kernel:
            # if True:
            output = Conv2dShiftFunction.apply(
                input_padded, self.shift, self.sign, bias_fixed_point,
                self.conc_weight, self.stride, padding, self.dilation,
                self.groups, self.use_kernel, self.use_cuda, self.rounding,
                self.shift_range)
        else:
            output = torch.nn.functional.conv2d(input_padded, weight_ps,
                                                bias_fixed_point, self.stride,
                                                padding, self.dilation,
                                                self.groups)

        # quantize backpropogation
        if self.quant_bits > 0:
            output = quantize_grad(output,
                                   num_bits=self.quant_bits,
                                   flatten_dims=(1, -1))

        return output
Esempio n. 4
0
 def forward(self, input):
     self.shift.data = ste.clamp(self.shift.data, *self.shift_range)
     shift_rounded = ste.round(self.shift, rounding=self.rounding)
     sign_rounded_signed = ste.sign(ste.round(self.sign, rounding=self.rounding))
     weight_ps = ste.unsym_grad_mul(self.shift_base**shift_rounded, sign_rounded_signed)
     if self.use_kernel:
         return LinearShiftFunction.apply(input, self.shift, self.sign, self.bias, self.shift_base, self.conc_weight, self.use_kernel, self.use_cuda, self.rounding, self.shift_range)
     else:
         return torch.nn.functional.linear(input, weight_ps, self.bias)
Esempio n. 5
0
    def forward(self, input):
        self.shift.data = ste.clamp(self.shift.data, *self.shift_range)
        shift_rounded = ste.round(self.shift, rounding=self.rounding)
        sign_rounded_signed = ste.sign(
            ste.round(self.sign, rounding=self.rounding))
        weight_ps = ste.unsym_grad_mul(2**shift_rounded, sign_rounded_signed)

        # TODO: round bias and input to fixed point

        if self.use_kernel:
            return LinearShiftFunction.apply(input, self.shift, self.sign,
                                             self.bias, self.conc_weight,
                                             self.use_kernel, self.use_cuda,
                                             self.rounding, self.shift_range,
                                             self.act_integer_bits,
                                             self.act_fraction_bits)
        else:
            return torch.nn.functional.linear(input, weight_ps, self.bias)