Beispiel #1
0
class Batch_norm(torch.nn.Module):
    def __init__(self, num_features, eps=1e-5, momentum=0.1, training=True):
        super(Batch_norm, self).__init__()
        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum
        self.training = training

        self.weight = Parameter(torch.Tensor(num_features))
        self.bias = Parameter(torch.Tensor(num_features))

        self.register_buffer('running_mean', torch.zeros(num_features))
        self.register_buffer('running_var', torch.ones(num_features))

        init.uniform_(self.weight)
        init.zeros_(self.bias)

    def forward(self, X):
        assert len(X.shape) in (2, 4)

        # dense layer batch norm
        if len(X.shape) == 2:
            mean = torch.mean(X, 0)
            variance = torch.mean((X - mean)**2, 0)
            if self.training:
                X_norm = (X - mean) * 1.0 / torch.sqrt(variance + self.eps)
                self.running_mean = self.running_mean * self.momentum + mean * (
                    1.0 - self.momentum)
                self.running_var = self.running_var * self.momentum + variance * (
                    1.0 - self.momentum)
            else:
                X_norm = (X - self.running_mean
                          ) * 1.0 / torch.sqrt(self.running_var + self.eps)
            out = self.weight * X_norm + self.bias
        # conv layer batch norm
        else:
            B, C, H, W = X.shape

            mean = torch.mean(X, (0, 2, 3))
            variance = torch.mean((X - mean.reshape((1, C, 1, 1)))**2,
                                  (0, 2, 3))

            if self.training:
                X_norm = (X - mean.reshape((1, C, 1, 1))) * 1.0 / torch.sqrt(
                    variance.reshape((1, C, 1, 1)) + self.eps)
                self.running_mean = self.running_mean * self.momentum + mean * (
                    1.0 - self.momentum)
                self.running_var = self.running_var * self.momentum + variance * (
                    1.0 - self.momentum)
            else:
                X_norm = (X - self.running_mean.reshape(
                    (1, C, 1, 1))) * 1.0 / torch.sqrt(
                        self.running_var.reshape((1, C, 1, 1)) + self.eps)

            out = self.weight.reshape(
                (1, C, 1, 1)) * X_norm + self.bias.reshape((1, C, 1, 1))

        return out
Beispiel #2
0
class NormedConv2D(nn.Conv2d):
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride=1,
                 padding=0,
                 dilation=1,
                 groups=1):
        super(NormedConv2D,
              self).__init__(in_channels, out_channels, kernel_size, stride,
                             padding, dilation, groups, False)

        self.register_parameter('_normed_weight', None)
        self.register_forward_pre_hook(self._get_normed_weight)

        self._scale_weight = Parameter(torch.full((self.out_channels, ), 0.1))
        self._scale_bias = Parameter(torch.zeros(self.out_channels))

        self.register_parameter('fused_weight', None)
        self.register_parameter('fused_bias', None)
        self.register_forward_hook(self._get_fused_weights)

    @staticmethod
    def _get_normed_weight(self, *_):
        if self.training:
            weight = self.weight
            reshaped_size = (-1, ) + (1, ) * (len(weight.shape) - 1)
            weight_mean = weight.view(weight.size(0),
                                      -1).mean(1).view(*reshaped_size)
            weight = weight - weight_mean
            weight_std = weight.view(weight.size(0),
                                     -1).std(1).view(*reshaped_size)
            weight_std = torch.clamp(weight_std, 1e-3)
            self._normed_weight = Parameter(weight /
                                            weight_std.expand_as(weight))

    @staticmethod
    def _get_fused_weights(self, *_):
        if self.training:
            reshaped_size = (
                -1, ) + (1, ) * (len(self._normed_weight.shape) - 1)
            self.fused_weight = Parameter(
                self._scale_weight.view(*reshaped_size) * self._normed_weight)
            self.fused_bias = Parameter(self._scale_bias)

    def forward(self, x):
        if self.training:
            x = F.conv2d(x, self._normed_weight, self.bias, self.stride,
                         self.padding, self.dilation, self.groups)
            x = x * self._scale_weight.reshape(
                -1, 1, 1) + self._scale_bias.reshape(-1, 1, 1)
        else:
            x = F.conv2d(x, self.fused_weight, self.fused_bias, self.stride,
                         self.padding, self.dilation, self.groups)
        return x
Beispiel #3
0
class Scale(nn.Module):
    def __init__(self, num_features):
        super(Scale, self).__init__()
        self.num_features = num_features
        self.weight = Parameter(torch.ones(self.num_features))
        self.bias = Parameter(torch.zeros(self.num_features))

    def forward(self, x):
        reshaped_size = (-1, ) + (1, ) * (len(x.shape) - 2)
        return x * self.weight.reshape(*reshaped_size) + self.bias.reshape(
            *reshaped_size)
Beispiel #4
0
class WScaleLayer(nn.Module):
    """
    Applies equalized learning rate to the preceding layer.
    """
    def __init__(self, incoming):
        super(WScaleLayer, self).__init__()
        self.incoming = incoming
        self.scale = (torch.mean(self.incoming.weight.data**2))**0.5
        self.incoming.weight.data.copy_(self.incoming.weight.data / self.scale)
        self.bias = None
        if self.incoming.bias is not None:
            self.bias = self.incoming.bias
            self.incoming.bias = None
        self.scale = Parameter(self.scale.reshape(
            -1))  # needed to move it to the gpu / FIX: Multi-GPU support

    def forward(self, x):
        x = self.scale * x
        if self.bias is not None:
            x += self.bias.view(1, self.bias.size()[0], 1, 1)
        return x

    def __repr__(self):
        param_str = '(incoming = %s)' % (
            self.incoming.__class__.__name__
        ) + f'(scale.is_cuda = {self.scale.is_cuda})'
        return self.__class__.__name__ + param_str
Beispiel #5
0
class BatchNorm(nn.Module):
    def __init__(self, num_features, eps=1e-5, momentum=0.1):
        super(BatchNorm, self).__init__()
        self.eps = eps
        self.momentum = momentum
        self.weight = Parameter(torch.Tensor(num_features))
        self.bias = Parameter(torch.Tensor(num_features))
        self.register_buffer('running_mean', torch.zeros(num_features))
        self.register_buffer('running_var', torch.ones(num_features))
        self.register_buffer('num_batches', torch.tensor(0, dtype=torch.long))
        nn.init.uniform_(self.weight)
        nn.init.zeros_(self.bias)

    def forward(self, x):
        _, C, _, _ = x.shape
        if self.training:
            mean = torch.mean(x, dim=(0, 2, 3))
            variance = torch.mean((x - mean.reshape((1, C, 1, 1)))**2,
                                  dim=(0, 2, 3))
            self._update_running_stats(mean, variance, self.momentum)
        else:
            mean, variance = self.running_mean, self.running_var
        mean, variance = mean.reshape((1, C, 1, 1)), variance.reshape(
            (1, C, 1, 1))
        x = (x - mean) / ((variance + self.eps)**0.5)
        x = self.weight.reshape((1, C, 1, 1)) * x + self.bias.reshape(
            (1, C, 1, 1))
        return x

    def _update_running_stats(self, mean, variance, momentum):
        if self.num_batches == 0:
            self.running_mean = mean
            self.running_var = variance
        else:
            self.running_mean = (
                1 - momentum) * self.running_mean + momentum * mean
            self.running_var = (
                1 - momentum) * self.running_var + momentum * variance
        self.num_batches += 1
Beispiel #6
0
class GraphConvolution(nn.Module):
	r"""
	Simple GCN layer, similar to https://arxiv.org/abs/1609.02907
	"""
	def __init__(self, in_features, out_features, bias=True):
		super(GraphConvolution, self).__init__()
		self.in_features = in_features
		self.out_features = out_features
		self.weight = Parameter(pt.FloatTensor(in_features, out_features))
		if bias:
			self.bias = Parameter(pt.FloatTensor(out_features))
		else:
			self.register_parameter('bias', None)
		self.reset_parameters()

	def reset_parameters(self):
		stdv = 1. / math.sqrt(self.weight.size(1))
		self.weight.data.uniform_(-stdv, stdv)
		if self.bias is not None:
			self.bias.data.uniform_(-stdv, stdv)

	def forward(self, input:pt.Tensor, adj:pt.Tensor):
		if input.dim()>2:
			weight_size = self.weight.size()
			new_weight = self.weight.reshape(1,weight_size[0],-1).expand(input.size(0),weight_size[0],-1)
			support = pt.bmm(input, new_weight)
			output = pt.bmm(adj, support)
		else:
			support = pt.mm(input, self.weight)
			output = pt.mm(adj, support)
		if self.bias is not None:
			return output + self.bias
		else:
			return output

	def __repr__(self):
		return self.__class__.__name__ + ' (' \
			   + str(self.in_features) + ' -> ' \
			   + str(self.out_features) + ')'
class DiverseRegDCConv2d(nn.Module):
    def __init__(self, embedding_in, num, in_channels, out_channels, kernel_size, stride=1,
                 padding=0, dilation=1, groups=1, bias=True,
                 alpha=0.5, beta=0.5, lamda=0.1):
        super(DiverseRegDCConv2d, self).__init__()
        self.num = num
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.groups = groups
        self.register_buffer('alpha', torch.tensor(alpha))
        self.register_buffer('beta', torch.tensor(beta))
        self.register_buffer('lamda', torch.tensor(lamda))

        self.batch_shape_loss = BetaCDFBatchShapingLoss(self.alpha,
                                                        self.beta)

        self.routing_fc = nn.Linear(embedding_in, num)
        self.weight = Parameter(torch.Tensor(out_channels * in_channels * kernel_size * kernel_size // groups, num))
        if bias:
            self.bias = Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)
        self._initialize_weights()
        self.register_buffer('inputs_se', None)
    def _initialize_weights(self):
        for name, m in self.named_modules():

            if isinstance(m, DiverseRegDCConv2d):
                c, h, w = m.in_channels, m.kernel_size, m.kernel_size
                nn.init.normal_(m.weight, 0, math.sqrt(1.0 / (c * h * w)))
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
                nn.init.constant_(m.running_mean, 1)
            elif isinstance(m, nn.BatchNorm1d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
                nn.init.constant_(m.running_mean, 1)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)
    def forward(self, inputs):
        '''
        if inputs.shape[-1] == 7:
            print('inputs_se:', inputs_se)
        batchsize, channel, height, width = inputs.shape
        weight = F.linear(inputs_se, self.weight)
        weight = weight.reshape(batchsize * self.out_channels, self.in_channels // self.groups, self.kernel_size, self.kernel_size)
        inputs = inputs.reshape(1, batchsize * channel, height, width)
        outputs = F.conv2d(inputs, weight, None, self.stride, self.padding, self.dilation, groups=self.groups * batchsize)
        height, width = outputs.shape[2:]
        outputs = outputs.reshape(batchsize, self.out_channels, height, width)
        if self.bias is not None:
            outputs = outputs + self.bias.reshape(1, -1, 1, 1)
        '''

        batchsize, channel, height, width = inputs.shape

        inputs_se = self.inputs_se # we need a pre_forward_hook to get this

        # add batch shaped loss
        if self.training:
            batch_shaped_loss = self.batch_shape_loss(inputs_se) * self.lamda
            batch_shaped_loss.backward(retain_graph=True)
            #rint(self.routing_fc.weight.grad.sum())

        weight = F.linear(inputs_se, self.weight)
        weight = weight.reshape(batchsize * self.out_channels, self.in_channels // self.groups, self.kernel_size,
                                self.kernel_size)

        inputs = inputs.reshape(1, batchsize * channel, height, width)
        outputs = F.conv2d(inputs, weight, None, self.stride, self.padding, self.dilation,
                           groups=self.groups * batchsize)
        height, width = outputs.shape[2:]
        outputs = outputs.reshape(batchsize, self.out_channels, height, width)
        if self.bias is not None:
            outputs = outputs + self.bias.reshape(1, -1, 1, 1)

        return outputs
Beispiel #8
0
class Quantization(nn.Module):
    def __init__(self, network_controller: NetworkQuantizationController, is_signed: bool,
                 alpha: float = 0.9, weights_values=None, efficient=True):
        """
        HMQ Block
        :param network_controller: The network controller
        :param is_signed: is this tensor signed
        :param alpha: the thresholds I.I.R value
        :param weights_values: In the case of weights quantized this is the tensors values
        :param efficient: Boolean flag stating to use the memory efficient
        """
        super(Quantization, self).__init__()
        self.weights_values = weights_values
        if weights_values is None:
            self.tensor_type = TensorType.ACTIVATION
            self.tensor_size = None
        else:
            self.tensor_type = TensorType.COEFFICIENT
            self.tensor_size = np.prod(weights_values.shape)

        self.network_controller = network_controller
        self.alpha = alpha
        self.is_signed_tensor = torch.Tensor([float(is_signed)]).cuda()

        if efficient:
            self.base_q = EfficientBaseQuantization()
        else:
            self.base_q = BaseQuantization()
        self.gumbel_softmax = GumbelSoftmax(ste=network_controller.ste)

        self.bits_vector = None
        self.mv_shifts = None
        self.base_thresholds = None
        self.nb_shifts_points_div = None
        self.search_matrix = None

    def init_quantization_coefficients(self):
        """
        This function initlized the HMQ parameters
        :return: None
        """
        init_threshold = 0
        n_bits_list, thresholds_shifts = self.network_controller.quantization_config.get_thresholds_bitwidth_lists(self)
        if self.is_coefficient():
            init_threshold = torch.pow(2.0, self.weights_values.abs().max().log2().ceil() + 1).item()
        if self.is_activation():
            n_bits_list = [8]

        self._init_quantization_params(n_bits_list, thresholds_shifts, init_threshold)
        self._init_search_matrix(self.network_controller.p, n_bits_list, len(thresholds_shifts))

    def _init_quantization_params(self, bit_list, thresholds_shifts, init_thresholds):
        self.update_bits_list(bit_list)
        self.mv_shifts = Parameter(torch.Tensor(thresholds_shifts), requires_grad=False)
        self.thresholds_shifts_points_div = Parameter(torch.pow(2.0, self.mv_shifts), requires_grad=False)
        self.base_thresholds = Parameter(torch.Tensor(1), requires_grad=False)
        init.constant_(self.base_thresholds, init_thresholds)

    def _init_search_matrix(self, p, n_bits_list, n_thresholds_options):
        n_channels = 1
        sm = -np.random.rand(n_channels, len(n_bits_list), n_thresholds_options, 1)
        n = np.prod(sm.shape)
        sm[:, 0, 0, 0] = np.log(p * n / (1 - p))  # for single channels
        self.search_matrix = Parameter(torch.Tensor(sm))

    def _get_quantization_probability_matrix(self, batch_size=1, noise_disable=False):
        return self.gumbel_softmax(self.search_matrix, self.network_controller.temperature, batch_size=batch_size,
                                   noise_disable=noise_disable)

    def _get_bits_probability(self, batch_size=1, noise_disable=False):
        p = self._get_quantization_probability_matrix(batch_size=batch_size, noise_disable=noise_disable)
        return p.sum(dim=4).sum(dim=3).sum(dim=1)

    def _update_iir(self, x):  # update scale using statistics
        if self.is_activation():
            if self.tensor_size is None:
                self.tensor_size = np.prod(x.shape[1:])  # Remove batch axis
            max_value = x.abs().max()
            self.base_thresholds.data.add_(self.alpha * (max_value - self.base_thresholds))

    def _calculate_expected_delta(self, p, max_scale):
        max_scales = max_scale / (self.thresholds_shifts_points_div.reshape(1, -1))
        max_scales = max_scales.reshape(1, 1, 1, -1, 1)

        nb_shifts = self.nb_shifts_points_div.reshape(1, 1, -1, 1, 1) * torch.pow(2.0, -self.is_signed_tensor)
        delta = (max_scales / nb_shifts) * p
        return delta.sum(dim=-1).sum(dim=-1).sum(dim=-1).sum(dim=-1)

    def _calculate_expected_threshold(self, p, max_threshold):
        p_t = p.sum(dim=4).sum(dim=2).sum(dim=1)
        thresholds = max_threshold / (self.thresholds_shifts_points_div.reshape(1, -1))
        return (p_t * thresholds).sum(dim=-1)

    def _calculate_expected_q_point(self, p, max_threshold, expected_delta, param_shape):
        t = self._calculate_expected_threshold(p, max_threshold=max_threshold).reshape(*param_shape)
        return t / expected_delta

    def _built_param_shape(self, x):
        random_size = x.shape[0] if self.is_activation() else x.shape[1]  # select random
        if len(x.shape) == 4:
            param_shape = [random_size, -1, 1, 1] if self.is_activation() else [-1, random_size, 1, 1]
        elif len(x.shape) == 2:
            param_shape = [random_size, -1] if self.is_activation() else [-1, random_size]
        else:
            raise NotImplemented
        return random_size, param_shape

    def forward(self, x):
        """
        The forward function of the HMQ module

        :param x: Input tensor x
        :return: A tensor after quantization
        """
        if self.network_controller.statistics_update:
            self._update_iir(x)
        max_threshold = torch.pow(2.0,
                                  torch.ceil(torch.log2(self.base_thresholds.detach().abs()))).detach()  # read scale
        if self.training and self.network_controller.temperature > 0:
            random_size, param_shape = self._built_param_shape(x)
            # axis according to tensor type (activation randomization is done over the batch axis,
            # coeff the randomization is done over the input channel axis)
            p = self._get_quantization_probability_matrix(batch_size=random_size)
            delta = self._calculate_expected_delta(p, max_threshold).reshape(*param_shape)
            q_points = self._calculate_expected_q_point(p, max_threshold, delta,
                                                        param_shape).reshape(*param_shape)
            return self.base_q(x, delta, q_points, self.is_signed_tensor)
        else:  # negative temperature/ infernce
            p = self._get_quantization_probability_matrix(batch_size=1, noise_disable=True).squeeze(dim=0)
            bits_index = torch.argmax(self._get_bits_probability(batch_size=1, noise_disable=True).squeeze(dim=0))
            max_index = torch.argmax(p[:, bits_index, :, 0], dim=-1)
            q_points = self.nb_shifts_points_div[bits_index] * torch.pow(2.0,
                                                                         -self.is_signed_tensor)
            max_scales = (max_threshold / self.thresholds_shifts_points_div.reshape(1, -1)).detach()
            delta = torch.stack(
                [(max_scales[i, mv] / q_points) for i, mv in enumerate(max_index)]).flatten().detach()
            return self.base_q(x, delta, q_points, self.is_signed_tensor)

    def get_bit_width(self):
        """
        This function return the selected bit-width
        :return: the bit-width of the HMQ
        """
        return self.bits_vector[torch.argmax(self._get_bits_probability(noise_disable=True).flatten())].item()

    def get_expected_bits(self):
        """
        This function return the expected bit-width
        :return: the expected bit-width of the HMQ
        """
        return (self.bits_vector * self._get_bits_probability(noise_disable=True)).sum()

    def get_float_size(self):
        """
        This function return the size of floating point tensor in bits
        Note: we assume 32 bits for floating point values
        :return: the floating point tensor size
        """
        return 32 * self.tensor_size

    def get_fxp_size(self):
        """
        This function return the size of quantized tensor in bits
        :return: the quantized tensor size
        """
        return self.get_bit_width() * self.tensor_size

    def is_activation(self):
        """
        This function return the boolean stating if this module quantize activation
        :return: a boolean flag stating if this activation quantization
        """
        return self.tensor_type == TensorType.ACTIVATION

    def is_coefficient(self):
        """
        This function return the boolean stating if this module quantize coefficient
        :return: a boolean flag stating if this coefficient quantization
        """
        return self.tensor_type == TensorType.COEFFICIENT

    def get_expected_tensor_size(self):
        """
         This function return the expected size of quantized tensor in bits
         :return: the expected size of quantized tensor
         """
        return torch.Tensor([self.tensor_size]).cuda()

    def update_bits_list(self, bits_list):
        """
        This function update the HMQ bit-width list
        :param bits_list: A list of new bit-widths
        :return: None
        """
        if self.bits_vector is None:
            self.bits_vector = Parameter(torch.Tensor(bits_list), requires_grad=False)
            self.nb_shifts_points_div = Parameter(
                torch.pow(2.0, self.bits_vector),  # - int(q_node.is_signed)
                requires_grad=False)  # move to init
        else:
            self.bits_vector.add_(torch.Tensor(bits_list).cuda() - self.bits_vector)
            self.nb_shifts_points_div.add_(torch.pow(2.0, self.bits_vector) - self.nb_shifts_points_div)
Beispiel #9
0
class _BatchNorm(Module):
    _version = 2
    __constants__ = [
        'track_running_stats', 'momentum', 'eps', 'weight', 'bias',
        'running_mean', 'running_var', 'num_batches_tracked'
    ]

    def __init__(self,
                 num_features,
                 eps=1e-5,
                 momentum=0.1,
                 affine=True,
                 track_running_stats=True):
        super(_BatchNorm, self).__init__()
        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum
        self.affine = affine
        self.track_running_stats = track_running_stats
        if self.affine:
            self.weight = Parameter(torch.Tensor(num_features))
            self.bias = Parameter(torch.Tensor(num_features))
        else:
            self.register_parameter('weight', None)
            self.register_parameter('bias', None)
        if self.track_running_stats:
            self.register_buffer('running_mean', torch.zeros(num_features))
            self.register_buffer('running_var', torch.ones(num_features))
            self.register_buffer('num_batches_tracked',
                                 torch.tensor(0, dtype=torch.long))
        else:
            self.register_parameter('running_mean', None)
            self.register_parameter('running_var', None)
            self.register_parameter('num_batches_tracked', None)
        self.reset_parameters()

    def reset_running_stats(self):
        if self.track_running_stats:
            self.running_mean.zero_()
            self.running_var.fill_(1)
            self.num_batches_tracked.zero_()

    def reset_parameters(self):
        self.reset_running_stats()
        if self.affine:
            init.uniform_(self.weight)
            init.zeros_(self.bias)

    def _check_input_dim(self, input):
        raise NotImplementedError

    def forward(self, x):

        self._check_input_dim(x)
        return_shape = x.shape
        y = x.transpose(0, 1)
        y = y.contiguous().view(x.size(1), -1)
        mu = y.mean(dim=1)
        sigma2 = y.var(dim=1, unbiased=False)
        if self.training is not True:
            y = y - self.running_mean.view(-1, 1)
            y = y / (self.running_var.view(-1, 1)**.5 + self.eps)
        else:
            if self.track_running_stats is True:
                with torch.no_grad():
                    self.running_mean = (
                        1 -
                        self.momentum) * self.running_mean + self.momentum * mu
                    self.running_var = (
                        1 - self.momentum
                    ) * self.running_var + self.momentum * sigma2

            y = y - mu.view(-1, 1)
            y = y / (sigma2.view(-1, 1)**.5 + self.eps)

        y = self.weight.view(-1, 1) * y + self.bias.view(-1, 1)

        self.mu = mu
        self.var = sigma2
        self.input = x.transpose(0, 1).reshape(x.size(1), -1)
        self.return_shape = return_shape

        N, C, H, W = return_shape
        return y.reshape(C, N, H, W).transpose(0, 1)

    def _grad_cam(self, grad_output, requires_activation=False):
        '''
        In grad_cam, the only thing we need is the gradient w.r.t. x, and we don't need to calculate the gradient of mu, var, gamma, and beta. 
        '''
        X, mu, var, gamma = self.input, self.mu.reshape(
            -1, 1), self.var.reshape(-1, 1), self.weight.reshape(-1, 1)
        N, C, H, W = self.return_shape
        n = N * W * H
        grad_output = grad_output.transpose(0, 1).reshape(C, -1)
        if self.training is not True:
            dX = grad_output * self.weight.reshape(
                -1, 1) / (self.running_var.reshape(-1, 1) + 1e-8)
        else:
            X_mu = X - mu
            std_inv = 1. / torch.sqrt(var + self.eps)

            dX_norm = grad_output * gamma
            dvar = torch.sum(dX_norm * X_mu, dim=1,
                             keepdim=True) * -.5 * std_inv**3
            dmu = torch.sum(dX_norm * -std_inv, dim=1,
                            keepdim=True) + dvar * torch.mean(
                                -2. * X_mu, dim=1, keepdim=True)

            dX = (dX_norm * std_inv) + (dvar * 2 * X_mu / n) + (dmu / n)
        return dX.reshape(C, N, H, W).transpose(0, 1), self.input.reshape(
            C, N, H, W).transpose(0, 1)

    def _simple_lrp(self, R, labels):
        return R

    def _epsilon_lrp(self, R, epsilon):
        '''
        Since there is only one (or several equally strong) dominant activations, default to _simple_lrp
        '''
        return self._simple_lrp(R)

    def _ww_lrp(self, R):
        '''
        There are no weights to use. default to _flat_lrp(R)
        '''
        return self._flat_lrp(R)

    def _flat_lrp(self, R):
        '''
        distribute relevance for each output evenly to the output neurons' receptive fields.
        '''
        return self._simple_lrp(R)

    def _alphabeta_lrp(self, R, labels):
        '''
        Since there is only one (or several equally strong) dominant activations, default to _simple_lrp
        '''
        return self._simple_lrp(R, labels)

    def _composite_lrp(self, R, labels):
        '''
        Since there is only one (or several equally strong) dominant activations, default to _simple_lrp
        '''
        return self._simple_lrp(R, labels)

    def extra_repr(self):
        return '{num_features}, eps={eps}, momentum={momentum}, affine={affine}, ' \
               'track_running_stats={track_running_stats}'.format(**self.__dict__)

    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
                              missing_keys, unexpected_keys, error_msgs):
        version = local_metadata.get('version', None)

        if (version is None or version < 2) and self.track_running_stats:
            # at version 2: added num_batches_tracked buffer
            #               this should have a default value of 0
            num_batches_tracked_key = prefix + 'num_batches_tracked'
            if num_batches_tracked_key not in state_dict:
                state_dict[num_batches_tracked_key] = torch.tensor(
                    0, dtype=torch.long)

        super(_BatchNorm,
              self)._load_from_state_dict(state_dict, prefix, local_metadata,
                                          strict, missing_keys,
                                          unexpected_keys, error_msgs)
Beispiel #10
0
class DenseFCLayer(torch.nn.Module):
    def __init__(self,
                 n_inputs=None,
                 n_outputs=None,
                 weights: torch.Tensor = None,
                 use_biases=True,
                 activation=None):
        super(DenseFCLayer, self).__init__()
        if n_inputs is not None and n_outputs is not None:
            self.n_inputs = n_inputs
            self.n_outputs = n_outputs
            self._activation = activation
            self._initial_weights = None

            self._weights = Parameter(torch.Tensor(n_inputs, n_outputs))
            self._init_weights()
            self._mask = torch.ones_like(self._weights)
            self._initial_weights = self._weights.clone()
            self.use_biases = use_biases

            if self.use_biases:
                self._biases = Parameter(torch.Tensor(n_outputs))
                self._init_biases()
        elif weights is not None:
            self.n_inputs = weights.size(0)
            self.n_outputs = weights.size(1)
            self._activation = activation
            self._initial_weights = weights

            self._weights = Parameter(weights)
            self._mask = torch.ones_like(self._weights)

            self._biases = Parameter(torch.Tensor(self.n_outputs))
            self._init_biases()
        else:
            raise ValueError(
                "DenseFClayer class accepts either n_inputs/n_outputs or weights"
            )

    def _init_weights(self):
        # Note the difference between init functions
        # torch.nn.init.xavier_normal_(self._weights)
        # torch.nn.init.xavier_uniform_(self._weights)
        # torch.nn.init.kaiming_normal_(self._weights)
        torch.nn.init.kaiming_uniform_(self._weights)

    def _init_biases(self):
        torch.nn.init.zeros_(self._biases)

    def prune_by_threshold(self, thr):
        self._mask *= (torch.abs(self._weights) >= thr).float()

    def prune_by_rank(self, rank):
        weights_val = self._weights[self._mask == 1]
        sorted_abs_weights = torch.sort(torch.abs(weights_val))[0]
        thr = sorted_abs_weights[rank]
        self.prune_by_threshold(thr)

    def prune_by_pct(self, pct):
        prune_idx = int(self.n_weights * pct)
        self.prune_by_rank(prune_idx)

    def prune_by_pct_taylor(self, pct):
        prune_idx = int(self.n_weights * pct)

        # by abs val
        wg = torch.abs(self._weights[self._mask == 1] *
                       self._weights.grad[self._mask == 1])
        sorted_wg = torch.sort(wg)[0]
        thr = sorted_wg[prune_idx]
        print(thr)
        self._mask *= (torch.abs(self._weights * self._weights.grad) >
                       thr).float()

        # by val
        # wg = self._weights[self._mask == 1] * self._weights.grad[self._mask == 1]
        # sorted_wg = torch.sort(wg)[0]
        # thr = sorted_wg[prune_idx]
        # self._mask *= (self._weights * self._weights.grad >= thr).float()

    def random_prune_by_pct(self, pct):
        prune_idx = int(self.n_weights * pct)
        rand = torch.rand(size=self._mask.size(), device=self._mask.device)
        rand_val = rand[self._mask == 1]
        sorted_abs_rand = torch.sort(rand_val)[0]
        thr = sorted_abs_rand[prune_idx]
        self._mask *= (rand >= thr).float()

    def reinitialize(self):
        self._weights = Parameter(self._initial_weights)
        self._init_biases()  # biases are reinitialized

    def to_sparse(self) -> SparseFCLayer:
        return SparseFCLayer((self._weights * self._mask).t().to_sparse(),
                             self._biases.reshape((-1, 1)), self._activation)

    @classmethod
    def from_sparse(cls, s_layer: SparseFCLayer):
        return cls(weights=s_layer.weights.t().to_dense(),
                   activation=s_layer.activation)

    def to_device(self, device: torch.device):
        self._initial_weights = self._initial_weights.to(device)
        self._mask = self._mask.to(device)

    def forward(self, inputs: torch.Tensor, use_mask=True):
        masked_weights = self._weights
        if use_mask:
            masked_weights = self._weights * self._mask
        if self.use_biases:
            ret = torch.addmm(self._biases, inputs, masked_weights)
        else:
            ret = torch.mm(inputs, masked_weights)
        return ret if self._activation is None else self._activation(ret)

    @property
    def mask(self):
        return self._mask

    @property
    def weights(self):
        return self._weights

    @property
    def activation(self):
        return self._activation

    @property
    def n_weights(self):
        return torch.nonzero(self._mask).size(0)

    @property
    def biases(self):
        if self.use_biases:
            return self._biases
        else:
            return None

    def __str__(self):
        return "DenseFClayer with size {} and activation {}".format(
            (self.n_inputs, self.n_outputs), self._activation)
Beispiel #11
0
class QuantConv2d(nn.Conv2d):
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride=1,
                 padding=0,
                 dilation=1,
                 groups=1,
                 bias=True,
                 bit=32,
                 extern_init=False,
                 init_model=nn.Sequential()):
        super(QuantConv2d,
              self).__init__(in_channels, out_channels, kernel_size, stride,
                             padding, dilation, groups, bias)
        self.bit = bit
        self.pwr_coef = 2**(bit - 1)
        self.Round_w = RoundFn_LLSQ.apply
        self.Round_b = RoundFn_Bias.apply
        self.bias_flag = bias
        #self.alpha_w = Variable(torch.rand( out_channels,1,1,1)).cuda()
        # self.alpha_w = Parameter(torch.rand( out_channels))
        if bit < 0:
            self.alpha_w = None
        else:
            self.alpha_w = Parameter(torch.rand(out_channels))
        #self.alpha_qfn = quan_fn_alpha()
        nn.init.kaiming_normal_(self.weight,
                                mode='fan_out',
                                nonlinearity='relu')
        if extern_init:
            param = list(init_model.parameters())
            self.weight = Parameter(param[0])
            if bias:
                self.bias = Parameter(param[1])
        if bit < 0:
            self.init_state = 0
        else:
            self.register_buffer('init_state', torch.zeros(1))
        # self.init_state = 0
    def forward(self, x):
        if self.bit == 32:
            return F.conv2d(x, self.weight, self.bias, self.stride,
                            self.padding, self.dilation, self.groups)
        else:
            w_reshape = self.weight.reshape([self.weight.shape[0],
                                             -1]).transpose(0, 1)
            if self.training and self.init_state == 0:
                self.alpha_w.data.copy_(
                    w_reshape.detach().abs().max(dim=0)[0] / self.pwr_coef)
                self.init_state.fill_(1)
                #self.init_state = 1

            #assert not torch.isnan(x).any(), "Conv2d Input should not be 'nan'"
            alpha_w = self.alpha_w  #self.alpha_qfn(self.alpha_w)
            #if torch.isnan(self.alpha_w).any() or torch.isinf(self.alpha_w).any():
            #    assert not torch.isnan(wq).any(), self.alpha_w
            #    assert not torch.isinf(wq).any(), self.alpha_w

            wq = self.Round_w(w_reshape, alpha_w, self.pwr_coef, self.bit)
            w_q = wq.transpose(0, 1).reshape(self.weight.shape)

            if self.bias_flag == True:
                LLSQ_b = self.Round_b(self.bias, alpha_w, self.pwr_coef,
                                      self.bit)
            else:
                LLSQ_b = self.bias

            # assert not torch.isnan(self.weight).any(), "Weight should not be 'nan'"
            # if torch.isnan(wq).any() or torch.isinf(wq).any():
            #     print(self.alpha_w)
            #     assert not torch.isnan(wq).any(), "Conv2d Weights should not be 'nan'"
            #     assert not torch.isinf(wq).any(), "Conv2d Weights should not be 'nan'"

            return F.conv2d(x, w_q, LLSQ_b, self.stride, self.padding,
                            self.dilation, self.groups)

    def extra_repr(self):
        s_prefix = super(QuantConv2d, self).extra_repr()
        if self.alpha_w is None:
            return '{}, fake'.format(s_prefix)
        return '{}'.format(s_prefix)
Beispiel #12
0
class LinearBn1d(nn.modules.linear.Linear, nni._FusedModule):
    r"""
    A LinearBn1d module is a module fused from Linear and BatchNorm1d, attached
    with FakeQuantize modules for weight, used in quantization aware training.

    We combined the interface of :class:`torch.nn.Linear` and
    :class:torch.nn.BatchNorm1d`.

    Similar to :class:`torch.nn.Linear`, with FakeQuantize modules initialized
    to default.

    Attributes:
        freeze_bn:
        weight_fake_quant: fake quant module for weight

    """
    def __init__(
            self,
            # Linear args
            in_features,
            out_features,
            bias=True,
            # BatchNorm1d args
            # num_features: out_features
            eps=1e-05,
            momentum=0.1,
            # affine: True
            # track_running_stats: True
            # Args for this module
            freeze_bn=False,
            qconfig=None):
        nn.modules.linear.Linear.__init__(self, in_features, out_features,
                                          bias)
        assert qconfig, 'qconfig must be provded for QAT module'
        self.qconfig = qconfig
        self.freeze_bn = freeze_bn if self.training else True
        self.bn = nn.BatchNorm1d(out_features, eps, momentum, True, True)
        self.weight_fake_quant = self.qconfig.weight()
        if bias:
            self.bias = Parameter(torch.empty(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_bn_parameters()

        # this needs to be called after reset_bn_parameters,
        # as they modify the same state
        if self.training:
            if freeze_bn:
                self.freeze_bn_stats()
            else:
                self.update_bn_stats()
        else:
            self.freeze_bn_stats()

    def reset_running_stats(self):
        self.bn.reset_running_stats()

    def reset_bn_parameters(self):
        self.bn.reset_running_stats()
        init.uniform_(self.bn.weight)
        init.zeros_(self.bn.bias)

    def reset_parameters(self):
        super(LinearBn1d, self).reset_parameters()

    def update_bn_stats(self):
        self.freeze_bn = False
        self.bn.training = True
        return self

    def freeze_bn_stats(self):
        self.freeze_bn = True
        self.bn.training = False
        return self

    def forward(self, input):
        assert self.bn.running_var is not None

        # Scale the linear weights by BN's running statistics to reduce
        # weight jitter, see https://arxiv.org/pdf/1806.08342.pdf, page 18
        # for motivation.
        #
        # Instead of
        #
        #   x1 = F.linear(x0, fq(w), b)
        #   x2 = self.bn(x1)
        #
        # We have
        #
        #   # scale the weight by previous batch's running statistics
        #   scale_factor = bn.w / bn.running_std_from_prev_batch
        #   # do the linear transformation without bias
        #   x1_scaled = F.linear(x0, fq(w * scale_factor), 0)
        #   # reverse the scaling and add original bias
        #   x1_orig = x1_scaled / scale_factor + b
        #   x2 = self.bn(x1_orig)

        running_std = torch.sqrt(self.bn.running_var + self.bn.eps)
        scale_factor = self.bn.weight / running_std
        weight_shape = [1] * len(self.weight.shape)
        weight_shape[0] = -1
        bias_shape = [1] * len(self.weight.shape)
        bias_shape[1] = -1
        scaled_weight = self.weight_fake_quant(
            self.weight * scale_factor.reshape(weight_shape))
        if self.bias is not None:
            zero_bias = torch.zeros_like(self.bias)
        else:
            zero_bias = torch.zeros(self.out_features,
                                    device=scaled_weight.device)
        linear_out = F.linear(input, scaled_weight, zero_bias)
        linear_out_orig = linear_out / scale_factor.reshape(bias_shape)
        if self.bias is not None:
            linear_out_orig = linear_out_orig + self.bias.reshape(bias_shape)
        bn_out = self.bn(linear_out_orig)
        return bn_out

    def train(self, mode=True):
        """
        Batchnorm's training behavior is using the self.training flag. Prevent
        changing it if BN is frozen. This makes sure that calling `model.train()`
        on a model with a frozen BN will behave properly.
        """
        self.training = mode
        if not self.freeze_bn:
            for module in self.children():
                module.train(mode)
        return self

    @classmethod
    def from_float(cls, mod):
        r"""Create a qat module from a float module or qparams_dict

            Args: `mod' a float module, either produced by torch.ao.quantization
            utilities or directly from user
        """
        assert type(mod) == nni.LinearBn1d, 'qat.' + cls.__name__ + \
            '.from_float only works for ' + nni.LinearBn1d.__name__
        assert hasattr(
            mod, 'qconfig'), 'Input float module must have qconfig defined'
        assert mod.qconfig, 'Input float module must have a valid config'
        qconfig = mod.qconfig
        linear, bn = mod[0], mod[1]
        qat_linearbn = cls(linear.in_features, linear.out_features, linear.bias
                           is not None, bn.eps, bn.momentum, False, qconfig)
        qat_linearbn.weight = linear.weight
        qat_linearbn.bias = linear.bias
        qat_linearbn.bn.weight = bn.weight
        qat_linearbn.bn.bias = bn.bias
        qat_linearbn.bn.running_mean = bn.running_mean
        qat_linearbn.bn.running_var = bn.running_var
        qat_linearbn.bn.num_batches_tracked = bn.num_batches_tracked
        return qat_linearbn

    def to_float(self):
        linear = torch.nn.Linear(self.in_features, self.out_features)
        linear.weight, linear.bias = fuse_linear_bn_weights(
            self.weight, self.bias, self.bn.running_mean, self.bn.running_var,
            self.bn.eps, self.bn.weight, self.bn.bias)
        return linear
Beispiel #13
0
class DenseLinear(nn.Module):
    __constants__ = ['in_features', 'out_features']

    def __init__(self,
                 in_features,
                 out_features,
                 use_bias=True,
                 use_mask=True,
                 **kwargs):
        super(DenseLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(torch.Tensor(out_features, in_features))
        if use_bias:
            self.bias = Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter('bias', None)

        self.reset_parameters(**kwargs)

        # self._initial_weight = self.weight.data.clone()
        # self._initial_bias = self.bias.data.clone() if use_bias else None
        self.use_mask = use_mask
        self.mask = torch.ones_like(self.weight, dtype=torch.bool)

    def reset_parameters(self, **kwargs):
        if len(kwargs.keys()) == 0:
            # default init, see https://pytorch.org/docs/stable/_modules/torch/nn/modules/linear.html#Linear
            init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        else:
            init.kaiming_uniform_(self.weight, **kwargs)

        if self.bias is not None:
            # default init, see https://pytorch.org/docs/stable/_modules/torch/nn/modules/linear.html#Linear
            fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            init.uniform_(self.bias, -bound, bound)

    def forward(self, inp: torch.Tensor):
        masked_weight = self.weight * self.mask if self.use_mask else self.weight
        return nn.functional.linear(inp, masked_weight, self.bias)

    def extra_repr(self):
        return 'in_features={}, out_features={}, bias={}'.format(
            self.in_features, self.out_features, self.bias is not None)

    def prune_by_threshold(self, thr):
        self.mask *= (self.weight.abs() >= thr)

    def prune_by_rank(self, rank):
        if rank == 0:
            return
        weight_val = self.weight[self.mask == 1.]
        sorted_abs_weight = weight_val.abs().sort()[0]
        thr = sorted_abs_weight[rank]
        self.prune_by_threshold(thr)

    def prune_by_pct(self, pct):
        prune_idx = int(self.num_weight * pct)
        self.prune_by_rank(prune_idx)

    def retain_by_threshold(self, thr):
        self.mask *= (self.weight.abs() >= thr)

    def retain_by_rank(self, rank):
        weights_val = self.weight[self.mask == 1.]
        sorted_abs_weights = weights_val.abs().sort(descending=True)[0]
        thr = sorted_abs_weights[rank]
        self.retain_by_threshold(thr)

    def random_prune_by_pct(self, pct):
        prune_idx = int(self.num_weight * pct)
        rand = torch.rand(size=self.mask.size(), device=self.mask.device)
        rand_val = rand[self.mask == 1]
        sorted_abs_rand = rand_val.sort()[0]
        thr = sorted_abs_rand[prune_idx]
        self.mask *= (rand >= thr)

    # def reinitialize(self):
    #     self.weight = Parameter(self._initial_weight)
    #     if self._initial_bias is not None:
    #         self.bias = Parameter(self._initial_bias)

    def to_sparse(self, transpose=False) -> SparseLinear:
        """
        by chance, some entries with mask = 1 can have a 0 value. Thus, the to_sparse methods give a different size
        there's no efficient way to solve it yet
        """
        sparse_bias = None if self.bias is None else self.bias.reshape((-1, 1))
        sparse_linear = SparseLinear((self.weight * self.mask).to_sparse(),
                                     sparse_bias, self.mask)
        if transpose:
            sparse_linear.transpose = True
        return sparse_linear

    def move_data(self, device: torch.device):
        self.mask = self.mask.to(device)

    def to(self, *args, **kwargs):
        device = torch._C._nn._parse_to(*args, **kwargs)[0]

        if device is not None:
            self.move_data(device)

        return super(DenseLinear, self).to(*args, **kwargs)

    @property
    def num_weight(self) -> int:
        return self.mask.sum().item()
Beispiel #14
0
class CConv2d(nn.Module):
    def __init__(self, num, in_channels, out_channels, kernel_size, stride=1,
                                 padding=0, dilation=1, groups=1, bias=True):
        super(CConv2d, self).__init__()
        self.num = num
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.groups = groups

        self.routing_fc = nn.Linear(in_channels, num)
        self.weight = Parameter(torch.Tensor(out_channels * in_channels * kernel_size * kernel_size // groups, num))
        if bias:
            self.bias = Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)
        self._initialize_weights()
    def _initialize_weights(self):
        for name, m in self.named_modules():

            if isinstance(m, CConv2d):
                c, h, w = m.in_channels, m.kernel_size, m.kernel_size
                nn.init.normal_(m.weight, 0, math.sqrt(1.0 / (c * h * w)))
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
                nn.init.constant_(m.running_mean, 1)
            elif isinstance(m, nn.BatchNorm1d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
                nn.init.constant_(m.running_mean, 1)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)
    def forward(self, inputs):
        '''
        if inputs.shape[-1] == 7:
            print('inputs_se:', inputs_se)

        '''
        x = inputs
        inputs_se = x.reshape(x.shape[0], x.shape[1], -1).mean(dim=-1, keepdim=False)
        inputs_se = F.sigmoid(self.routing_fc(inputs_se))

        batchsize, channel, height, width = inputs.shape
        weight = F.linear(inputs_se, self.weight)
        weight = weight.reshape(batchsize * self.out_channels, self.in_channels // self.groups, self.kernel_size,
                                self.kernel_size)
        inputs = inputs.reshape(1, batchsize * channel, height, width)
        outputs = F.conv2d(inputs, weight, None, self.stride, self.padding, self.dilation,
                           groups=self.groups * batchsize)
        height, width = outputs.shape[2:]
        outputs = outputs.reshape(batchsize, self.out_channels, height, width)
        if self.bias is not None:
            outputs = outputs + self.bias.reshape(1, -1, 1, 1)
        
        return outputs
Beispiel #15
0
class _ConvBnNd(nn.modules.conv._ConvNd):

    _version = 2

    def __init__(
            self,
            # ConvNd args
            in_channels,
            out_channels,
            kernel_size,
            stride,
            padding,
            dilation,
            transposed,
            output_padding,
            groups,
            bias,
            padding_mode,
            # BatchNormNd args
            # num_features: out_channels
            eps=1e-05,
            momentum=0.1,
            # affine: True
            # track_running_stats: True
            # Args for this module
            freeze_bn=False,
            qconfig=None):
        nn.modules.conv._ConvNd.__init__(self, in_channels, out_channels,
                                         kernel_size, stride, padding,
                                         dilation, transposed, output_padding,
                                         groups, False, padding_mode)
        assert qconfig, 'qconfig must be provided for QAT module'
        self.qconfig = qconfig
        self.freeze_bn = freeze_bn if self.training else True
        self.bn = nn.BatchNorm2d(out_channels, eps, momentum, True, True)
        self.weight_fake_quant = self.qconfig.weight()
        if bias:
            self.bias = Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)
        self.reset_bn_parameters()

        # this needs to be called after reset_bn_parameters,
        # as they modify the same state
        if self.training:
            if freeze_bn:
                self.freeze_bn_stats()
            else:
                self.update_bn_stats()
        else:
            self.freeze_bn_stats()

    def reset_running_stats(self):
        self.bn.reset_running_stats()

    def reset_bn_parameters(self):
        self.bn.reset_running_stats()
        init.uniform_(self.bn.weight)
        init.zeros_(self.bn.bias)
        # note: below is actully for conv, not BN
        if self.bias is not None:
            fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            init.uniform_(self.bias, -bound, bound)

    def reset_parameters(self):
        super(_ConvBnNd, self).reset_parameters()

    def update_bn_stats(self):
        self.freeze_bn = False
        self.bn.training = True
        return self

    def freeze_bn_stats(self):
        self.freeze_bn = True
        self.bn.training = False
        return self

    def _forward(self, input):
        running_std = torch.sqrt(self.bn.running_var + self.bn.eps)
        scale_factor = self.bn.weight / running_std
        scaled_weight = self.weight_fake_quant(
            self.weight * scale_factor.reshape([-1, 1, 1, 1]))
        # this does not include the conv bias
        conv = self._conv_forward(input, scaled_weight)
        conv_orig = conv / scale_factor.reshape([1, -1, 1, 1])
        if self.bias is not None:
            conv_orig = conv_orig + self.bias.reshape([1, -1, 1, 1])
        conv = self.bn(conv_orig)
        return conv

    def extra_repr(self):
        # TODO(jerryzh): extend
        return super(_ConvBnNd, self).extra_repr()

    def forward(self, input):
        return self._forward(input)

    def train(self, mode=True):
        """
        Batchnorm's training behavior is using the self.training flag. Prevent
        changing it if BN is frozen. This makes sure that calling `model.train()`
        on a model with a frozen BN will behave properly.
        """
        self.training = mode
        if not self.freeze_bn:
            for module in self.children():
                module.train(mode)
        return self

    # ===== Serialization version history =====
    #
    # Version 1/None
    #   self
    #   |--- weight : Tensor
    #   |--- bias : Tensor
    #   |--- gamma : Tensor
    #   |--- beta : Tensor
    #   |--- running_mean : Tensor
    #   |--- running_var : Tensor
    #   |--- num_batches_tracked : Tensor
    #
    # Version 2
    #   self
    #   |--- weight : Tensor
    #   |--- bias : Tensor
    #   |--- bn : Module
    #        |--- weight : Tensor (moved from v1.self.gamma)
    #        |--- bias : Tensor (moved from v1.self.beta)
    #        |--- running_mean : Tensor (moved from v1.self.running_mean)
    #        |--- running_var : Tensor (moved from v1.self.running_var)
    #        |--- num_batches_tracked : Tensor (moved from v1.self.num_batches_tracked)
    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
                              missing_keys, unexpected_keys, error_msgs):
        version = local_metadata.get('version', None)
        if version is None or version == 1:
            # BN related parameters and buffers were moved into the BN module for v2
            v2_to_v1_names = {
                'bn.weight': 'gamma',
                'bn.bias': 'beta',
                'bn.running_mean': 'running_mean',
                'bn.running_var': 'running_var',
                'bn.num_batches_tracked': 'num_batches_tracked',
            }
            for v2_name, v1_name in v2_to_v1_names.items():
                if prefix + v1_name in state_dict:
                    state_dict[prefix + v2_name] = state_dict[prefix + v1_name]
                    state_dict.pop(prefix + v1_name)
                elif prefix + v2_name in state_dict:
                    # there was a brief period where forward compatibility
                    # for this module was broken (between
                    # https://github.com/pytorch/pytorch/pull/38478
                    # and https://github.com/pytorch/pytorch/pull/38820)
                    # and modules emitted the v2 state_dict format while
                    # specifying that version == 1. This patches the forward
                    # compatibility issue by allowing the v2 style entries to
                    # be used.
                    pass
                elif strict:
                    missing_keys.append(prefix + v2_name)

        super(_ConvBnNd,
              self)._load_from_state_dict(state_dict, prefix, local_metadata,
                                          strict, missing_keys,
                                          unexpected_keys, error_msgs)

    @classmethod
    def from_float(cls, mod):
        r"""Create a qat module from a float module or qparams_dict

            Args: `mod` a float module, either produced by torch.quantization utilities
            or directly from user
        """
        assert type(mod) == cls._FLOAT_MODULE, 'qat.' + cls.__name__ + '.from_float only works for ' + \
            cls._FLOAT_MODULE.__name__
        assert hasattr(
            mod, 'qconfig'), 'Input float module must have qconfig defined'
        assert mod.qconfig, 'Input float module must have a valid qconfig'
        qconfig = mod.qconfig
        conv, bn = mod[0], mod[1]
        qat_convbn = cls(conv.in_channels, conv.out_channels, conv.kernel_size,
                         conv.stride, conv.padding, conv.dilation, conv.groups,
                         conv.bias is not None, conv.padding_mode, bn.eps,
                         bn.momentum, False, qconfig)
        qat_convbn.weight = conv.weight
        qat_convbn.bias = conv.bias
        qat_convbn.bn.weight = bn.weight
        qat_convbn.bn.bias = bn.bias
        qat_convbn.bn.running_mean = bn.running_mean
        qat_convbn.bn.running_var = bn.running_var
        qat_convbn.bn.num_batches_tracked = bn.num_batches_tracked
        return qat_convbn
Beispiel #16
0
class _ConvBnNd(nn.modules.conv._ConvNd):
    def __init__(self,
                 # ConvNd args
                 in_channels, out_channels, kernel_size, stride,
                 padding, dilation, transposed, output_padding,
                 groups,
                 bias,
                 padding_mode,
                 # BatchNormNd args
                 # num_features: out_channels
                 eps=1e-05, momentum=0.1,
                 # affine: True
                 # track_running_stats: True
                 # Args for this module
                 freeze_bn=False,
                 qconfig=None):
        nn.modules.conv._ConvNd.__init__(self, in_channels, out_channels, kernel_size,
                                         stride, padding, dilation, transposed,
                                         output_padding, groups, False, padding_mode)
        assert qconfig, 'qconfig must be provided for QAT module'
        self.qconfig = qconfig
        self.freeze_bn = freeze_bn if self.training else True
        self.bn = nn.BatchNorm2d(out_channels, eps, momentum, True, True)
        self.activation_post_process = self.qconfig.activation()
        self.weight_fake_quant = self.qconfig.weight()
        if bias:
            self.bias = Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)
        self.reset_bn_parameters()

        # this needs to be called after reset_bn_parameters,
        # as they modify the same state
        if self.training:
            if freeze_bn:
                self.freeze_bn_stats()
            else:
                self.update_bn_stats()
        else:
            self.freeze_bn_stats()

    def reset_running_stats(self):
        self.bn.reset_running_stats()

    def reset_bn_parameters(self):
        self.bn.reset_running_stats()
        init.uniform_(self.bn.weight)
        init.zeros_(self.bn.bias)
        # note: below is actully for conv, not BN
        if self.bias is not None:
            fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            init.uniform_(self.bias, -bound, bound)

    def reset_parameters(self):
        super(_ConvBnNd, self).reset_parameters()

    def update_bn_stats(self):
        self.freeze_bn = False
        self.bn.training = True
        return self

    def freeze_bn_stats(self):
        self.freeze_bn = True
        self.bn.training = False
        return self

    def _forward(self, input):
        running_std = torch.sqrt(self.bn.running_var + self.bn.eps)
        scale_factor = self.bn.weight / running_std
        scaled_weight = self.weight_fake_quant(self.weight * scale_factor.reshape([-1, 1, 1, 1]))
        # this does not include the conv bias
        conv = self._conv_forward(input, scaled_weight)
        conv_orig = conv / scale_factor.reshape([1, -1, 1, 1])
        if self.bias is not None:
            conv_orig = conv_orig + self.bias.reshape([1, -1, 1, 1])
        conv = self.bn(conv_orig)
        return conv

    def extra_repr(self):
        # TODO(jerryzh): extend
        return super(_ConvBnNd, self).extra_repr()

    def forward(self, input):
        return self.activation_post_process(self._forward(input))

    def train(self, mode=True):
        """
        Batchnorm's training behavior is using the self.training flag. Prevent
        changing it if BN is frozen. This makes sure that calling `model.train()`
        on a model with a frozen BN will behave properly.
        """
        self.training = mode
        if not self.freeze_bn:
            for module in self.children():
                module.train(mode)
        return self

    @classmethod
    def from_float(cls, mod, qconfig=None):
        r"""Create a qat module from a float module or qparams_dict

            Args: `mod` a float module, either produced by torch.quantization utilities
            or directly from user
        """
        assert type(mod) == cls._FLOAT_MODULE, 'qat.' + cls.__name__ + '.from_float only works for ' + \
            cls._FLOAT_MODULE.__name__
        if not qconfig:
            assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined'
            assert mod.qconfig, 'Input float module must have a valid qconfig'
            qconfig = mod.qconfig
        conv, bn = mod[0], mod[1]
        qat_convbn = cls(conv.in_channels, conv.out_channels, conv.kernel_size,
                         conv.stride, conv.padding, conv.dilation,
                         conv.groups, conv.bias is not None,
                         conv.padding_mode,
                         bn.eps, bn.momentum,
                         False,
                         qconfig)
        qat_convbn.weight = conv.weight
        qat_convbn.bias = conv.bias
        qat_convbn.bn.weight = bn.weight
        qat_convbn.bn.bias = bn.bias
        qat_convbn.bn.running_mean = bn.running_mean
        qat_convbn.bn.running_var = bn.running_var
        qat_convbn.bn.num_batches_tracked = bn.num_batches_tracked
        return qat_convbn
class DCTConvolution(torch.nn.Conv2d):
    """
    Instantiate a learnable convolution operator - the trainable weights are the coefficients of the
    DCT decomposition of the convolution operator (i.e. this change is bijective for convex training procedures!)
    """
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 bias=False,
                 mean=False,
                 dilation=1):

        self.mean = mean
        padding = int(np.floor(kernel_size / 2))
        super().__init__(in_channels,
                         out_channels,
                         kernel_size,
                         padding=padding,
                         bias=bias,
                         dilation=dilation)

        self.initialize_DCT()
        torch.nn.init.orthogonal_(self.weight, gain=1)

    def initialize_DCT(self):
        num_basis_functions = self.kernel_size[0] * self.kernel_size[1]
        dct_basis = self.weight.new_zeros(num_basis_functions, 1,
                                          *self.kernel_size)

        b = 0  # enumerate basis functions
        for b1 in range(self.kernel_size[0]):
            for b2 in range(self.kernel_size[1]):
                for i in range(self.kernel_size[0]):
                    for j in range(self.kernel_size[1]):
                        dct_basis[b, 0, i,
                                  j] = (np.cos(np.pi / self.kernel_size[0] *
                                               (i + 0.5) * b1) *
                                        np.cos(np.pi / self.kernel_size[1] *
                                               (j + 0.5) * b2))
                b += 1

        if not self.mean:
            dct_basis = dct_basis[1:, 0:1, :, :]
            num_weights = (num_basis_functions - 1) * self.in_channels
            self.weight = Parameter(
                self.weight.new_zeros(self.out_channels, num_weights, 1, 1))
        else:
            num_weights = num_basis_functions * self.in_channels
            self.weight = Parameter(
                self.weight.new_zeros(self.out_channels, num_weights, 1, 1))

        self.register_buffer(
            'dct_basis',
            torch.cat([dct_basis] * self.in_channels, 0) /
            torch.norm(dct_basis))
        init.kaiming_uniform_(self.weight, a=np.sqrt(5))

    def forward(self, input, direction='op'):
        if direction == 'op':
            dct_response = F.conv2d(input,
                                    self.dct_basis,
                                    None,
                                    self.stride,
                                    self.padding,
                                    self.dilation,
                                    groups=self.in_channels)
            return F.conv2d(dct_response,
                            self.weight,
                            self.bias,
                            self.stride,
                            padding=0,
                            dilation=self.dilation,
                            groups=self.groups)
        elif direction == 't':
            input_weighted = F.conv_transpose2d(input,
                                                self.weight,
                                                None,
                                                self.stride,
                                                output_padding=0,
                                                padding=0,
                                                dilation=self.dilation,
                                                groups=self.groups)
            return F.conv_transpose2d(input_weighted,
                                      self.dct_basis,
                                      None,
                                      self.stride,
                                      output_padding=0,
                                      padding=self.padding,
                                      dilation=self.dilation,
                                      groups=self.in_channels)
        else:
            raise ValueError('Invalid Direction')

    def return_filters(self):
        with torch.no_grad():
            num_basis_functions = self.dct_basis.shape[0] // self.in_channels

            color_weights = self.weight.reshape(self.out_channels,
                                                num_basis_functions,
                                                self.in_channels, 1, 1)
            return (
                self.dct_basis[0:num_basis_functions, :, :, :].unsqueeze(0) *
                color_weights).sum(1)

    def normest(self):
        # return self.weight.norm(dim=[2, 3]).sum() / self.groups
        return normest(self, self.in_channels, verbose=False)