コード例 #1
0
class group_relaxed_L1L2Conv2d(Module):
    """Implementation of TF1 regularization for the feature maps of a convolutional layer"""
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride=1,
                 padding=0,
                 dilation=1,
                 groups=1,
                 bias=True,
                 lamba=1.,
                 alpha=1.,
                 beta=4.,
                 weight_decay=1.,
                 **kwargs):
        """
		:param in_channels: Number of input channels
		:param out_channels: Number of output channels
		:param kernel_size: size of the kernel
		:param stride: stride for the convolution
		:param padding: padding for the convolution
		:param dilation: dilation factor for the convolution
		:param groups: how many groups we will assume in the convolution
		:param bias: whether we will use a bias
		:param lamba: strength of the TFL regularization
		"""
        super(group_relaxed_L1L2Conv2d, self).__init__()
        self.floatTensor = torch.FloatTensor if not torch.cuda.is_available(
        ) else torch.cuda.FloatTensor
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = pair(kernel_size)
        self.stride = pair(stride)
        self.padding = pair(padding)
        self.dilation = pair(dilation)
        self.output_padding = pair(0)
        self.groups = groups
        self.lamba = lamba
        self.alpha = alpha
        self.beta = beta
        self.lamba1 = self.lamba / self.beta
        self.weight_decay = weight_decay
        self.weight = Parameter(
            torch.Tensor(out_channels, in_channels // groups,
                         *self.kernel_size))
        self.u = torch.rand(out_channels, in_channels // groups,
                            *self.kernel_size)
        self.u = self.u.to('cuda')
        if bias:
            self.bias = Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()
        self.input_shape = None
        print(self)

    def reset_parameters(self):
        init.kaiming_normal(self.weight, mode='fan_in')

        if self.bias is not None:
            self.bias.data.normal_(0, 1e-2)

    def constrain_parameters(self, **kwargs):
        norm_w = self.weight.data.norm(p=float('inf'))
        if norm_w > self.lamba1:
            m = Softshrink(self.lamba1)
            z = m(self.weight.data)
            self.u.data = z * (z.data.norm(p=2) +
                               self.alpha * self.lamba1) / (z.data.norm(p=2))
        elif norm_w == self.lamba1:
            self.u = self.weight.clone()
            self.u[self.u.abs() < lamba1] = 0
            n = torch.sum(self.u != 0)
            self.u[self.u != 0] = self.weight.sign(
            ) * self.alpha * self.lamba1 / (n**(1 / 2))

        elif (1 - self.alpha) * self.lamba1 < norm_w and norm_w < self.lamba1:
            self.u = self.weight.clone()
            max_idx = np.unravel_index(torch.argmax(self.u.cpu(), None),
                                       self.u.shape)
            max_value_sign = self.u[max_idx].sign()
            self.u[:] = 0
            self.u[max_idx] = (norm_w +
                               (self.alpha - 1) * self.lamba1) * max_value_sign
        else:
            self.u = self.weight.clone()
            self.u[:] = 0

    def grow_beta(self, growth_factor):
        self.beta = self.beta * growth_factor
        self.lamba1 = self.lamba / self.beta

    def _reg_w(self, **kwargs):
        logpw = -self.beta * torch.sum(
            0.5 * self.weight.add(-self.u).pow(2)) - self.lamba * np.sqrt(
                self.in_channels * self.kernel_size[0] * self.kernel_size[1]
            ) * torch.sum(
                torch.pow(torch.sum(self.weight.pow(2), 3).sum(2).sum(1), 0.5))
        logpb = 0
        if self.bias is not None:
            logpb = -torch.sum(self.weight_decay * .5 * (self.bias.pow(2)))
        return logpw + logpb

    def regularization(self):
        return self._reg_w()

    def count_zero_u(self):
        total = np.prod(self.u.size())
        zero = total - self.u.nonzero().size(0)
        return zero

    def count_zero_w(self):
        return torch.sum((self.weight.abs() < 1e-5).int()).item()

    def count_active_neuron(self):
        return torch.sum((torch.sum(self.weight.abs(), 3).sum(2).sum(1) /
                          (self.in_channels * self.kernel_size[0] *
                           self.kernel_size[1])) > 1e-5).item()

    def count_total_neuron(self):
        return self.out_channels

    def count_weight(self):
        return np.prod(self.u.size())

    def count_expected_flops_and_l0(self):
        #ppos = self.out_channels
        ppos = torch.sum(
            torch.sum(self.weight.abs(), 3).sum(2).sum(1) > 0.001).item()
        n = self.kernel_size[0] * self.kernel_size[1] * self.in_channels
        flops_per_instance = n + (n - 1)

        num_instances_per_filter = (
            (self.input_shape[1] - self.kernel_size[0] + 2 * self.padding[0]) /
            self.stride[0]) + 1
        num_instances_per_filter *= (
            (self.input_shape[2] - self.kernel_size[1] + 2 * self.padding[1]) /
            self.stride[1]) + 1

        flops_per_filter = num_instances_per_filter * flops_per_instance
        expected_flops = flops_per_filter * ppos
        expected_l0 = n * ppos

        if self.bias is not None:
            expected_flops += num_instances_per_filter * ppos
            expected_l0 += ppos
        return expected_flops, expected_l0

    def forward(self, input_):
        if self.input_shape is None:
            self.input_shape = input_.size()
        output = F.conv2d(input_, self.weight, self.bias, self.stride,
                          self.padding, self.dilation, self.groups)
        return output

    def __repr__(self):
        s = ('{name}({in_channels}, {out_channels}, kernel_size={kernel_size} '
             ', stride={stride}')
        if self.padding != (0, ) * len(self.padding):
            s += ', padding={padding}'
        if self.dilation != (1, ) * len(self.dilation):
            s += ', dilation={dilation}'
        if self.output_padding != (0, ) * len(self.output_padding):
            s += ', output_padding={output_padding}'
        if self.groups != 1:
            s += ', groups={groups}'
        if self.bias is None:
            s += ', bias=False'
        s += ')'
        return s.format(name=self.__class__.__name__, **self.__dict__)
コード例 #2
0
class group_relaxed_L0Dense(Module):
    """Implementation of TFL regularization for the input units of a fully connected layer"""
    def __init__(self,
                 in_features,
                 out_features,
                 bias=True,
                 lamba=1.,
                 beta=4.,
                 weight_decay=1.,
                 **kwargs):
        """
		:param in_features: input dimensionality
		:param out_features: output dimensionality
		:param bias: whether we use bias
		:param lamba: strength of the TF1 regularization
		"""
        super(group_relaxed_L0Dense, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(torch.Tensor(in_features, out_features))
        self.u = torch.rand(in_features, out_features)
        self.u = self.u.to('cuda')
        if bias:
            self.bias = Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.lamba = lamba
        self.beta = beta
        self.weight_decay = weight_decay
        self.floatTensor = torch.FloatTensor if not torch.cuda.is_available(
        ) else torch.cuda.FloatTensor
        self.reset_parameters()
        print(self)

    def reset_parameters(self):
        init.kaiming_normal(self.weight, mode='fan_out')

        if self.bias is not None:
            self.bias.data.normal_(0, 1e-2)

    def constrain_parameters(self, **kwargs):
        #self.weight.data = F.normalize(self.weight.data, p=2, dim=1)
        m = Hardshrink((2 * self.lamba / self.beta)**(1 / 2))
        self.u.data = m(self.weight.data)

    def grow_beta(self, growth_factor):
        self.beta = self.beta * growth_factor

    def _reg_w(self, **kwargs):
        logpw = -self.beta * torch.sum(
            0.5 * self.weight.add(-self.u).pow(2)) - self.lamba * np.sqrt(
                self.out_features) * torch.sum(
                    torch.pow(torch.sum(self.weight.pow(2), 1), 0.5))
        logpb = 0
        if self.bias is not None:
            logpb = -torch.sum(self.weight_decay * .5 * (self.bias.pow(2)))
        return logpw + logpb

    def regularization(self):
        return self._reg_w()

    def count_zero_u(self):
        total = np.prod(self.u.size())
        zero = total - self.u.nonzero().size(0)
        return zero

    def count_zero_w(self):
        return torch.sum((self.weight.abs() < 1e-5).int()).item()

    def count_weight(self):
        return np.prod(self.u.size())

    def count_active_neuron(self):
        return torch.sum(
            torch.sum(self.weight.abs() / self.out_features, 1) > 1e-5).item()

    def count_total_neuron(self):
        return self.in_features

    def count_expected_flops_and_l0(self):
        ppos = torch.sum(self.weight.abs() > 0.000001).item()
        expected_flops = (2 * ppos - 1) * self.out_features
        expected_l0 = ppos * self.out_features
        if self.bias is not None:
            expected_flops += self.out_features
            expected_l0 += self.out_features
        return expected_flops, expected_l0

    def forward(self, input):
        output = input.mm(self.weight)
        if self.bias is not None:
            output.add_(self.bias.view(1, self.out_features).expand_as(output))
        return output

    def __repr__(self):
        return self.__class__.__name__+' (' \
         + str(self.in_features) + ' -> ' \
         + str(self.out_features) + ', lambda: ' \
         + str(self.lamba) + ')'
コード例 #3
0
class group_relaxed_L1L2Dense(Module):
    """Implementation of TFL regularization for the input units of a fully connected layer"""
    def __init__(self,
                 in_features,
                 out_features,
                 bias=True,
                 lamba=1.,
                 alpha=1.,
                 beta=4.,
                 weight_decay=1.,
                 **kwargs):
        """
		:param in_features: input dimensionality
		:param out_features: output dimensionality
		:param bias: whether we use bias
		:param lamba: strength of the TF1 regularization
		"""
        super(group_relaxed_L1L2Dense, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(torch.Tensor(in_features, out_features))
        self.u = torch.rand(in_features, out_features)
        self.u = self.u.to('cuda')
        if bias:
            self.bias = Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.lamba = lamba
        self.alpha = alpha
        self.beta = beta
        self.lamba1 = self.lamba / self.beta
        self.weight_decay = weight_decay
        self.floatTensor = torch.FloatTensor if not torch.cuda.is_available(
        ) else torch.cuda.FloatTensor
        self.reset_parameters()
        print(self)

    def reset_parameters(self):
        init.kaiming_normal(self.weight, mode='fan_out')

        if self.bias is not None:
            self.bias.data.normal_(0, 1e-2)

    def constrain_parameters(self, **kwargs):
        norm_w = self.weight.data.norm(p=float('inf'))
        if norm_w > self.lamba1:
            m = Softshrink(self.lamba1)
            z = m(self.weight.data)
            self.u.data = z * (z.data.norm(p=2) +
                               self.alpha * self.lamba1) / (z.data.norm(p=2))
        elif norm_w == self.lamba1:
            self.u = self.weight.clone()
            self.u[self.u.abs() < lamba1] = 0
            n = torch.sum(self.u != 0)
            self.u[self.u != 0] = self.weight.sign(
            ) * self.alpha * self.lamba1 / (n**(1 / 2))

        elif (1 - self.alpha) * self.lamba1 < norm_w and norm_w < self.lamba1:
            self.u = self.weight.clone()
            max_idx = np.unravel_index(torch.argmax(self.u.cpu(), None),
                                       self.u.shape)
            max_value_sign = self.u[max_idx].sign()
            self.u[:] = 0
            self.u[max_idx] = (norm_w +
                               (self.alpha - 1) * self.lamba1) * max_value_sign
        else:
            self.u = self.weight.clone()
            self.u[:] = 0

    def grow_beta(self, growth_factor):
        self.beta = self.beta * growth_factor
        self.lamba1 = self.lamba / self.beta

    def _reg_w(self, **kwargs):
        logpw = -self.beta * torch.sum(
            0.5 * self.weight.add(-self.u).pow(2)) - self.lamba * np.sqrt(
                self.out_features) * torch.sum(
                    torch.pow(torch.sum(self.weight.pow(2), 1), 0.5))
        logpb = 0
        if self.bias is not None:
            logpb = -torch.sum(self.weight_decay * .5 * (self.bias.pow(2)))
        return logpw + logpb

    def regularization(self):
        return self._reg_w()

    def count_zero_u(self):
        total = np.prod(self.u.size())
        zero = total - self.u.nonzero().size(0)
        return zero

    def count_zero_w(self):
        return torch.sum((self.weight.abs() < 1e-5).int()).item()

    def count_weight(self):
        return np.prod(self.u.size())

    def count_active_neuron(self):
        return torch.sum(
            torch.sum(self.weight.abs() / self.out_features, 1) > 1e-5).item()

    def count_total_neuron(self):
        return self.in_features

    def count_expected_flops_and_l0(self):
        ppos = torch.sum(self.weight.abs() > 0.000001).item()
        expected_flops = (2 * ppos - 1) * self.out_features
        expected_l0 = ppos * self.out_features
        if self.bias is not None:
            expected_flops += self.out_features
            expected_l0 += self.out_features
        return expected_flops, expected_l0

    def forward(self, input):
        output = input.mm(self.weight)
        if self.bias is not None:
            output.add_(self.bias.view(1, self.out_features).expand_as(output))
        return output

    def __repr__(self):
        return self.__class__.__name__+' (' \
         + str(self.in_features) + ' -> ' \
         + str(self.out_features) + ', lambda: ' \
         + str(self.lamba) + ')'
コード例 #4
0
class group_relaxed_SCAD_Dense(Module):
	"""Implementation of TFL regularization for the input units of a fully connected layer"""
	def __init__(self, in_features, out_features, bias=True, lamba=1., alpha = 3.7, beta = 4.0, weight_decay=1., **kwargs):
		"""
		:param in_features: input dimensionality
		:param out_features: output dimensionality
		:param bias: whether we use bias
		:param lamba: strength of the TF1 regularization
		"""
		super(group_relaxed_SCAD_Dense,self).__init__()
		self.in_features = in_features
		self.out_features = out_features
		self.weight = Parameter(torch.Tensor(in_features, out_features))
		self.u = torch.rand(in_features, out_features)
		self.u = self.u.to('cuda')
		if bias:
			self.bias = Parameter(torch.Tensor(out_features))
		else:
			self.register_parameter('bias', None)
		self.lamba = lamba
		self.alpha = alpha
		self.beta = beta
		self.lamba1 = self.lamba/self.beta
		self.weight_decay = weight_decay
		self.floatTensor = torch.FloatTensor if not torch.cuda.is_available() else torch.cuda.FloatTensor
		self.reset_parameters()
		print(self)

	def reset_parameters(self):
		init.kaiming_normal(self.weight, mode='fan_out')

		if self.bias is not None:
			self.bias.data.normal_(0,1e-2)


	def constrain_parameters(self, **kwargs):
		self.u = self.weight.clone()
		s = Softshrink(self.lamba1)
		#shrinkage on values with absolute value less than 2*lamba1
		shrink_value = s(self.weight.data)
		self.u[self.weight.abs()<=2*self.lamba1] = shrink_value[self.weight.abs()<=2*self.lamba1]

		#modify values whose absolute values are between 2*lamba1 and alpha*lamba1
		modify_weight = self.weight.data
		modify_weight = ((self.alpha - 1)*modify_weight-modify_weight.sign()*(3.7*self.lamba1))/(self.alpha -2)
		self.u[(self.weight.abs()>2*self.lamba1) & (self.weight.abs()<=self.alpha*self.lamba1)] = modify_weight[(self.weight.abs()>2*self.lamba1) & (self.weight.abs()<=self.alpha*self.lamba1)]


	def grow_beta(self, growth_factor):
		self.beta = self.beta*growth_factor
		self.lamba1 = self.lamba/self.beta

	def _reg_w(self, **kwargs):
		logpw = -self.beta*torch.sum(0.5*self.weight.add(-self.u).pow(2))-self.lamba*np.sqrt(self.out_features)*torch.sum(torch.pow(torch.sum(self.weight.pow(2),1),0.5))
		logpb = 0
		if self.bias is not None:
			logpb = - torch.sum(self.weight_decay * .5 * (self.bias.pow(2)))
		return logpw + logpb

	def regularization(self):
		return self._reg_w()

	def count_zero_u(self):
		total = np.prod(self.u.size())
		zero = total - self.u.nonzero().size(0)
		return zero

	def count_zero_w(self):
		return torch.sum((self.weight.abs()<1e-5).int()).item()

	def count_weight(self):
		return np.prod(self.u.size())

	def count_active_neuron(self):
		return torch.sum(torch.sum(self.weight.abs()/self.out_features,1)>1e-5).item()

	def count_total_neuron(self):
		return self.in_features

	def count_expected_flops_and_l0(self):
		ppos = torch.sum(self.weight.abs()>0.000001).item()
		expected_flops = (2*ppos-1)*self.out_features
		expected_l0 = ppos*self.out_features
		if self.bias is not None:
			expected_flops += self.out_features
			expected_l0 += self.out_features
		return expected_flops, expected_l0

	def forward(self, input):
		output = input.mm(self.weight)
		if self.bias is not None:
			output.add_(self.bias.view(1, self.out_features).expand_as(output))
		return output

	def __repr__(self):
		return self.__class__.__name__+' (' \
			+ str(self.in_features) + ' -> ' \
			+ str(self.out_features) + ', lambda: ' \
			+ str(self.lamba) + ')'
コード例 #5
0
class group_relaxed_TF1Conv2d(Module):
	"""Implementation of TF1 regularization for the feature maps of a convolutional layer"""
	def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True,
		lamba=1., alpha=1., beta=4., weight_decay = 1., **kwargs):
		"""
		:param in_channels: Number of input channels
		:param out_channels: Number of output channels
		:param kernel_size: size of the kernel
		:param stride: stride for the convolution
		:param padding: padding for the convolution
		:param dilation: dilation factor for the convolution
		:param groups: how many groups we will assume in the convolution
		:param bias: whether we will use a bias
		:param lamba: strength of the TFL regularization
		"""
		super(group_relaxed_TF1Conv2d, self).__init__()
		self.floatTensor = torch.FloatTensor if not torch.cuda.is_available() else torch.cuda.FloatTensor
		self.in_channels = in_channels
		self.out_channels = out_channels
		self.kernel_size = pair(kernel_size)
		self.stride = pair(stride)
		self.padding = pair(padding)
		self.dilation = pair(dilation)
		self.output_padding = pair(0)
		self.groups = groups
		self.lamba = lamba
		self.alpha = alpha
		self.beta = beta
		self.lamba1 = self.lamba/self.beta
		self.weight_decay = weight_decay
		self.weight = Parameter(torch.Tensor(out_channels, in_channels // groups, *self.kernel_size))
		self.u = torch.rand(out_channels, in_channels // groups, *self.kernel_size)
		self.u = self.u.to('cuda')
		if bias:
			self.bias = Parameter(torch.Tensor(out_channels))
		else:
			self.register_parameter('bias', None)
		self.reset_parameters()
		self.input_shape = None
		print(self)

	def reset_parameters(self):
		init.kaiming_normal(self.weight, mode='fan_in')
		

		if self.bias is not None:
			self.bias.data.normal_(0,1e-2)

	def phi(self,x):
		phi_x = torch.acos(1-27*(self.lamba1*self.alpha*(self.alpha+1))/(2*(self.alpha+x.abs())**3))
		return phi_x

	def g(self,x):
		g_x = x.sign()*(2/3*(self.alpha + x.abs())*torch.cos(self.phi(x)/3)-2*self.alpha/3+x.abs()/3)
		return g_x


	def constrain_parameters(self, thres_std=1.):
		#self.weight.data = F.normalize(self.weight.data, p=2, dim=1)
		#print(torch.sum(self.weight.pow(2)))
		if self.lamba1 <= (self.alpha**2)/(2*(self.alpha+1)):
			t = self.lamba1*(self.alpha+1)/(self.alpha)
		else:
			t = np.sqrt(2*self.lamba1*(self.alpha+1))-self.alpha/2
		self.u.data = self.weight.data.clone()
		self.u.data[self.u.data.abs() <=t] = 0
		g_result = self.g(self.u)
		self.u.data[self.u.data.abs() > t] = g_result[self.u.data.abs() > t]

	def grow_beta(self, growth_factor):
		self.beta = self.beta*growth_factor
		self.lamba1 = self.lamba/self.beta

	def _reg_w(self, **kwargs):
		logpw = -self.beta*torch.sum(0.5*self.weight.add(-self.u).pow(2))-self.lamba*np.sqrt(self.in_channels*self.kernel_size[0]*self.kernel_size[1])*torch.sum(torch.pow(torch.sum(self.weight.pow(2),3).sum(2).sum(1),0.5))
		logpb = 0
		if self.bias is not None:
			logpb = - torch.sum(self.weight_decay * .5 * (self.bias.pow(2)))
		return logpw+logpb

	def regularization(self):
		return self._reg_w()

	def count_zero_u(self):
		total = np.prod(self.u.size())
		zero = total - self.u.nonzero().size(0)
		return zero

	def count_zero_w(self):
		return torch.sum((self.weight.abs()<1e-5).int()).item()

	def count_active_neuron(self):
		return torch.sum((torch.sum(self.weight.abs(),3).sum(2).sum(1)/(self.in_channels*self.kernel_size[0]*self.kernel_size[1]))>1e-5).item()

	def count_total_neuron(self):
		return self.out_channels


	def count_weight(self):
		return np.prod(self.u.size())

	def count_expected_flops_and_l0(self):
		#ppos = self.out_channels
		ppos = torch.sum(torch.sum(self.weight.abs(),3).sum(2).sum(1)>0.001).item()
		n = self.kernel_size[0]*self.kernel_size[1]*self.in_channels
		flops_per_instance = n+(n-1)

		num_instances_per_filter = ((self.input_shape[1] -self.kernel_size[0]+2*self.padding[0])/self.stride[0]) + 1
		num_instances_per_filter *=((self.input_shape[2] - self.kernel_size[1]+2*self.padding[1])/self.stride[1]) + 1

		flops_per_filter = num_instances_per_filter * flops_per_instance
		expected_flops = flops_per_filter*ppos
		expected_l0 = n*ppos

		if self.bias is not None:
			expected_flops += num_instances_per_filter*ppos
			expected_l0 += ppos
		return expected_flops, expected_l0

	def forward(self, input_):
		if self.input_shape is None:
			self.input_shape = input_.size()
		output = F.conv2d(input_, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
		return output

	def __repr__(self):
		s = ('{name}({in_channels}, {out_channels}, kernel_size={kernel_size} '
			', stride={stride}')
		if self.padding != (0,) * len(self.padding):
			s += ', padding={padding}'
		if self.dilation != (1,) * len(self.dilation):
			s += ', dilation={dilation}'
		if self.output_padding != (0,) * len(self.output_padding):
			s += ', output_padding={output_padding}'
		if self.groups != 1:
			s += ', groups={groups}'
		if self.bias is None:
			s += ', bias=False'
		s += ')'
		return s.format(name=self.__class__.__name__, **self.__dict__)
コード例 #6
0
class OrthogonalBayesianDense(Module):
    def __init__(self,
                 in_features: int,
                 out_features: int,
                 order: int = 8,
                 use_bias: bool = True,
                 add_diagonal: bool = True,
                 prior_std: float = 1.,
                 bias_std: float = 1e-2,
                 **kwargs):
        super(OrthogonalBayesianDense, self).__init__()
        self.floatTensor = torch.FloatTensor if not torch.cuda.is_available(
        ) else torch.cuda.FloatTensor
        self.device = torch.device(
            'cpu') if not torch.cuda.is_available() else torch.device('cuda')
        self.in_features = in_features
        self.out_features = out_features
        self.add_diagonal = add_diagonal
        self.prior_std = prior_std
        self.bias_std = bias_std

        max_feature = max(self.in_features, self.out_features)

        self.order = order or max_feature
        assert 1 <= self.order <= max_feature

        self.v_mean = Parameter(self.floatTensor(self.order, max_feature))
        self.v_logvar = Parameter(self.floatTensor(self.order, max_feature))

        if self.add_diagonal:
            self.d_mean = Parameter(
                self.floatTensor(min(self.in_features, self.out_features)))
            self.d_logvar = Parameter(
                self.floatTensor(min(self.in_features, self.out_features)))

        self.use_bias = use_bias
        if self.use_bias:
            self.bias_mean = Parameter(self.floatTensor(out_features))
            self.bias_logvar = Parameter(self.floatTensor(out_features))

        self.reset_parameters()
        print(self)

    def reset_parameters(self):
        self.v_mean.data.normal_(std=self.prior_std)
        self.v_logvar.data.normal_(mean=-5., std=self.prior_std)

        if self.add_diagonal:
            self.d_mean.data.normal_(std=self.prior_std)
            self.d_logvar.data.normal_(mean=-5., std=self.prior_std)

        if self.use_bias:
            self.bias_mean.data.normal_(std=self.bias_std)
            self.bias_logvar.data.normal_(mean=-5., std=self.bias_std)

    def _sample_eps(self, shape: tuple):
        return Variable(self.floatTensor(shape).normal_())

    def _eq_logpw(self, prior_mean: float, prior_std: float,
                  mean: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
        logpw = logvar.exp().add((prior_mean - mean)**2).div(prior_std**2).add(
            math.log(2. * math.pi * (prior_std**2))).mul(-0.5)
        return torch.sum(logpw)

    def _eq_logqw(self, logvar: torch.Tensor):
        logqw = logvar.add(math.log(2. * math.pi)).add(1.).mul(-0.5)
        return torch.sum(logqw)

    def eq_logpw(self) -> torch.Tensor:
        v = self._eq_logpw(prior_mean=0.,
                           prior_std=self.prior_std,
                           mean=self.v_mean,
                           logvar=self.v_logvar)

        if self.add_diagonal:
            d = self._eq_logpw(prior_mean=0.,
                               prior_std=self.prior_std,
                               mean=self.d_mean,
                               logvar=self.d_logvar)
            logpw = v.add(d)
        else:
            logpw = v

        if self.use_bias:
            bias = self._eq_logpw(prior_mean=0.,
                                  prior_std=self.prior_std,
                                  mean=self.bias_mean,
                                  logvar=self.bias_logvar)
            logpw.add(bias)
        return logpw

    def eq_logqw(self):
        v = self._eq_logqw(logvar=self.v_logvar)

        if self.add_diagonal:
            d = self._eq_logqw(logvar=self.d_logvar)
            logqw = v.add(d)
        else:
            logqw = v

        if self.use_bias:
            bias = self._eq_logqw(logvar=self.bias_logvar)
            logqw.add(bias)
        return logqw

    def kldiv(self):
        return self.eq_logpw() - self.eq_logqw()

    def kldiv_aux(self) -> float:
        return 0.

    def _chain_multiply(self, t: torch.Tensor) -> torch.Tensor:
        p = t[0]
        for i in range(1, t.shape[0]):
            p = p.mm(t[i])
        return p

    def _calc_householder_tensor(self, t: torch.Tensor) -> torch.Tensor:
        norm = t.norm(p=2, dim=1)
        t = t.div(norm.unsqueeze(1))
        h = torch.einsum('ab,ac->abc', (t, t))
        return torch.eye(n=t.shape[1],
                         device=self.device).expand_as(h) - h.mul(2.)

    def _calc_weights(self, v: torch.Tensor, d: torch.Tensor) -> torch.Tensor:

        u = self._chain_multiply(self._calc_householder_tensor(v))

        if self.out_features <= self.in_features:
            D = torch.eye(n=self.in_features,
                          m=self.out_features,
                          device=self.device).mm(torch.diag(d))
            W = u.mm(D)
        else:
            D = torch.diag(d).mm(
                torch.eye(n=self.in_features,
                          m=self.out_features,
                          device=self.device))
            W = D.mm(u)
        return W

    def forward(self, input: torch.Tensor):

        v = self.v_mean.add(
            self.v_logvar.mul(0.5).exp().mul(
                self._sample_eps(self.v_logvar.shape)))

        if self.add_diagonal:
            d = self.d_mean.add(
                self.d_logvar.mul(0.5).exp().mul(
                    self._sample_eps(self.d_logvar.shape)))
        else:
            d = torch.ones(min(self.in_features, self.out_features),
                           device=self.device)

        w = self._calc_weights(v=v, d=d)
        y = input.mm(w)

        if self.use_bias:
            bias = self.bias_mean.add(
                self.bias_logvar.mul(0.5).exp().mul(
                    self._sample_eps(self.bias_logvar.shape)))
            return y.add(bias.view(1, self.out_features))
        return y

    def __repr__(self):
        return self.__class__.__name__ + ' (' \
            + str(self.in_features) + ' -> ' \
            + str(self.out_features) + ', order: ' \
            + str(self.order) + ', add_diagonal: ' \
            + str(self.add_diagonal) + ')'
コード例 #7
0
class KernelBayesianConv2(ConvNd):
    def __init__(self,
                 in_channels: int,
                 out_channels: int,
                 kernel_size: int,
                 dim: int = 2,
                 stride: int = 1,
                 padding: int = 0,
                 dilation: int = 1,
                 groups: int = 1,
                 use_bias: bool = True,
                 prior_std: float = 1.,
                 bias_std: float = 1e-3,
                 **kwargs):

        stride = pair(stride)
        padding = pair(padding)
        dilation = pair(dilation)

        self.floatTensor = torch.FloatTensor if not torch.cuda.is_available(
        ) else torch.cuda.FloatTensor

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.groups = groups
        self.dim = dim
        self.kernel_size = pair(kernel_size)
        self.use_bias = use_bias
        self.prior_std = prior_std
        self.bias_std = bias_std

        super(KernelBayesianConv2, self).__init__(in_channels, out_channels,
                                                  self.kernel_size, stride,
                                                  padding, dilation, False,
                                                  pair(0), groups, use_bias)

        self.columns_mean = Parameter(
            self.floatTensor(self.in_channels * int(np.prod(self.kernel_size)),
                             self.dim))
        self.columns_logvar = Parameter(
            self.floatTensor(self.in_channels * int(np.prod(self.kernel_size)),
                             self.dim))

        self.rows_mean = Parameter(
            self.floatTensor(
                self.out_channels * int(np.prod(self.kernel_size)) // groups,
                self.dim))
        self.rows_logvar = Parameter(
            self.floatTensor(
                self.out_channels * int(np.prod(self.kernel_size)) // groups,
                self.dim))

        self.alpha_mean = Parameter(
            self.floatTensor(self.out_channels // groups, self.in_channels))
        self.alpha_logvar = Parameter(
            self.floatTensor(self.out_channels // groups, self.in_channels))

        self.use_bias = use_bias
        if self.use_bias:
            self.bias_mean = Parameter(
                self.floatTensor(self.out_channels // self.groups))
            self.bias_logvar = Parameter(
                self.floatTensor(self.out_channels // self.groups))

        self.reset_parameters()
        print(self)

    def reset_parameters(self):
        init.kaiming_normal_(self.weight, mode='fan_in')

        if hasattr(self, 'columns_mean'):
            self.columns_mean.data.normal_(std=self.prior_std)
            self.columns_logvar.data.normal_(std=self.prior_std)

        if hasattr(self, 'rows_mean'):
            self.rows_mean.data.normal_(std=self.prior_std)
            self.rows_logvar.data.normal_(std=self.prior_std)

        if hasattr(self, 'alpha_mean'):
            self.alpha_mean.data.normal_(std=self.prior_std)
            self.alpha_logvar.data.normal_(std=self.prior_std)

        if hasattr(self, 'bias_mean'):
            self.bias_mean.data.normal_(std=self.bias_std)
            self.bias_logvar.data.normal_(std=self.bias_std)

    def _calc_rbf_weights(self, rows: torch.Tensor, columns: torch.Tensor,
                          alpha: torch.Tensor) -> Parameter:
        w = self.floatTensor(self.out_channels // self.groups,
                             self.in_channels, np.prod(self.kernel_size))

        for i in range(int(np.prod(self.kernel_size))):
            row_start = i * (self.out_channels // self.groups)
            row_stop = (i + 1) * (self.out_channels // self.groups)
            col_start = i * self.in_channels
            col_stop = (i + 1) * self.in_channels

            x2 = rows[row_start:row_stop, :].pow(2).sum(dim=1).view(
                self.out_channels // self.groups, 1)
            y2 = columns[col_start:col_stop, :].pow(2).sum(dim=1).view(
                1, self.in_channels)
            xy = rows[row_start:row_stop, :].mm(
                columns[col_start:col_stop, :].t()).mul(-2.)

            w[:, :, i] = x2.add(y2).add(xy).mul(-1).exp().mul(alpha)

        return w.view(self.out_channels // self.groups, self.in_channels,
                      *self.kernel_size)

    def _sample_eps(self, shape: tuple):
        return Variable(self.floatTensor(shape).normal_())

    def _eq_logpw(self, prior_std: float, mean: torch.Tensor,
                  logvar: torch.Tensor) -> torch.Tensor:
        logpw = logvar.exp().add(mean**2).div(prior_std**2).add(
            math.log(2. * math.pi * (prior_std**2))).mul(-0.5)
        return torch.sum(logpw)

    def _eq_logqw(self, logvar: torch.Tensor):
        logqw = logvar.add(math.log(2. * math.pi)).add(1.).mul(-0.5)
        return torch.sum(logqw)

    def eq_logpw(self) -> torch.Tensor:
        rows = self._eq_logpw(prior_std=self.prior_std,
                              mean=self.rows_mean,
                              logvar=self.rows_logvar)
        columns = self._eq_logpw(prior_std=self.prior_std,
                                 mean=self.columns_mean,
                                 logvar=self.columns_logvar)
        alpha = self._eq_logpw(prior_std=self.prior_std,
                               mean=self.alpha_mean,
                               logvar=self.alpha_logvar)
        logpw = rows.add(columns).add(alpha)

        if self.use_bias:
            bias = self._eq_logpw(prior_std=self.prior_std,
                                  mean=self.bias_mean,
                                  logvar=self.bias_logvar)
            logpw.add(bias)
        return logpw

    def eq_logqw(self):
        rows = self._eq_logqw(logvar=self.rows_logvar)
        columns = self._eq_logqw(logvar=self.columns_logvar)
        alpha = self._eq_logqw(logvar=self.alpha_logvar)
        logqw = rows.add(columns).add(alpha)

        if self.use_bias:
            bias = self._eq_logqw(logvar=self.bias_logvar)
            logqw.add(bias)
        return logqw

    def kldiv(self):
        return self.eq_logpw() - self.eq_logqw()

    def kldiv_aux(self) -> float:
        return 0.

    def forward(self, input: torch.Tensor):

        rows = self.rows_mean
        columns = self.columns_mean
        alpha = self.alpha_mean

        if self.training:
            rows = self.rows_mean.add(
                self.rows_logvar.mul(0.5).exp().mul(
                    self._sample_eps(rows.shape)))
            columns = self.columns_mean.add(
                self.columns_logvar.mul(0.5).exp().mul(
                    self._sample_eps(columns.shape)))
            alpha = self.alpha_mean.add(
                self.alpha_logvar.mul(0.5).exp().mul(
                    self._sample_eps(alpha.shape)))

        weight = self._calc_rbf_weights(rows=rows,
                                        columns=columns,
                                        alpha=alpha)

        bias = self.bias_mean
        if self.training:
            bias = self.bias_mean.add(
                self.bias_logvar.mul(0.5).exp().mul(
                    self._sample_eps(self.bias_logvar.shape)))

        if not self.use_bias:
            bias = None

        return F.conv2d(input, weight, bias, self.stride, self.padding,
                        self.dilation, self.groups)

    def __repr__(self):
        s = ('{name}({in_channels}, {out_channels}, kernel_size={kernel_size} '
             ', stride={stride}, dim={dim}')
        if self.padding != (0, ) * len(self.padding):
            s += ', padding={padding}'
        if self.dilation != (1, ) * len(self.dilation):
            s += ', dilation={dilation}'
        if self.output_padding != (0, ) * len(self.output_padding):
            s += ', output_padding={output_padding}'
        if self.groups != 1:
            s += ', groups={groups}'
        if self.bias is None:
            s += ', bias=False'
        s += ')'
        return s.format(name=self.__class__.__name__, **self.__dict__)
コード例 #8
0
class OrthogonalBayesianConv2d(ConvNd):
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride=1,
                 padding=0,
                 dilation=1,
                 groups=1,
                 use_bias=True,
                 simple=True,
                 add_diagonal=True,
                 weight_decay=1.,
                 prior_std=1.,
                 bias_std=1e-2,
                 **kwargs):
        kernel_size = pair(kernel_size)
        stride = pair(stride)
        padding = pair(padding)
        dilation = pair(dilation)
        self.weight_decay = weight_decay
        self.floatTensor = torch.FloatTensor if not torch.cuda.is_available(
        ) else torch.cuda.FloatTensor
        self.device = torch.device(
            'cpu') if not torch.cuda.is_available() else torch.device('cuda')
        self.simple = simple
        self.add_diagonal = add_diagonal
        super(OrthogonalBayesianConv2d,
              self).__init__(in_channels, out_channels,
                             kernel_size, stride, padding, dilation, False,
                             pair(0), groups, use_bias)

        self.in_channels = in_channels
        self.out_channels = out_channels // groups
        self.kernel_size = kernel_size[0]

        self.prior_std = prior_std
        self.bias_std = bias_std

        if simple:
            self.r_mean = Parameter(
                self.floatTensor(self.kernel_size * self.kernel_size,
                                 self.out_channels))
            self.r_logvar = Parameter(
                self.floatTensor(self.kernel_size * self.kernel_size,
                                 self.out_channels))
        else:
            self.r_mean = Parameter(self.floatTensor(2, self.out_channels))
            self.r_logvar = Parameter(self.floatTensor(2, self.out_channels))
            self.t_mean = Parameter(
                self.floatTensor(2 * (self.kernel_size - 1),
                                 self.out_channels))
            self.t_logvar = Parameter(
                self.floatTensor(2 * (self.kernel_size - 1),
                                 self.out_channels))

        if self.add_diagonal:
            self.d_mean = Parameter(
                self.floatTensor(self.kernel_size, self.kernel_size,
                                 min(self.in_channels, self.out_channels)))
            self.d_logvar = Parameter(
                self.floatTensor(self.kernel_size, self.kernel_size,
                                 min(self.in_channels, self.out_channels)))

        self.use_bias = use_bias
        if self.use_bias:
            self.bias_mean = Parameter(
                self.floatTensor(self.out_channels // self.groups))
            self.bias_logvar = Parameter(
                self.floatTensor(self.out_channels // self.groups))

        self.reset_parameters()
        print(self)

    def reset_parameters(self):
        if hasattr(self, 'r_mean'):
            torch.nn.init.orthogonal_(self.r_mean)
            self.r_logvar.data.normal_(mean=-5., std=self.prior_std)

        if hasattr(self, 't_mean'):
            torch.nn.init.orthogonal_(self.t_mean)
            self.t_logvar.data.normal_(mean=-5., std=self.prior_std)

        if hasattr(self, 'd_mean'):
            self.d_mean.data.normal_(std=self.prior_std)
            self.d_logvar.data.normal_(mean=-5, std=self.prior_std)

        if hasattr(self, 'bias_mean'):
            self.bias_mean.data.normal_(std=self.bias_std)
            self.bias_logvar.data.normal_(mean=-5., std=self.bias_std)

    def constrain_parameters(self, thres_std=1.):
        pass

    def _sample_eps(self, shape: tuple):
        return Variable(self.floatTensor(shape).normal_())

    def _eq_logpw(self, prior_mean: float, prior_std: float,
                  mean: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
        logpw = logvar.exp().add((prior_mean - mean)**2).div(prior_std**2).add(
            math.log(2. * math.pi * (prior_std**2))).mul(-0.5)
        return torch.sum(logpw)

    def _eq_logqw(self, logvar: torch.Tensor):
        logqw = logvar.add(math.log(2. * math.pi)).add(1.).mul(-0.5)
        return torch.sum(logqw)

    def eq_logpw(self) -> torch.Tensor:
        r = self._eq_logpw(prior_mean=0.,
                           prior_std=self.prior_std,
                           mean=self.r_mean,
                           logvar=self.r_logvar)

        if self.add_diagonal:
            d = self._eq_logpw(prior_mean=0.,
                               prior_std=self.prior_std,
                               mean=self.d_mean,
                               logvar=self.d_logvar)
            logpw = r.add(d)
        else:
            logpw = r

        if not self.simple:
            t = self._eq_logpw(prior_mean=0.,
                               prior_std=self.prior_std,
                               mean=self.t_mean,
                               logvar=self.t_logvar)
            logpw = logpw.add(t)

        if self.use_bias:
            bias = self._eq_logpw(prior_mean=0.,
                                  prior_std=self.prior_std,
                                  mean=self.bias_mean,
                                  logvar=self.bias_logvar)
            logpw.add(bias)

        return logpw

    def eq_logqw(self):
        r = self._eq_logqw(logvar=self.r_logvar)

        if self.add_diagonal:
            d = self._eq_logqw(logvar=self.d_logvar)
            logqw = r.add(d)
        else:
            logqw = r

        if not self.simple:
            t = self._eq_logqw(logvar=self.t_logvar)
            logqw.add(t)

        if self.use_bias:
            bias = self._eq_logqw(logvar=self.bias_logvar)
            logqw.add(bias)
        return logqw

    def kldiv(self):
        return self.eq_logpw() - self.eq_logqw()

    def kldiv_aux(self) -> float:
        return 0.

    def _chain_multiply(self, t: torch.Tensor) -> torch.Tensor:
        p = t[0]
        for i in range(1, t.shape[0]):
            p = p.mm(t[i])
        return p

    def _calc_householder_tensor(self, t: torch.Tensor) -> torch.Tensor:
        norm = t.norm(p=2, dim=1)
        t = t.div(norm.unsqueeze(1))
        h = torch.einsum('ab,ac->abc', (t, t))
        return torch.eye(n=t.shape[1],
                         device=self.device).expand_as(h) - h.mul(2.)

    def _calc_block_orthogonal_tensor(self, size: int, t: torch.Tensor):
        h = torch.zeros(size, 2, 2, t.shape[1], t.shape[2], device=self.device)

        for i in range(size):
            p = t[2 * i]
            q = t[2 * i + 1]
            pq = p.mm(q)
            h[i, 0, 0] = pq
            h[i, 0, 1] = p.sub(pq)
            h[i, 1, 0] = q.sub(pq)
            h[i, 1, 1] = torch.eye(p.shape[0],
                                   device=self.device).add(pq).sub(p).sub(q)
        return h

    def _matrix_conv(self, s: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        assert s[0, 0].shape[0] == t[0, 0].shape[0]

        n = s[0, 0].shape[0]
        k = int(np.sqrt(s.shape[0] * s.shape[1]))
        l = int(np.sqrt(t.shape[0] * t.shape[1]))
        size = k + l - 1

        result = torch.zeros(size, size, n, n, device=self.device)

        for i in range(size):
            for j in range(size):
                for index1 in range(min(k, i + 1)):
                    for index2 in range(min(k, j + 1)):
                        if (i - index1) < l and (j - index2) < l:
                            result[i, j] += torch.mm(s[index1, index2],
                                                     t[i - index1, j - index2])
        return result

    def _orthogonal_kernel(self,
                           r: torch.Tensor,
                           t: torch.Tensor,
                           d: torch.Tensor = None,
                           transpose: bool = False) -> torch.Tensor:
        assert self.in_channels <= self.out_channels

        if self.in_channels <= self.out_channels:
            d = torch.einsum('abc,cd->abcd', (d,
                                              torch.eye(n=self.in_channels,
                                                        m=self.out_channels,
                                                        device=self.device)))
        else:
            d = torch.einsum('ab, cda->cdab',
                             (torch.eye(n=self.in_channels,
                                        m=self.out_channels,
                                        device=self.device), d))

        if self.simple:
            q = self._calc_householder_tensor(r).view(self.kernel_size,
                                                      self.kernel_size,
                                                      self.out_channels,
                                                      self.out_channels)
        else:
            r = self._chain_multiply(self._calc_householder_tensor(r))

            if self.kernel_size == 1:
                return torch.unsqueeze(torch.unsqueeze(r, 0), 0)

            t = self._calc_block_orthogonal_tensor(
                size=self.kernel_size - 1, t=self._calc_householder_tensor(t))

            s = t[0]
            for i in range(1, self.kernel_size - 1):
                s = self._matrix_conv(s=s, t=t[i])

            q = torch.einsum('ab,debc->deac', (r, s))

        q = torch.einsum('abcd,abde->abce', (d, q))

        if transpose:
            q = q.permute(2, 3, 1, 0)
        else:
            q = q.permute(3, 2, 1, 0)

        return q

    def forward(self, input_):
        r = self.r_mean.add(
            self.r_logvar.mul(0.5).exp().mul(
                self._sample_eps(self.r_logvar.shape)))

        if self.add_diagonal:
            d = self.d_mean.add(
                self.d_logvar.mul(0.5).exp().mul(
                    self._sample_eps(self.d_logvar.shape)))
        else:
            d = torch.ones(self.kernel_size,
                           self.kernel_size,
                           min(self.in_channels, self.out_channels),
                           device=self.device)

        if self.simple:
            weight = self._orthogonal_kernel(r=r, t=None, d=d)
        else:
            t = self.t_mean.add(
                self.t_logvar.mul(0.5).exp().mul(
                    self._sample_eps(self.t_logvar.shape)))
            weight = self._orthogonal_kernel(r=r, t=t, d=d)

        bias = self.bias_mean.add(
            self.bias_logvar.mul(0.5).exp().mul(
                self._sample_eps(self.bias_logvar.shape)))
        if not self.use_bias:
            bias = None

        return F.conv2d(input_, weight, bias, self.stride, self.padding,
                        self.dilation, self.groups)

    def __repr__(self):
        s = ('{name}({in_channels}, {out_channels}, kernel_size={kernel_size} '
             ', stride={stride}, weight_decay={weight_decay}')
        if self.padding != (0, ) * len(self.padding):
            s += ', padding={padding}'
        if self.dilation != (1, ) * len(self.dilation):
            s += ', dilation={dilation}'
        if self.output_padding != (0, ) * len(self.output_padding):
            s += ', output_padding={output_padding}'
        if self.groups != 1:
            s += ', groups={groups}'
        if self.bias is None:
            s += ', bias=False'
        s += ', simple={simple}'
        s += ', add_diagonal={add_diagonal}'
        s += ')'
        return s.format(name=self.__class__.__name__, **self.__dict__)