class Batch_norm(torch.nn.Module): def __init__(self, num_features, eps=1e-5, momentum=0.1, training=True): super(Batch_norm, self).__init__() self.num_features = num_features self.eps = eps self.momentum = momentum self.training = training self.weight = Parameter(torch.Tensor(num_features)) self.bias = Parameter(torch.Tensor(num_features)) self.register_buffer('running_mean', torch.zeros(num_features)) self.register_buffer('running_var', torch.ones(num_features)) init.uniform_(self.weight) init.zeros_(self.bias) def forward(self, X): assert len(X.shape) in (2, 4) # dense layer batch norm if len(X.shape) == 2: mean = torch.mean(X, 0) variance = torch.mean((X - mean)**2, 0) if self.training: X_norm = (X - mean) * 1.0 / torch.sqrt(variance + self.eps) self.running_mean = self.running_mean * self.momentum + mean * ( 1.0 - self.momentum) self.running_var = self.running_var * self.momentum + variance * ( 1.0 - self.momentum) else: X_norm = (X - self.running_mean ) * 1.0 / torch.sqrt(self.running_var + self.eps) out = self.weight * X_norm + self.bias # conv layer batch norm else: B, C, H, W = X.shape mean = torch.mean(X, (0, 2, 3)) variance = torch.mean((X - mean.reshape((1, C, 1, 1)))**2, (0, 2, 3)) if self.training: X_norm = (X - mean.reshape((1, C, 1, 1))) * 1.0 / torch.sqrt( variance.reshape((1, C, 1, 1)) + self.eps) self.running_mean = self.running_mean * self.momentum + mean * ( 1.0 - self.momentum) self.running_var = self.running_var * self.momentum + variance * ( 1.0 - self.momentum) else: X_norm = (X - self.running_mean.reshape( (1, C, 1, 1))) * 1.0 / torch.sqrt( self.running_var.reshape((1, C, 1, 1)) + self.eps) out = self.weight.reshape( (1, C, 1, 1)) * X_norm + self.bias.reshape((1, C, 1, 1)) return out
class NormedConv2D(nn.Conv2d): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1): super(NormedConv2D, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, False) self.register_parameter('_normed_weight', None) self.register_forward_pre_hook(self._get_normed_weight) self._scale_weight = Parameter(torch.full((self.out_channels, ), 0.1)) self._scale_bias = Parameter(torch.zeros(self.out_channels)) self.register_parameter('fused_weight', None) self.register_parameter('fused_bias', None) self.register_forward_hook(self._get_fused_weights) @staticmethod def _get_normed_weight(self, *_): if self.training: weight = self.weight reshaped_size = (-1, ) + (1, ) * (len(weight.shape) - 1) weight_mean = weight.view(weight.size(0), -1).mean(1).view(*reshaped_size) weight = weight - weight_mean weight_std = weight.view(weight.size(0), -1).std(1).view(*reshaped_size) weight_std = torch.clamp(weight_std, 1e-3) self._normed_weight = Parameter(weight / weight_std.expand_as(weight)) @staticmethod def _get_fused_weights(self, *_): if self.training: reshaped_size = ( -1, ) + (1, ) * (len(self._normed_weight.shape) - 1) self.fused_weight = Parameter( self._scale_weight.view(*reshaped_size) * self._normed_weight) self.fused_bias = Parameter(self._scale_bias) def forward(self, x): if self.training: x = F.conv2d(x, self._normed_weight, self.bias, self.stride, self.padding, self.dilation, self.groups) x = x * self._scale_weight.reshape( -1, 1, 1) + self._scale_bias.reshape(-1, 1, 1) else: x = F.conv2d(x, self.fused_weight, self.fused_bias, self.stride, self.padding, self.dilation, self.groups) return x
class Scale(nn.Module): def __init__(self, num_features): super(Scale, self).__init__() self.num_features = num_features self.weight = Parameter(torch.ones(self.num_features)) self.bias = Parameter(torch.zeros(self.num_features)) def forward(self, x): reshaped_size = (-1, ) + (1, ) * (len(x.shape) - 2) return x * self.weight.reshape(*reshaped_size) + self.bias.reshape( *reshaped_size)
class WScaleLayer(nn.Module): """ Applies equalized learning rate to the preceding layer. """ def __init__(self, incoming): super(WScaleLayer, self).__init__() self.incoming = incoming self.scale = (torch.mean(self.incoming.weight.data**2))**0.5 self.incoming.weight.data.copy_(self.incoming.weight.data / self.scale) self.bias = None if self.incoming.bias is not None: self.bias = self.incoming.bias self.incoming.bias = None self.scale = Parameter(self.scale.reshape( -1)) # needed to move it to the gpu / FIX: Multi-GPU support def forward(self, x): x = self.scale * x if self.bias is not None: x += self.bias.view(1, self.bias.size()[0], 1, 1) return x def __repr__(self): param_str = '(incoming = %s)' % ( self.incoming.__class__.__name__ ) + f'(scale.is_cuda = {self.scale.is_cuda})' return self.__class__.__name__ + param_str
class BatchNorm(nn.Module): def __init__(self, num_features, eps=1e-5, momentum=0.1): super(BatchNorm, self).__init__() self.eps = eps self.momentum = momentum self.weight = Parameter(torch.Tensor(num_features)) self.bias = Parameter(torch.Tensor(num_features)) self.register_buffer('running_mean', torch.zeros(num_features)) self.register_buffer('running_var', torch.ones(num_features)) self.register_buffer('num_batches', torch.tensor(0, dtype=torch.long)) nn.init.uniform_(self.weight) nn.init.zeros_(self.bias) def forward(self, x): _, C, _, _ = x.shape if self.training: mean = torch.mean(x, dim=(0, 2, 3)) variance = torch.mean((x - mean.reshape((1, C, 1, 1)))**2, dim=(0, 2, 3)) self._update_running_stats(mean, variance, self.momentum) else: mean, variance = self.running_mean, self.running_var mean, variance = mean.reshape((1, C, 1, 1)), variance.reshape( (1, C, 1, 1)) x = (x - mean) / ((variance + self.eps)**0.5) x = self.weight.reshape((1, C, 1, 1)) * x + self.bias.reshape( (1, C, 1, 1)) return x def _update_running_stats(self, mean, variance, momentum): if self.num_batches == 0: self.running_mean = mean self.running_var = variance else: self.running_mean = ( 1 - momentum) * self.running_mean + momentum * mean self.running_var = ( 1 - momentum) * self.running_var + momentum * variance self.num_batches += 1
class GraphConvolution(nn.Module): r""" Simple GCN layer, similar to https://arxiv.org/abs/1609.02907 """ def __init__(self, in_features, out_features, bias=True): super(GraphConvolution, self).__init__() self.in_features = in_features self.out_features = out_features self.weight = Parameter(pt.FloatTensor(in_features, out_features)) if bias: self.bias = Parameter(pt.FloatTensor(out_features)) else: self.register_parameter('bias', None) self.reset_parameters() def reset_parameters(self): stdv = 1. / math.sqrt(self.weight.size(1)) self.weight.data.uniform_(-stdv, stdv) if self.bias is not None: self.bias.data.uniform_(-stdv, stdv) def forward(self, input:pt.Tensor, adj:pt.Tensor): if input.dim()>2: weight_size = self.weight.size() new_weight = self.weight.reshape(1,weight_size[0],-1).expand(input.size(0),weight_size[0],-1) support = pt.bmm(input, new_weight) output = pt.bmm(adj, support) else: support = pt.mm(input, self.weight) output = pt.mm(adj, support) if self.bias is not None: return output + self.bias else: return output def __repr__(self): return self.__class__.__name__ + ' (' \ + str(self.in_features) + ' -> ' \ + str(self.out_features) + ')'
class DiverseRegDCConv2d(nn.Module): def __init__(self, embedding_in, num, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, alpha=0.5, beta=0.5, lamda=0.1): super(DiverseRegDCConv2d, self).__init__() self.num = num self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = kernel_size self.stride = stride self.padding = padding self.dilation = dilation self.groups = groups self.register_buffer('alpha', torch.tensor(alpha)) self.register_buffer('beta', torch.tensor(beta)) self.register_buffer('lamda', torch.tensor(lamda)) self.batch_shape_loss = BetaCDFBatchShapingLoss(self.alpha, self.beta) self.routing_fc = nn.Linear(embedding_in, num) self.weight = Parameter(torch.Tensor(out_channels * in_channels * kernel_size * kernel_size // groups, num)) if bias: self.bias = Parameter(torch.Tensor(out_channels)) else: self.register_parameter('bias', None) self._initialize_weights() self.register_buffer('inputs_se', None) def _initialize_weights(self): for name, m in self.named_modules(): if isinstance(m, DiverseRegDCConv2d): c, h, w = m.in_channels, m.kernel_size, m.kernel_size nn.init.normal_(m.weight, 0, math.sqrt(1.0 / (c * h * w))) if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) nn.init.constant_(m.running_mean, 1) elif isinstance(m, nn.BatchNorm1d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) nn.init.constant_(m.running_mean, 1) elif isinstance(m, nn.Linear): nn.init.normal_(m.weight, 0, 0.01) nn.init.constant_(m.bias, 0) def forward(self, inputs): ''' if inputs.shape[-1] == 7: print('inputs_se:', inputs_se) batchsize, channel, height, width = inputs.shape weight = F.linear(inputs_se, self.weight) weight = weight.reshape(batchsize * self.out_channels, self.in_channels // self.groups, self.kernel_size, self.kernel_size) inputs = inputs.reshape(1, batchsize * channel, height, width) outputs = F.conv2d(inputs, weight, None, self.stride, self.padding, self.dilation, groups=self.groups * batchsize) height, width = outputs.shape[2:] outputs = outputs.reshape(batchsize, self.out_channels, height, width) if self.bias is not None: outputs = outputs + self.bias.reshape(1, -1, 1, 1) ''' batchsize, channel, height, width = inputs.shape inputs_se = self.inputs_se # we need a pre_forward_hook to get this # add batch shaped loss if self.training: batch_shaped_loss = self.batch_shape_loss(inputs_se) * self.lamda batch_shaped_loss.backward(retain_graph=True) #rint(self.routing_fc.weight.grad.sum()) weight = F.linear(inputs_se, self.weight) weight = weight.reshape(batchsize * self.out_channels, self.in_channels // self.groups, self.kernel_size, self.kernel_size) inputs = inputs.reshape(1, batchsize * channel, height, width) outputs = F.conv2d(inputs, weight, None, self.stride, self.padding, self.dilation, groups=self.groups * batchsize) height, width = outputs.shape[2:] outputs = outputs.reshape(batchsize, self.out_channels, height, width) if self.bias is not None: outputs = outputs + self.bias.reshape(1, -1, 1, 1) return outputs
class Quantization(nn.Module): def __init__(self, network_controller: NetworkQuantizationController, is_signed: bool, alpha: float = 0.9, weights_values=None, efficient=True): """ HMQ Block :param network_controller: The network controller :param is_signed: is this tensor signed :param alpha: the thresholds I.I.R value :param weights_values: In the case of weights quantized this is the tensors values :param efficient: Boolean flag stating to use the memory efficient """ super(Quantization, self).__init__() self.weights_values = weights_values if weights_values is None: self.tensor_type = TensorType.ACTIVATION self.tensor_size = None else: self.tensor_type = TensorType.COEFFICIENT self.tensor_size = np.prod(weights_values.shape) self.network_controller = network_controller self.alpha = alpha self.is_signed_tensor = torch.Tensor([float(is_signed)]).cuda() if efficient: self.base_q = EfficientBaseQuantization() else: self.base_q = BaseQuantization() self.gumbel_softmax = GumbelSoftmax(ste=network_controller.ste) self.bits_vector = None self.mv_shifts = None self.base_thresholds = None self.nb_shifts_points_div = None self.search_matrix = None def init_quantization_coefficients(self): """ This function initlized the HMQ parameters :return: None """ init_threshold = 0 n_bits_list, thresholds_shifts = self.network_controller.quantization_config.get_thresholds_bitwidth_lists(self) if self.is_coefficient(): init_threshold = torch.pow(2.0, self.weights_values.abs().max().log2().ceil() + 1).item() if self.is_activation(): n_bits_list = [8] self._init_quantization_params(n_bits_list, thresholds_shifts, init_threshold) self._init_search_matrix(self.network_controller.p, n_bits_list, len(thresholds_shifts)) def _init_quantization_params(self, bit_list, thresholds_shifts, init_thresholds): self.update_bits_list(bit_list) self.mv_shifts = Parameter(torch.Tensor(thresholds_shifts), requires_grad=False) self.thresholds_shifts_points_div = Parameter(torch.pow(2.0, self.mv_shifts), requires_grad=False) self.base_thresholds = Parameter(torch.Tensor(1), requires_grad=False) init.constant_(self.base_thresholds, init_thresholds) def _init_search_matrix(self, p, n_bits_list, n_thresholds_options): n_channels = 1 sm = -np.random.rand(n_channels, len(n_bits_list), n_thresholds_options, 1) n = np.prod(sm.shape) sm[:, 0, 0, 0] = np.log(p * n / (1 - p)) # for single channels self.search_matrix = Parameter(torch.Tensor(sm)) def _get_quantization_probability_matrix(self, batch_size=1, noise_disable=False): return self.gumbel_softmax(self.search_matrix, self.network_controller.temperature, batch_size=batch_size, noise_disable=noise_disable) def _get_bits_probability(self, batch_size=1, noise_disable=False): p = self._get_quantization_probability_matrix(batch_size=batch_size, noise_disable=noise_disable) return p.sum(dim=4).sum(dim=3).sum(dim=1) def _update_iir(self, x): # update scale using statistics if self.is_activation(): if self.tensor_size is None: self.tensor_size = np.prod(x.shape[1:]) # Remove batch axis max_value = x.abs().max() self.base_thresholds.data.add_(self.alpha * (max_value - self.base_thresholds)) def _calculate_expected_delta(self, p, max_scale): max_scales = max_scale / (self.thresholds_shifts_points_div.reshape(1, -1)) max_scales = max_scales.reshape(1, 1, 1, -1, 1) nb_shifts = self.nb_shifts_points_div.reshape(1, 1, -1, 1, 1) * torch.pow(2.0, -self.is_signed_tensor) delta = (max_scales / nb_shifts) * p return delta.sum(dim=-1).sum(dim=-1).sum(dim=-1).sum(dim=-1) def _calculate_expected_threshold(self, p, max_threshold): p_t = p.sum(dim=4).sum(dim=2).sum(dim=1) thresholds = max_threshold / (self.thresholds_shifts_points_div.reshape(1, -1)) return (p_t * thresholds).sum(dim=-1) def _calculate_expected_q_point(self, p, max_threshold, expected_delta, param_shape): t = self._calculate_expected_threshold(p, max_threshold=max_threshold).reshape(*param_shape) return t / expected_delta def _built_param_shape(self, x): random_size = x.shape[0] if self.is_activation() else x.shape[1] # select random if len(x.shape) == 4: param_shape = [random_size, -1, 1, 1] if self.is_activation() else [-1, random_size, 1, 1] elif len(x.shape) == 2: param_shape = [random_size, -1] if self.is_activation() else [-1, random_size] else: raise NotImplemented return random_size, param_shape def forward(self, x): """ The forward function of the HMQ module :param x: Input tensor x :return: A tensor after quantization """ if self.network_controller.statistics_update: self._update_iir(x) max_threshold = torch.pow(2.0, torch.ceil(torch.log2(self.base_thresholds.detach().abs()))).detach() # read scale if self.training and self.network_controller.temperature > 0: random_size, param_shape = self._built_param_shape(x) # axis according to tensor type (activation randomization is done over the batch axis, # coeff the randomization is done over the input channel axis) p = self._get_quantization_probability_matrix(batch_size=random_size) delta = self._calculate_expected_delta(p, max_threshold).reshape(*param_shape) q_points = self._calculate_expected_q_point(p, max_threshold, delta, param_shape).reshape(*param_shape) return self.base_q(x, delta, q_points, self.is_signed_tensor) else: # negative temperature/ infernce p = self._get_quantization_probability_matrix(batch_size=1, noise_disable=True).squeeze(dim=0) bits_index = torch.argmax(self._get_bits_probability(batch_size=1, noise_disable=True).squeeze(dim=0)) max_index = torch.argmax(p[:, bits_index, :, 0], dim=-1) q_points = self.nb_shifts_points_div[bits_index] * torch.pow(2.0, -self.is_signed_tensor) max_scales = (max_threshold / self.thresholds_shifts_points_div.reshape(1, -1)).detach() delta = torch.stack( [(max_scales[i, mv] / q_points) for i, mv in enumerate(max_index)]).flatten().detach() return self.base_q(x, delta, q_points, self.is_signed_tensor) def get_bit_width(self): """ This function return the selected bit-width :return: the bit-width of the HMQ """ return self.bits_vector[torch.argmax(self._get_bits_probability(noise_disable=True).flatten())].item() def get_expected_bits(self): """ This function return the expected bit-width :return: the expected bit-width of the HMQ """ return (self.bits_vector * self._get_bits_probability(noise_disable=True)).sum() def get_float_size(self): """ This function return the size of floating point tensor in bits Note: we assume 32 bits for floating point values :return: the floating point tensor size """ return 32 * self.tensor_size def get_fxp_size(self): """ This function return the size of quantized tensor in bits :return: the quantized tensor size """ return self.get_bit_width() * self.tensor_size def is_activation(self): """ This function return the boolean stating if this module quantize activation :return: a boolean flag stating if this activation quantization """ return self.tensor_type == TensorType.ACTIVATION def is_coefficient(self): """ This function return the boolean stating if this module quantize coefficient :return: a boolean flag stating if this coefficient quantization """ return self.tensor_type == TensorType.COEFFICIENT def get_expected_tensor_size(self): """ This function return the expected size of quantized tensor in bits :return: the expected size of quantized tensor """ return torch.Tensor([self.tensor_size]).cuda() def update_bits_list(self, bits_list): """ This function update the HMQ bit-width list :param bits_list: A list of new bit-widths :return: None """ if self.bits_vector is None: self.bits_vector = Parameter(torch.Tensor(bits_list), requires_grad=False) self.nb_shifts_points_div = Parameter( torch.pow(2.0, self.bits_vector), # - int(q_node.is_signed) requires_grad=False) # move to init else: self.bits_vector.add_(torch.Tensor(bits_list).cuda() - self.bits_vector) self.nb_shifts_points_div.add_(torch.pow(2.0, self.bits_vector) - self.nb_shifts_points_div)
class _BatchNorm(Module): _version = 2 __constants__ = [ 'track_running_stats', 'momentum', 'eps', 'weight', 'bias', 'running_mean', 'running_var', 'num_batches_tracked' ] def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True): super(_BatchNorm, self).__init__() self.num_features = num_features self.eps = eps self.momentum = momentum self.affine = affine self.track_running_stats = track_running_stats if self.affine: self.weight = Parameter(torch.Tensor(num_features)) self.bias = Parameter(torch.Tensor(num_features)) else: self.register_parameter('weight', None) self.register_parameter('bias', None) if self.track_running_stats: self.register_buffer('running_mean', torch.zeros(num_features)) self.register_buffer('running_var', torch.ones(num_features)) self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long)) else: self.register_parameter('running_mean', None) self.register_parameter('running_var', None) self.register_parameter('num_batches_tracked', None) self.reset_parameters() def reset_running_stats(self): if self.track_running_stats: self.running_mean.zero_() self.running_var.fill_(1) self.num_batches_tracked.zero_() def reset_parameters(self): self.reset_running_stats() if self.affine: init.uniform_(self.weight) init.zeros_(self.bias) def _check_input_dim(self, input): raise NotImplementedError def forward(self, x): self._check_input_dim(x) return_shape = x.shape y = x.transpose(0, 1) y = y.contiguous().view(x.size(1), -1) mu = y.mean(dim=1) sigma2 = y.var(dim=1, unbiased=False) if self.training is not True: y = y - self.running_mean.view(-1, 1) y = y / (self.running_var.view(-1, 1)**.5 + self.eps) else: if self.track_running_stats is True: with torch.no_grad(): self.running_mean = ( 1 - self.momentum) * self.running_mean + self.momentum * mu self.running_var = ( 1 - self.momentum ) * self.running_var + self.momentum * sigma2 y = y - mu.view(-1, 1) y = y / (sigma2.view(-1, 1)**.5 + self.eps) y = self.weight.view(-1, 1) * y + self.bias.view(-1, 1) self.mu = mu self.var = sigma2 self.input = x.transpose(0, 1).reshape(x.size(1), -1) self.return_shape = return_shape N, C, H, W = return_shape return y.reshape(C, N, H, W).transpose(0, 1) def _grad_cam(self, grad_output, requires_activation=False): ''' In grad_cam, the only thing we need is the gradient w.r.t. x, and we don't need to calculate the gradient of mu, var, gamma, and beta. ''' X, mu, var, gamma = self.input, self.mu.reshape( -1, 1), self.var.reshape(-1, 1), self.weight.reshape(-1, 1) N, C, H, W = self.return_shape n = N * W * H grad_output = grad_output.transpose(0, 1).reshape(C, -1) if self.training is not True: dX = grad_output * self.weight.reshape( -1, 1) / (self.running_var.reshape(-1, 1) + 1e-8) else: X_mu = X - mu std_inv = 1. / torch.sqrt(var + self.eps) dX_norm = grad_output * gamma dvar = torch.sum(dX_norm * X_mu, dim=1, keepdim=True) * -.5 * std_inv**3 dmu = torch.sum(dX_norm * -std_inv, dim=1, keepdim=True) + dvar * torch.mean( -2. * X_mu, dim=1, keepdim=True) dX = (dX_norm * std_inv) + (dvar * 2 * X_mu / n) + (dmu / n) return dX.reshape(C, N, H, W).transpose(0, 1), self.input.reshape( C, N, H, W).transpose(0, 1) def _simple_lrp(self, R, labels): return R def _epsilon_lrp(self, R, epsilon): ''' Since there is only one (or several equally strong) dominant activations, default to _simple_lrp ''' return self._simple_lrp(R) def _ww_lrp(self, R): ''' There are no weights to use. default to _flat_lrp(R) ''' return self._flat_lrp(R) def _flat_lrp(self, R): ''' distribute relevance for each output evenly to the output neurons' receptive fields. ''' return self._simple_lrp(R) def _alphabeta_lrp(self, R, labels): ''' Since there is only one (or several equally strong) dominant activations, default to _simple_lrp ''' return self._simple_lrp(R, labels) def _composite_lrp(self, R, labels): ''' Since there is only one (or several equally strong) dominant activations, default to _simple_lrp ''' return self._simple_lrp(R, labels) def extra_repr(self): return '{num_features}, eps={eps}, momentum={momentum}, affine={affine}, ' \ 'track_running_stats={track_running_stats}'.format(**self.__dict__) def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): version = local_metadata.get('version', None) if (version is None or version < 2) and self.track_running_stats: # at version 2: added num_batches_tracked buffer # this should have a default value of 0 num_batches_tracked_key = prefix + 'num_batches_tracked' if num_batches_tracked_key not in state_dict: state_dict[num_batches_tracked_key] = torch.tensor( 0, dtype=torch.long) super(_BatchNorm, self)._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
class DenseFCLayer(torch.nn.Module): def __init__(self, n_inputs=None, n_outputs=None, weights: torch.Tensor = None, use_biases=True, activation=None): super(DenseFCLayer, self).__init__() if n_inputs is not None and n_outputs is not None: self.n_inputs = n_inputs self.n_outputs = n_outputs self._activation = activation self._initial_weights = None self._weights = Parameter(torch.Tensor(n_inputs, n_outputs)) self._init_weights() self._mask = torch.ones_like(self._weights) self._initial_weights = self._weights.clone() self.use_biases = use_biases if self.use_biases: self._biases = Parameter(torch.Tensor(n_outputs)) self._init_biases() elif weights is not None: self.n_inputs = weights.size(0) self.n_outputs = weights.size(1) self._activation = activation self._initial_weights = weights self._weights = Parameter(weights) self._mask = torch.ones_like(self._weights) self._biases = Parameter(torch.Tensor(self.n_outputs)) self._init_biases() else: raise ValueError( "DenseFClayer class accepts either n_inputs/n_outputs or weights" ) def _init_weights(self): # Note the difference between init functions # torch.nn.init.xavier_normal_(self._weights) # torch.nn.init.xavier_uniform_(self._weights) # torch.nn.init.kaiming_normal_(self._weights) torch.nn.init.kaiming_uniform_(self._weights) def _init_biases(self): torch.nn.init.zeros_(self._biases) def prune_by_threshold(self, thr): self._mask *= (torch.abs(self._weights) >= thr).float() def prune_by_rank(self, rank): weights_val = self._weights[self._mask == 1] sorted_abs_weights = torch.sort(torch.abs(weights_val))[0] thr = sorted_abs_weights[rank] self.prune_by_threshold(thr) def prune_by_pct(self, pct): prune_idx = int(self.n_weights * pct) self.prune_by_rank(prune_idx) def prune_by_pct_taylor(self, pct): prune_idx = int(self.n_weights * pct) # by abs val wg = torch.abs(self._weights[self._mask == 1] * self._weights.grad[self._mask == 1]) sorted_wg = torch.sort(wg)[0] thr = sorted_wg[prune_idx] print(thr) self._mask *= (torch.abs(self._weights * self._weights.grad) > thr).float() # by val # wg = self._weights[self._mask == 1] * self._weights.grad[self._mask == 1] # sorted_wg = torch.sort(wg)[0] # thr = sorted_wg[prune_idx] # self._mask *= (self._weights * self._weights.grad >= thr).float() def random_prune_by_pct(self, pct): prune_idx = int(self.n_weights * pct) rand = torch.rand(size=self._mask.size(), device=self._mask.device) rand_val = rand[self._mask == 1] sorted_abs_rand = torch.sort(rand_val)[0] thr = sorted_abs_rand[prune_idx] self._mask *= (rand >= thr).float() def reinitialize(self): self._weights = Parameter(self._initial_weights) self._init_biases() # biases are reinitialized def to_sparse(self) -> SparseFCLayer: return SparseFCLayer((self._weights * self._mask).t().to_sparse(), self._biases.reshape((-1, 1)), self._activation) @classmethod def from_sparse(cls, s_layer: SparseFCLayer): return cls(weights=s_layer.weights.t().to_dense(), activation=s_layer.activation) def to_device(self, device: torch.device): self._initial_weights = self._initial_weights.to(device) self._mask = self._mask.to(device) def forward(self, inputs: torch.Tensor, use_mask=True): masked_weights = self._weights if use_mask: masked_weights = self._weights * self._mask if self.use_biases: ret = torch.addmm(self._biases, inputs, masked_weights) else: ret = torch.mm(inputs, masked_weights) return ret if self._activation is None else self._activation(ret) @property def mask(self): return self._mask @property def weights(self): return self._weights @property def activation(self): return self._activation @property def n_weights(self): return torch.nonzero(self._mask).size(0) @property def biases(self): if self.use_biases: return self._biases else: return None def __str__(self): return "DenseFClayer with size {} and activation {}".format( (self.n_inputs, self.n_outputs), self._activation)
class QuantConv2d(nn.Conv2d): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, bit=32, extern_init=False, init_model=nn.Sequential()): super(QuantConv2d, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias) self.bit = bit self.pwr_coef = 2**(bit - 1) self.Round_w = RoundFn_LLSQ.apply self.Round_b = RoundFn_Bias.apply self.bias_flag = bias #self.alpha_w = Variable(torch.rand( out_channels,1,1,1)).cuda() # self.alpha_w = Parameter(torch.rand( out_channels)) if bit < 0: self.alpha_w = None else: self.alpha_w = Parameter(torch.rand(out_channels)) #self.alpha_qfn = quan_fn_alpha() nn.init.kaiming_normal_(self.weight, mode='fan_out', nonlinearity='relu') if extern_init: param = list(init_model.parameters()) self.weight = Parameter(param[0]) if bias: self.bias = Parameter(param[1]) if bit < 0: self.init_state = 0 else: self.register_buffer('init_state', torch.zeros(1)) # self.init_state = 0 def forward(self, x): if self.bit == 32: return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) else: w_reshape = self.weight.reshape([self.weight.shape[0], -1]).transpose(0, 1) if self.training and self.init_state == 0: self.alpha_w.data.copy_( w_reshape.detach().abs().max(dim=0)[0] / self.pwr_coef) self.init_state.fill_(1) #self.init_state = 1 #assert not torch.isnan(x).any(), "Conv2d Input should not be 'nan'" alpha_w = self.alpha_w #self.alpha_qfn(self.alpha_w) #if torch.isnan(self.alpha_w).any() or torch.isinf(self.alpha_w).any(): # assert not torch.isnan(wq).any(), self.alpha_w # assert not torch.isinf(wq).any(), self.alpha_w wq = self.Round_w(w_reshape, alpha_w, self.pwr_coef, self.bit) w_q = wq.transpose(0, 1).reshape(self.weight.shape) if self.bias_flag == True: LLSQ_b = self.Round_b(self.bias, alpha_w, self.pwr_coef, self.bit) else: LLSQ_b = self.bias # assert not torch.isnan(self.weight).any(), "Weight should not be 'nan'" # if torch.isnan(wq).any() or torch.isinf(wq).any(): # print(self.alpha_w) # assert not torch.isnan(wq).any(), "Conv2d Weights should not be 'nan'" # assert not torch.isinf(wq).any(), "Conv2d Weights should not be 'nan'" return F.conv2d(x, w_q, LLSQ_b, self.stride, self.padding, self.dilation, self.groups) def extra_repr(self): s_prefix = super(QuantConv2d, self).extra_repr() if self.alpha_w is None: return '{}, fake'.format(s_prefix) return '{}'.format(s_prefix)
class LinearBn1d(nn.modules.linear.Linear, nni._FusedModule): r""" A LinearBn1d module is a module fused from Linear and BatchNorm1d, attached with FakeQuantize modules for weight, used in quantization aware training. We combined the interface of :class:`torch.nn.Linear` and :class:torch.nn.BatchNorm1d`. Similar to :class:`torch.nn.Linear`, with FakeQuantize modules initialized to default. Attributes: freeze_bn: weight_fake_quant: fake quant module for weight """ def __init__( self, # Linear args in_features, out_features, bias=True, # BatchNorm1d args # num_features: out_features eps=1e-05, momentum=0.1, # affine: True # track_running_stats: True # Args for this module freeze_bn=False, qconfig=None): nn.modules.linear.Linear.__init__(self, in_features, out_features, bias) assert qconfig, 'qconfig must be provded for QAT module' self.qconfig = qconfig self.freeze_bn = freeze_bn if self.training else True self.bn = nn.BatchNorm1d(out_features, eps, momentum, True, True) self.weight_fake_quant = self.qconfig.weight() if bias: self.bias = Parameter(torch.empty(out_features)) else: self.register_parameter('bias', None) self.reset_bn_parameters() # this needs to be called after reset_bn_parameters, # as they modify the same state if self.training: if freeze_bn: self.freeze_bn_stats() else: self.update_bn_stats() else: self.freeze_bn_stats() def reset_running_stats(self): self.bn.reset_running_stats() def reset_bn_parameters(self): self.bn.reset_running_stats() init.uniform_(self.bn.weight) init.zeros_(self.bn.bias) def reset_parameters(self): super(LinearBn1d, self).reset_parameters() def update_bn_stats(self): self.freeze_bn = False self.bn.training = True return self def freeze_bn_stats(self): self.freeze_bn = True self.bn.training = False return self def forward(self, input): assert self.bn.running_var is not None # Scale the linear weights by BN's running statistics to reduce # weight jitter, see https://arxiv.org/pdf/1806.08342.pdf, page 18 # for motivation. # # Instead of # # x1 = F.linear(x0, fq(w), b) # x2 = self.bn(x1) # # We have # # # scale the weight by previous batch's running statistics # scale_factor = bn.w / bn.running_std_from_prev_batch # # do the linear transformation without bias # x1_scaled = F.linear(x0, fq(w * scale_factor), 0) # # reverse the scaling and add original bias # x1_orig = x1_scaled / scale_factor + b # x2 = self.bn(x1_orig) running_std = torch.sqrt(self.bn.running_var + self.bn.eps) scale_factor = self.bn.weight / running_std weight_shape = [1] * len(self.weight.shape) weight_shape[0] = -1 bias_shape = [1] * len(self.weight.shape) bias_shape[1] = -1 scaled_weight = self.weight_fake_quant( self.weight * scale_factor.reshape(weight_shape)) if self.bias is not None: zero_bias = torch.zeros_like(self.bias) else: zero_bias = torch.zeros(self.out_features, device=scaled_weight.device) linear_out = F.linear(input, scaled_weight, zero_bias) linear_out_orig = linear_out / scale_factor.reshape(bias_shape) if self.bias is not None: linear_out_orig = linear_out_orig + self.bias.reshape(bias_shape) bn_out = self.bn(linear_out_orig) return bn_out def train(self, mode=True): """ Batchnorm's training behavior is using the self.training flag. Prevent changing it if BN is frozen. This makes sure that calling `model.train()` on a model with a frozen BN will behave properly. """ self.training = mode if not self.freeze_bn: for module in self.children(): module.train(mode) return self @classmethod def from_float(cls, mod): r"""Create a qat module from a float module or qparams_dict Args: `mod' a float module, either produced by torch.ao.quantization utilities or directly from user """ assert type(mod) == nni.LinearBn1d, 'qat.' + cls.__name__ + \ '.from_float only works for ' + nni.LinearBn1d.__name__ assert hasattr( mod, 'qconfig'), 'Input float module must have qconfig defined' assert mod.qconfig, 'Input float module must have a valid config' qconfig = mod.qconfig linear, bn = mod[0], mod[1] qat_linearbn = cls(linear.in_features, linear.out_features, linear.bias is not None, bn.eps, bn.momentum, False, qconfig) qat_linearbn.weight = linear.weight qat_linearbn.bias = linear.bias qat_linearbn.bn.weight = bn.weight qat_linearbn.bn.bias = bn.bias qat_linearbn.bn.running_mean = bn.running_mean qat_linearbn.bn.running_var = bn.running_var qat_linearbn.bn.num_batches_tracked = bn.num_batches_tracked return qat_linearbn def to_float(self): linear = torch.nn.Linear(self.in_features, self.out_features) linear.weight, linear.bias = fuse_linear_bn_weights( self.weight, self.bias, self.bn.running_mean, self.bn.running_var, self.bn.eps, self.bn.weight, self.bn.bias) return linear
class DenseLinear(nn.Module): __constants__ = ['in_features', 'out_features'] def __init__(self, in_features, out_features, use_bias=True, use_mask=True, **kwargs): super(DenseLinear, self).__init__() self.in_features = in_features self.out_features = out_features self.weight = Parameter(torch.Tensor(out_features, in_features)) if use_bias: self.bias = Parameter(torch.Tensor(out_features)) else: self.register_parameter('bias', None) self.reset_parameters(**kwargs) # self._initial_weight = self.weight.data.clone() # self._initial_bias = self.bias.data.clone() if use_bias else None self.use_mask = use_mask self.mask = torch.ones_like(self.weight, dtype=torch.bool) def reset_parameters(self, **kwargs): if len(kwargs.keys()) == 0: # default init, see https://pytorch.org/docs/stable/_modules/torch/nn/modules/linear.html#Linear init.kaiming_uniform_(self.weight, a=math.sqrt(5)) else: init.kaiming_uniform_(self.weight, **kwargs) if self.bias is not None: # default init, see https://pytorch.org/docs/stable/_modules/torch/nn/modules/linear.html#Linear fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight) bound = 1 / math.sqrt(fan_in) init.uniform_(self.bias, -bound, bound) def forward(self, inp: torch.Tensor): masked_weight = self.weight * self.mask if self.use_mask else self.weight return nn.functional.linear(inp, masked_weight, self.bias) def extra_repr(self): return 'in_features={}, out_features={}, bias={}'.format( self.in_features, self.out_features, self.bias is not None) def prune_by_threshold(self, thr): self.mask *= (self.weight.abs() >= thr) def prune_by_rank(self, rank): if rank == 0: return weight_val = self.weight[self.mask == 1.] sorted_abs_weight = weight_val.abs().sort()[0] thr = sorted_abs_weight[rank] self.prune_by_threshold(thr) def prune_by_pct(self, pct): prune_idx = int(self.num_weight * pct) self.prune_by_rank(prune_idx) def retain_by_threshold(self, thr): self.mask *= (self.weight.abs() >= thr) def retain_by_rank(self, rank): weights_val = self.weight[self.mask == 1.] sorted_abs_weights = weights_val.abs().sort(descending=True)[0] thr = sorted_abs_weights[rank] self.retain_by_threshold(thr) def random_prune_by_pct(self, pct): prune_idx = int(self.num_weight * pct) rand = torch.rand(size=self.mask.size(), device=self.mask.device) rand_val = rand[self.mask == 1] sorted_abs_rand = rand_val.sort()[0] thr = sorted_abs_rand[prune_idx] self.mask *= (rand >= thr) # def reinitialize(self): # self.weight = Parameter(self._initial_weight) # if self._initial_bias is not None: # self.bias = Parameter(self._initial_bias) def to_sparse(self, transpose=False) -> SparseLinear: """ by chance, some entries with mask = 1 can have a 0 value. Thus, the to_sparse methods give a different size there's no efficient way to solve it yet """ sparse_bias = None if self.bias is None else self.bias.reshape((-1, 1)) sparse_linear = SparseLinear((self.weight * self.mask).to_sparse(), sparse_bias, self.mask) if transpose: sparse_linear.transpose = True return sparse_linear def move_data(self, device: torch.device): self.mask = self.mask.to(device) def to(self, *args, **kwargs): device = torch._C._nn._parse_to(*args, **kwargs)[0] if device is not None: self.move_data(device) return super(DenseLinear, self).to(*args, **kwargs) @property def num_weight(self) -> int: return self.mask.sum().item()
class CConv2d(nn.Module): def __init__(self, num, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True): super(CConv2d, self).__init__() self.num = num self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = kernel_size self.stride = stride self.padding = padding self.dilation = dilation self.groups = groups self.routing_fc = nn.Linear(in_channels, num) self.weight = Parameter(torch.Tensor(out_channels * in_channels * kernel_size * kernel_size // groups, num)) if bias: self.bias = Parameter(torch.Tensor(out_channels)) else: self.register_parameter('bias', None) self._initialize_weights() def _initialize_weights(self): for name, m in self.named_modules(): if isinstance(m, CConv2d): c, h, w = m.in_channels, m.kernel_size, m.kernel_size nn.init.normal_(m.weight, 0, math.sqrt(1.0 / (c * h * w))) if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) nn.init.constant_(m.running_mean, 1) elif isinstance(m, nn.BatchNorm1d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) nn.init.constant_(m.running_mean, 1) elif isinstance(m, nn.Linear): nn.init.normal_(m.weight, 0, 0.01) nn.init.constant_(m.bias, 0) def forward(self, inputs): ''' if inputs.shape[-1] == 7: print('inputs_se:', inputs_se) ''' x = inputs inputs_se = x.reshape(x.shape[0], x.shape[1], -1).mean(dim=-1, keepdim=False) inputs_se = F.sigmoid(self.routing_fc(inputs_se)) batchsize, channel, height, width = inputs.shape weight = F.linear(inputs_se, self.weight) weight = weight.reshape(batchsize * self.out_channels, self.in_channels // self.groups, self.kernel_size, self.kernel_size) inputs = inputs.reshape(1, batchsize * channel, height, width) outputs = F.conv2d(inputs, weight, None, self.stride, self.padding, self.dilation, groups=self.groups * batchsize) height, width = outputs.shape[2:] outputs = outputs.reshape(batchsize, self.out_channels, height, width) if self.bias is not None: outputs = outputs + self.bias.reshape(1, -1, 1, 1) return outputs
class _ConvBnNd(nn.modules.conv._ConvNd): _version = 2 def __init__( self, # ConvNd args in_channels, out_channels, kernel_size, stride, padding, dilation, transposed, output_padding, groups, bias, padding_mode, # BatchNormNd args # num_features: out_channels eps=1e-05, momentum=0.1, # affine: True # track_running_stats: True # Args for this module freeze_bn=False, qconfig=None): nn.modules.conv._ConvNd.__init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, transposed, output_padding, groups, False, padding_mode) assert qconfig, 'qconfig must be provided for QAT module' self.qconfig = qconfig self.freeze_bn = freeze_bn if self.training else True self.bn = nn.BatchNorm2d(out_channels, eps, momentum, True, True) self.weight_fake_quant = self.qconfig.weight() if bias: self.bias = Parameter(torch.Tensor(out_channels)) else: self.register_parameter('bias', None) self.reset_bn_parameters() # this needs to be called after reset_bn_parameters, # as they modify the same state if self.training: if freeze_bn: self.freeze_bn_stats() else: self.update_bn_stats() else: self.freeze_bn_stats() def reset_running_stats(self): self.bn.reset_running_stats() def reset_bn_parameters(self): self.bn.reset_running_stats() init.uniform_(self.bn.weight) init.zeros_(self.bn.bias) # note: below is actully for conv, not BN if self.bias is not None: fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight) bound = 1 / math.sqrt(fan_in) init.uniform_(self.bias, -bound, bound) def reset_parameters(self): super(_ConvBnNd, self).reset_parameters() def update_bn_stats(self): self.freeze_bn = False self.bn.training = True return self def freeze_bn_stats(self): self.freeze_bn = True self.bn.training = False return self def _forward(self, input): running_std = torch.sqrt(self.bn.running_var + self.bn.eps) scale_factor = self.bn.weight / running_std scaled_weight = self.weight_fake_quant( self.weight * scale_factor.reshape([-1, 1, 1, 1])) # this does not include the conv bias conv = self._conv_forward(input, scaled_weight) conv_orig = conv / scale_factor.reshape([1, -1, 1, 1]) if self.bias is not None: conv_orig = conv_orig + self.bias.reshape([1, -1, 1, 1]) conv = self.bn(conv_orig) return conv def extra_repr(self): # TODO(jerryzh): extend return super(_ConvBnNd, self).extra_repr() def forward(self, input): return self._forward(input) def train(self, mode=True): """ Batchnorm's training behavior is using the self.training flag. Prevent changing it if BN is frozen. This makes sure that calling `model.train()` on a model with a frozen BN will behave properly. """ self.training = mode if not self.freeze_bn: for module in self.children(): module.train(mode) return self # ===== Serialization version history ===== # # Version 1/None # self # |--- weight : Tensor # |--- bias : Tensor # |--- gamma : Tensor # |--- beta : Tensor # |--- running_mean : Tensor # |--- running_var : Tensor # |--- num_batches_tracked : Tensor # # Version 2 # self # |--- weight : Tensor # |--- bias : Tensor # |--- bn : Module # |--- weight : Tensor (moved from v1.self.gamma) # |--- bias : Tensor (moved from v1.self.beta) # |--- running_mean : Tensor (moved from v1.self.running_mean) # |--- running_var : Tensor (moved from v1.self.running_var) # |--- num_batches_tracked : Tensor (moved from v1.self.num_batches_tracked) def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): version = local_metadata.get('version', None) if version is None or version == 1: # BN related parameters and buffers were moved into the BN module for v2 v2_to_v1_names = { 'bn.weight': 'gamma', 'bn.bias': 'beta', 'bn.running_mean': 'running_mean', 'bn.running_var': 'running_var', 'bn.num_batches_tracked': 'num_batches_tracked', } for v2_name, v1_name in v2_to_v1_names.items(): if prefix + v1_name in state_dict: state_dict[prefix + v2_name] = state_dict[prefix + v1_name] state_dict.pop(prefix + v1_name) elif prefix + v2_name in state_dict: # there was a brief period where forward compatibility # for this module was broken (between # https://github.com/pytorch/pytorch/pull/38478 # and https://github.com/pytorch/pytorch/pull/38820) # and modules emitted the v2 state_dict format while # specifying that version == 1. This patches the forward # compatibility issue by allowing the v2 style entries to # be used. pass elif strict: missing_keys.append(prefix + v2_name) super(_ConvBnNd, self)._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) @classmethod def from_float(cls, mod): r"""Create a qat module from a float module or qparams_dict Args: `mod` a float module, either produced by torch.quantization utilities or directly from user """ assert type(mod) == cls._FLOAT_MODULE, 'qat.' + cls.__name__ + '.from_float only works for ' + \ cls._FLOAT_MODULE.__name__ assert hasattr( mod, 'qconfig'), 'Input float module must have qconfig defined' assert mod.qconfig, 'Input float module must have a valid qconfig' qconfig = mod.qconfig conv, bn = mod[0], mod[1] qat_convbn = cls(conv.in_channels, conv.out_channels, conv.kernel_size, conv.stride, conv.padding, conv.dilation, conv.groups, conv.bias is not None, conv.padding_mode, bn.eps, bn.momentum, False, qconfig) qat_convbn.weight = conv.weight qat_convbn.bias = conv.bias qat_convbn.bn.weight = bn.weight qat_convbn.bn.bias = bn.bias qat_convbn.bn.running_mean = bn.running_mean qat_convbn.bn.running_var = bn.running_var qat_convbn.bn.num_batches_tracked = bn.num_batches_tracked return qat_convbn
class _ConvBnNd(nn.modules.conv._ConvNd): def __init__(self, # ConvNd args in_channels, out_channels, kernel_size, stride, padding, dilation, transposed, output_padding, groups, bias, padding_mode, # BatchNormNd args # num_features: out_channels eps=1e-05, momentum=0.1, # affine: True # track_running_stats: True # Args for this module freeze_bn=False, qconfig=None): nn.modules.conv._ConvNd.__init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, transposed, output_padding, groups, False, padding_mode) assert qconfig, 'qconfig must be provided for QAT module' self.qconfig = qconfig self.freeze_bn = freeze_bn if self.training else True self.bn = nn.BatchNorm2d(out_channels, eps, momentum, True, True) self.activation_post_process = self.qconfig.activation() self.weight_fake_quant = self.qconfig.weight() if bias: self.bias = Parameter(torch.Tensor(out_channels)) else: self.register_parameter('bias', None) self.reset_bn_parameters() # this needs to be called after reset_bn_parameters, # as they modify the same state if self.training: if freeze_bn: self.freeze_bn_stats() else: self.update_bn_stats() else: self.freeze_bn_stats() def reset_running_stats(self): self.bn.reset_running_stats() def reset_bn_parameters(self): self.bn.reset_running_stats() init.uniform_(self.bn.weight) init.zeros_(self.bn.bias) # note: below is actully for conv, not BN if self.bias is not None: fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight) bound = 1 / math.sqrt(fan_in) init.uniform_(self.bias, -bound, bound) def reset_parameters(self): super(_ConvBnNd, self).reset_parameters() def update_bn_stats(self): self.freeze_bn = False self.bn.training = True return self def freeze_bn_stats(self): self.freeze_bn = True self.bn.training = False return self def _forward(self, input): running_std = torch.sqrt(self.bn.running_var + self.bn.eps) scale_factor = self.bn.weight / running_std scaled_weight = self.weight_fake_quant(self.weight * scale_factor.reshape([-1, 1, 1, 1])) # this does not include the conv bias conv = self._conv_forward(input, scaled_weight) conv_orig = conv / scale_factor.reshape([1, -1, 1, 1]) if self.bias is not None: conv_orig = conv_orig + self.bias.reshape([1, -1, 1, 1]) conv = self.bn(conv_orig) return conv def extra_repr(self): # TODO(jerryzh): extend return super(_ConvBnNd, self).extra_repr() def forward(self, input): return self.activation_post_process(self._forward(input)) def train(self, mode=True): """ Batchnorm's training behavior is using the self.training flag. Prevent changing it if BN is frozen. This makes sure that calling `model.train()` on a model with a frozen BN will behave properly. """ self.training = mode if not self.freeze_bn: for module in self.children(): module.train(mode) return self @classmethod def from_float(cls, mod, qconfig=None): r"""Create a qat module from a float module or qparams_dict Args: `mod` a float module, either produced by torch.quantization utilities or directly from user """ assert type(mod) == cls._FLOAT_MODULE, 'qat.' + cls.__name__ + '.from_float only works for ' + \ cls._FLOAT_MODULE.__name__ if not qconfig: assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined' assert mod.qconfig, 'Input float module must have a valid qconfig' qconfig = mod.qconfig conv, bn = mod[0], mod[1] qat_convbn = cls(conv.in_channels, conv.out_channels, conv.kernel_size, conv.stride, conv.padding, conv.dilation, conv.groups, conv.bias is not None, conv.padding_mode, bn.eps, bn.momentum, False, qconfig) qat_convbn.weight = conv.weight qat_convbn.bias = conv.bias qat_convbn.bn.weight = bn.weight qat_convbn.bn.bias = bn.bias qat_convbn.bn.running_mean = bn.running_mean qat_convbn.bn.running_var = bn.running_var qat_convbn.bn.num_batches_tracked = bn.num_batches_tracked return qat_convbn
class DCTConvolution(torch.nn.Conv2d): """ Instantiate a learnable convolution operator - the trainable weights are the coefficients of the DCT decomposition of the convolution operator (i.e. this change is bijective for convex training procedures!) """ def __init__(self, in_channels, out_channels, kernel_size, bias=False, mean=False, dilation=1): self.mean = mean padding = int(np.floor(kernel_size / 2)) super().__init__(in_channels, out_channels, kernel_size, padding=padding, bias=bias, dilation=dilation) self.initialize_DCT() torch.nn.init.orthogonal_(self.weight, gain=1) def initialize_DCT(self): num_basis_functions = self.kernel_size[0] * self.kernel_size[1] dct_basis = self.weight.new_zeros(num_basis_functions, 1, *self.kernel_size) b = 0 # enumerate basis functions for b1 in range(self.kernel_size[0]): for b2 in range(self.kernel_size[1]): for i in range(self.kernel_size[0]): for j in range(self.kernel_size[1]): dct_basis[b, 0, i, j] = (np.cos(np.pi / self.kernel_size[0] * (i + 0.5) * b1) * np.cos(np.pi / self.kernel_size[1] * (j + 0.5) * b2)) b += 1 if not self.mean: dct_basis = dct_basis[1:, 0:1, :, :] num_weights = (num_basis_functions - 1) * self.in_channels self.weight = Parameter( self.weight.new_zeros(self.out_channels, num_weights, 1, 1)) else: num_weights = num_basis_functions * self.in_channels self.weight = Parameter( self.weight.new_zeros(self.out_channels, num_weights, 1, 1)) self.register_buffer( 'dct_basis', torch.cat([dct_basis] * self.in_channels, 0) / torch.norm(dct_basis)) init.kaiming_uniform_(self.weight, a=np.sqrt(5)) def forward(self, input, direction='op'): if direction == 'op': dct_response = F.conv2d(input, self.dct_basis, None, self.stride, self.padding, self.dilation, groups=self.in_channels) return F.conv2d(dct_response, self.weight, self.bias, self.stride, padding=0, dilation=self.dilation, groups=self.groups) elif direction == 't': input_weighted = F.conv_transpose2d(input, self.weight, None, self.stride, output_padding=0, padding=0, dilation=self.dilation, groups=self.groups) return F.conv_transpose2d(input_weighted, self.dct_basis, None, self.stride, output_padding=0, padding=self.padding, dilation=self.dilation, groups=self.in_channels) else: raise ValueError('Invalid Direction') def return_filters(self): with torch.no_grad(): num_basis_functions = self.dct_basis.shape[0] // self.in_channels color_weights = self.weight.reshape(self.out_channels, num_basis_functions, self.in_channels, 1, 1) return ( self.dct_basis[0:num_basis_functions, :, :, :].unsqueeze(0) * color_weights).sum(1) def normest(self): # return self.weight.norm(dim=[2, 3]).sum() / self.groups return normest(self, self.in_channels, verbose=False)