Exemplo n.º 1
0
def shift_concat_input(module, grad_input, grad_output):
    if isinstance(module, nn.Conv2d) and first_element(grad_output[0]):
        shift = first_element(grad_output[0])
        print('shifted module {} by {} bits'.format(module.__class__.__name__, shift))
        module.norm += shift
        module.bias = nn.Parameter(roundnorm_reg(module.bias, shift))
        grad_input = tuple(torch.zeros_like(tensor) for tensor in grad_input)
        module.out_frac_bits[0] -= shift
        module.b_frac_bits[0] -= shift
    elif grad_output[0].sum() != 0:
        tmp = []
        for tensor in grad_input:
            if tensor is not None:
                tmp.append(torch.empty_like(tensor).fill_(first_element(grad_output[0])))
            else:
                tmp.append(None)
        grad_input = tuple(tmp)
        print('propagated through {}'.format(module.__class__.__name__))
    elif isinstance(module, Concat):
        print(module.norm)
        tmp = []
        for curr_norm, tensor in zip(module.norm, grad_input):
            tmp.append(torch.empty_like(tensor).fill_(curr_norm))
        module.norm = nn.Parameter(torch.Tensor([0 for _ in module.norm]))
        grad_input = tuple(tmp)
    elif module.__class__ in module_classes(nn) \
            and not isinstance(module, (nn.Sequential, nn.ModuleList)) \
            or module.__class__ in module_classes(gap_quantization.layers):
        tmp = []
        for tensor in grad_input:
            if tensor is not None:
                tmp.append(torch.zeros_like(tensor))
            else:
                tmp.append(None)  # for convolutions without biases
    return grad_input
Exemplo n.º 2
0
 def forward(self, inputs):
     inputs = F.avg_pool2d(inputs, self.kernel_size, self.stride, self.padding, self.ceil_mode,
                           self.count_include_pad)
     inputs = torch.floor_(inputs * self.kernel_size * self.kernel_size + 0.1)
     pool_factor = math.pow(2, 16) // math.pow(self.kernel_size, 2)
     bound = math.pow(2.0, self.bits - 1)
     min_val = -bound
     max_val = bound - 1
     return torch.clamp(roundnorm_reg(inputs * pool_factor, self.bits), min_val, max_val)
Exemplo n.º 3
0
 def forward(self, inputs):
     inputs = F.adaptive_avg_pool2d(inputs, self.output_size)
     mult = inputs.shape[2] * inputs.shape[3] // self.output_size[
         0] // self.output_size[1]
     inputs = torch.floor_(inputs * mult + 0.1)
     pool_factor = math.pow(2, 16) // mult
     bound = math.pow(2.0, self.bits - 1)
     min_val = -bound
     max_val = bound - 1
     return torch.clamp(roundnorm_reg(inputs * pool_factor, self.bits),
                        min_val, max_val)
Exemplo n.º 4
0
 def forward(self, inputs):
     self.weights = nn.ParameterList(  # pylint: disable=attribute-defined-outside-init
         [nn.Parameter(self.weight.data[:, i, :, :].unsqueeze_(1)) for i in range(self.weight.shape[1])])
     out = None
     for i in range(inputs.shape[1]):
         conv_res = F.conv2d(inputs[:, i, :, :].unsqueeze_(1), self.weights[i], None, self.stride,
                             self.padding, self.dilation, self.groups)
         if out is None:
             out = conv_res
         else:
             out += conv_res
     out += (self.bias * math.pow(2, self.norm)).view(1, -1, 1, 1).expand_as(out)
     out = gap8_clip(roundnorm_reg(out, self.norm), self.bits)
     return out
Exemplo n.º 5
0
 def forward(self, inp1, inp2):
     inputs = [inp1, inp2]
     for idx, _ in enumerate([inp1, inp2]):
         inputs[idx] = roundnorm_reg(inputs[idx], self.norm[idx])
     return torch.stack(inputs, dim=0).sum(dim=0)
Exemplo n.º 6
0
 def forward(self, inputs):
     for idx, _ in enumerate(inputs):
         inputs[idx] = roundnorm_reg(inputs[idx], self.norm[idx])
     return torch.cat(inputs, self.dim)