Ejemplo n.º 1
0
    def forward(self, input):
        if self.sparsity != 0:
            # apply mask
            self.weight.data = self.weight.data * self.weight_mask.data

        if self.quantize == True:
            # quantization v2
            input_q = self.quantize_input_fw(input, self.weight_bits)
            weight_qparams = calculate_qparams(self.weight,
                                               num_bits=self.weight_bits,
                                               flatten_dims=(1, -1),
                                               reduce_dim=None)
            self.qweight = quantize(self.weight, qparams=weight_qparams)
            bias_fixed_point = None
            output = F.conv2d(input_q, self.qweight, None, self.stride,
                              self.padding)
            output = quantize_grad(output,
                                   num_bits=self.weight_bits,
                                   flatten_dims=(1, -1))
        else:
            output = F.conv2d(input, self.weight, None, self.stride,
                              self.padding)
        if self.bias:
            output += self.b.unsqueeze(0).unsqueeze(2).unsqueeze(3)

        return output
Ejemplo n.º 2
0
    def forward(self, input):
        self.shift.data = ste.clamp(self.shift.data, *self.shift_range)
        shift_rounded = ste.round(self.shift, self.rounding)
        # print(self.threshold)
        if self.threshold is None:
            # print('default threshold')
            sign_rounded_signed = ste.sign(ste.round(self.sign, self.rounding))
        else:
            sign_rounded_signed = ste.sign(
                ste.myround(self.sign, self.threshold))
        weight_ps = ste.unsym_grad_mul(2**shift_rounded, sign_rounded_signed)
        # input_fixed_point = ste.round_fixed_point(input, quant_bits=self.quant_bits)
        if self.quant_bits > 0:
            input_fixed_point = ste.round_fixed_point(
                input, quant_bits=self.quant_bits)
        else:
            input_fixed_point = ste.round_fixed_point(
                input, quant_bits=self.quant_bits)
        if self.bias is not None:
            bias_fixed_point = ste.round_fixed_point(self.bias)
        else:
            bias_fixed_point = None

        if self.padding_mode == 'circular':
            expanded_padding = ((self.padding[1] + 1) // 2,
                                self.padding[1] // 2,
                                (self.padding[0] + 1) // 2,
                                self.padding[0] // 2)

            input_padded = F.pad(input_fixed_point,
                                 expanded_padding,
                                 mode='circular')
            padding = _pair(0)
        else:
            input_padded = input_fixed_point
            padding = self.padding

        if self.use_kernel:
            # if True:
            output = Conv2dShiftFunction.apply(
                input_padded, self.shift, self.sign, bias_fixed_point,
                self.conc_weight, self.stride, padding, self.dilation,
                self.groups, self.use_kernel, self.use_cuda, self.rounding,
                self.shift_range)
        else:
            output = torch.nn.functional.conv2d(input_padded, weight_ps,
                                                bias_fixed_point, self.stride,
                                                padding, self.dilation,
                                                self.groups)

        # quantize backpropogation
        if self.quant_bits > 0:
            output = quantize_grad(output,
                                   num_bits=self.quant_bits,
                                   flatten_dims=(1, -1))

        return output