def shift_concat_input(module, grad_input, grad_output): if isinstance(module, nn.Conv2d) and first_element(grad_output[0]): shift = first_element(grad_output[0]) print('shifted module {} by {} bits'.format(module.__class__.__name__, shift)) module.norm += shift module.bias = nn.Parameter(roundnorm_reg(module.bias, shift)) grad_input = tuple(torch.zeros_like(tensor) for tensor in grad_input) module.out_frac_bits[0] -= shift module.b_frac_bits[0] -= shift elif grad_output[0].sum() != 0: tmp = [] for tensor in grad_input: if tensor is not None: tmp.append(torch.empty_like(tensor).fill_(first_element(grad_output[0]))) else: tmp.append(None) grad_input = tuple(tmp) print('propagated through {}'.format(module.__class__.__name__)) elif isinstance(module, Concat): print(module.norm) tmp = [] for curr_norm, tensor in zip(module.norm, grad_input): tmp.append(torch.empty_like(tensor).fill_(curr_norm)) module.norm = nn.Parameter(torch.Tensor([0 for _ in module.norm])) grad_input = tuple(tmp) elif module.__class__ in module_classes(nn) \ and not isinstance(module, (nn.Sequential, nn.ModuleList)) \ or module.__class__ in module_classes(gap_quantization.layers): tmp = [] for tensor in grad_input: if tensor is not None: tmp.append(torch.zeros_like(tensor)) else: tmp.append(None) # for convolutions without biases return grad_input
def forward(self, inputs): inputs = F.avg_pool2d(inputs, self.kernel_size, self.stride, self.padding, self.ceil_mode, self.count_include_pad) inputs = torch.floor_(inputs * self.kernel_size * self.kernel_size + 0.1) pool_factor = math.pow(2, 16) // math.pow(self.kernel_size, 2) bound = math.pow(2.0, self.bits - 1) min_val = -bound max_val = bound - 1 return torch.clamp(roundnorm_reg(inputs * pool_factor, self.bits), min_val, max_val)
def forward(self, inputs): inputs = F.adaptive_avg_pool2d(inputs, self.output_size) mult = inputs.shape[2] * inputs.shape[3] // self.output_size[ 0] // self.output_size[1] inputs = torch.floor_(inputs * mult + 0.1) pool_factor = math.pow(2, 16) // mult bound = math.pow(2.0, self.bits - 1) min_val = -bound max_val = bound - 1 return torch.clamp(roundnorm_reg(inputs * pool_factor, self.bits), min_val, max_val)
def forward(self, inputs): self.weights = nn.ParameterList( # pylint: disable=attribute-defined-outside-init [nn.Parameter(self.weight.data[:, i, :, :].unsqueeze_(1)) for i in range(self.weight.shape[1])]) out = None for i in range(inputs.shape[1]): conv_res = F.conv2d(inputs[:, i, :, :].unsqueeze_(1), self.weights[i], None, self.stride, self.padding, self.dilation, self.groups) if out is None: out = conv_res else: out += conv_res out += (self.bias * math.pow(2, self.norm)).view(1, -1, 1, 1).expand_as(out) out = gap8_clip(roundnorm_reg(out, self.norm), self.bits) return out
def forward(self, inp1, inp2): inputs = [inp1, inp2] for idx, _ in enumerate([inp1, inp2]): inputs[idx] = roundnorm_reg(inputs[idx], self.norm[idx]) return torch.stack(inputs, dim=0).sum(dim=0)
def forward(self, inputs): for idx, _ in enumerate(inputs): inputs[idx] = roundnorm_reg(inputs[idx], self.norm[idx]) return torch.cat(inputs, self.dim)