class group_relaxed_L1L2Conv2d(Module): """Implementation of TF1 regularization for the feature maps of a convolutional layer""" def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, lamba=1., alpha=1., beta=4., weight_decay=1., **kwargs): """ :param in_channels: Number of input channels :param out_channels: Number of output channels :param kernel_size: size of the kernel :param stride: stride for the convolution :param padding: padding for the convolution :param dilation: dilation factor for the convolution :param groups: how many groups we will assume in the convolution :param bias: whether we will use a bias :param lamba: strength of the TFL regularization """ super(group_relaxed_L1L2Conv2d, self).__init__() self.floatTensor = torch.FloatTensor if not torch.cuda.is_available( ) else torch.cuda.FloatTensor self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = pair(kernel_size) self.stride = pair(stride) self.padding = pair(padding) self.dilation = pair(dilation) self.output_padding = pair(0) self.groups = groups self.lamba = lamba self.alpha = alpha self.beta = beta self.lamba1 = self.lamba / self.beta self.weight_decay = weight_decay self.weight = Parameter( torch.Tensor(out_channels, in_channels // groups, *self.kernel_size)) self.u = torch.rand(out_channels, in_channels // groups, *self.kernel_size) self.u = self.u.to('cuda') if bias: self.bias = Parameter(torch.Tensor(out_channels)) else: self.register_parameter('bias', None) self.reset_parameters() self.input_shape = None print(self) def reset_parameters(self): init.kaiming_normal(self.weight, mode='fan_in') if self.bias is not None: self.bias.data.normal_(0, 1e-2) def constrain_parameters(self, **kwargs): norm_w = self.weight.data.norm(p=float('inf')) if norm_w > self.lamba1: m = Softshrink(self.lamba1) z = m(self.weight.data) self.u.data = z * (z.data.norm(p=2) + self.alpha * self.lamba1) / (z.data.norm(p=2)) elif norm_w == self.lamba1: self.u = self.weight.clone() self.u[self.u.abs() < lamba1] = 0 n = torch.sum(self.u != 0) self.u[self.u != 0] = self.weight.sign( ) * self.alpha * self.lamba1 / (n**(1 / 2)) elif (1 - self.alpha) * self.lamba1 < norm_w and norm_w < self.lamba1: self.u = self.weight.clone() max_idx = np.unravel_index(torch.argmax(self.u.cpu(), None), self.u.shape) max_value_sign = self.u[max_idx].sign() self.u[:] = 0 self.u[max_idx] = (norm_w + (self.alpha - 1) * self.lamba1) * max_value_sign else: self.u = self.weight.clone() self.u[:] = 0 def grow_beta(self, growth_factor): self.beta = self.beta * growth_factor self.lamba1 = self.lamba / self.beta def _reg_w(self, **kwargs): logpw = -self.beta * torch.sum( 0.5 * self.weight.add(-self.u).pow(2)) - self.lamba * np.sqrt( self.in_channels * self.kernel_size[0] * self.kernel_size[1] ) * torch.sum( torch.pow(torch.sum(self.weight.pow(2), 3).sum(2).sum(1), 0.5)) logpb = 0 if self.bias is not None: logpb = -torch.sum(self.weight_decay * .5 * (self.bias.pow(2))) return logpw + logpb def regularization(self): return self._reg_w() def count_zero_u(self): total = np.prod(self.u.size()) zero = total - self.u.nonzero().size(0) return zero def count_zero_w(self): return torch.sum((self.weight.abs() < 1e-5).int()).item() def count_active_neuron(self): return torch.sum((torch.sum(self.weight.abs(), 3).sum(2).sum(1) / (self.in_channels * self.kernel_size[0] * self.kernel_size[1])) > 1e-5).item() def count_total_neuron(self): return self.out_channels def count_weight(self): return np.prod(self.u.size()) def count_expected_flops_and_l0(self): #ppos = self.out_channels ppos = torch.sum( torch.sum(self.weight.abs(), 3).sum(2).sum(1) > 0.001).item() n = self.kernel_size[0] * self.kernel_size[1] * self.in_channels flops_per_instance = n + (n - 1) num_instances_per_filter = ( (self.input_shape[1] - self.kernel_size[0] + 2 * self.padding[0]) / self.stride[0]) + 1 num_instances_per_filter *= ( (self.input_shape[2] - self.kernel_size[1] + 2 * self.padding[1]) / self.stride[1]) + 1 flops_per_filter = num_instances_per_filter * flops_per_instance expected_flops = flops_per_filter * ppos expected_l0 = n * ppos if self.bias is not None: expected_flops += num_instances_per_filter * ppos expected_l0 += ppos return expected_flops, expected_l0 def forward(self, input_): if self.input_shape is None: self.input_shape = input_.size() output = F.conv2d(input_, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) return output def __repr__(self): s = ('{name}({in_channels}, {out_channels}, kernel_size={kernel_size} ' ', stride={stride}') if self.padding != (0, ) * len(self.padding): s += ', padding={padding}' if self.dilation != (1, ) * len(self.dilation): s += ', dilation={dilation}' if self.output_padding != (0, ) * len(self.output_padding): s += ', output_padding={output_padding}' if self.groups != 1: s += ', groups={groups}' if self.bias is None: s += ', bias=False' s += ')' return s.format(name=self.__class__.__name__, **self.__dict__)
class group_relaxed_L0Dense(Module): """Implementation of TFL regularization for the input units of a fully connected layer""" def __init__(self, in_features, out_features, bias=True, lamba=1., beta=4., weight_decay=1., **kwargs): """ :param in_features: input dimensionality :param out_features: output dimensionality :param bias: whether we use bias :param lamba: strength of the TF1 regularization """ super(group_relaxed_L0Dense, self).__init__() self.in_features = in_features self.out_features = out_features self.weight = Parameter(torch.Tensor(in_features, out_features)) self.u = torch.rand(in_features, out_features) self.u = self.u.to('cuda') if bias: self.bias = Parameter(torch.Tensor(out_features)) else: self.register_parameter('bias', None) self.lamba = lamba self.beta = beta self.weight_decay = weight_decay self.floatTensor = torch.FloatTensor if not torch.cuda.is_available( ) else torch.cuda.FloatTensor self.reset_parameters() print(self) def reset_parameters(self): init.kaiming_normal(self.weight, mode='fan_out') if self.bias is not None: self.bias.data.normal_(0, 1e-2) def constrain_parameters(self, **kwargs): #self.weight.data = F.normalize(self.weight.data, p=2, dim=1) m = Hardshrink((2 * self.lamba / self.beta)**(1 / 2)) self.u.data = m(self.weight.data) def grow_beta(self, growth_factor): self.beta = self.beta * growth_factor def _reg_w(self, **kwargs): logpw = -self.beta * torch.sum( 0.5 * self.weight.add(-self.u).pow(2)) - self.lamba * np.sqrt( self.out_features) * torch.sum( torch.pow(torch.sum(self.weight.pow(2), 1), 0.5)) logpb = 0 if self.bias is not None: logpb = -torch.sum(self.weight_decay * .5 * (self.bias.pow(2))) return logpw + logpb def regularization(self): return self._reg_w() def count_zero_u(self): total = np.prod(self.u.size()) zero = total - self.u.nonzero().size(0) return zero def count_zero_w(self): return torch.sum((self.weight.abs() < 1e-5).int()).item() def count_weight(self): return np.prod(self.u.size()) def count_active_neuron(self): return torch.sum( torch.sum(self.weight.abs() / self.out_features, 1) > 1e-5).item() def count_total_neuron(self): return self.in_features def count_expected_flops_and_l0(self): ppos = torch.sum(self.weight.abs() > 0.000001).item() expected_flops = (2 * ppos - 1) * self.out_features expected_l0 = ppos * self.out_features if self.bias is not None: expected_flops += self.out_features expected_l0 += self.out_features return expected_flops, expected_l0 def forward(self, input): output = input.mm(self.weight) if self.bias is not None: output.add_(self.bias.view(1, self.out_features).expand_as(output)) return output def __repr__(self): return self.__class__.__name__+' (' \ + str(self.in_features) + ' -> ' \ + str(self.out_features) + ', lambda: ' \ + str(self.lamba) + ')'
class group_relaxed_L1L2Dense(Module): """Implementation of TFL regularization for the input units of a fully connected layer""" def __init__(self, in_features, out_features, bias=True, lamba=1., alpha=1., beta=4., weight_decay=1., **kwargs): """ :param in_features: input dimensionality :param out_features: output dimensionality :param bias: whether we use bias :param lamba: strength of the TF1 regularization """ super(group_relaxed_L1L2Dense, self).__init__() self.in_features = in_features self.out_features = out_features self.weight = Parameter(torch.Tensor(in_features, out_features)) self.u = torch.rand(in_features, out_features) self.u = self.u.to('cuda') if bias: self.bias = Parameter(torch.Tensor(out_features)) else: self.register_parameter('bias', None) self.lamba = lamba self.alpha = alpha self.beta = beta self.lamba1 = self.lamba / self.beta self.weight_decay = weight_decay self.floatTensor = torch.FloatTensor if not torch.cuda.is_available( ) else torch.cuda.FloatTensor self.reset_parameters() print(self) def reset_parameters(self): init.kaiming_normal(self.weight, mode='fan_out') if self.bias is not None: self.bias.data.normal_(0, 1e-2) def constrain_parameters(self, **kwargs): norm_w = self.weight.data.norm(p=float('inf')) if norm_w > self.lamba1: m = Softshrink(self.lamba1) z = m(self.weight.data) self.u.data = z * (z.data.norm(p=2) + self.alpha * self.lamba1) / (z.data.norm(p=2)) elif norm_w == self.lamba1: self.u = self.weight.clone() self.u[self.u.abs() < lamba1] = 0 n = torch.sum(self.u != 0) self.u[self.u != 0] = self.weight.sign( ) * self.alpha * self.lamba1 / (n**(1 / 2)) elif (1 - self.alpha) * self.lamba1 < norm_w and norm_w < self.lamba1: self.u = self.weight.clone() max_idx = np.unravel_index(torch.argmax(self.u.cpu(), None), self.u.shape) max_value_sign = self.u[max_idx].sign() self.u[:] = 0 self.u[max_idx] = (norm_w + (self.alpha - 1) * self.lamba1) * max_value_sign else: self.u = self.weight.clone() self.u[:] = 0 def grow_beta(self, growth_factor): self.beta = self.beta * growth_factor self.lamba1 = self.lamba / self.beta def _reg_w(self, **kwargs): logpw = -self.beta * torch.sum( 0.5 * self.weight.add(-self.u).pow(2)) - self.lamba * np.sqrt( self.out_features) * torch.sum( torch.pow(torch.sum(self.weight.pow(2), 1), 0.5)) logpb = 0 if self.bias is not None: logpb = -torch.sum(self.weight_decay * .5 * (self.bias.pow(2))) return logpw + logpb def regularization(self): return self._reg_w() def count_zero_u(self): total = np.prod(self.u.size()) zero = total - self.u.nonzero().size(0) return zero def count_zero_w(self): return torch.sum((self.weight.abs() < 1e-5).int()).item() def count_weight(self): return np.prod(self.u.size()) def count_active_neuron(self): return torch.sum( torch.sum(self.weight.abs() / self.out_features, 1) > 1e-5).item() def count_total_neuron(self): return self.in_features def count_expected_flops_and_l0(self): ppos = torch.sum(self.weight.abs() > 0.000001).item() expected_flops = (2 * ppos - 1) * self.out_features expected_l0 = ppos * self.out_features if self.bias is not None: expected_flops += self.out_features expected_l0 += self.out_features return expected_flops, expected_l0 def forward(self, input): output = input.mm(self.weight) if self.bias is not None: output.add_(self.bias.view(1, self.out_features).expand_as(output)) return output def __repr__(self): return self.__class__.__name__+' (' \ + str(self.in_features) + ' -> ' \ + str(self.out_features) + ', lambda: ' \ + str(self.lamba) + ')'
class group_relaxed_SCAD_Dense(Module): """Implementation of TFL regularization for the input units of a fully connected layer""" def __init__(self, in_features, out_features, bias=True, lamba=1., alpha = 3.7, beta = 4.0, weight_decay=1., **kwargs): """ :param in_features: input dimensionality :param out_features: output dimensionality :param bias: whether we use bias :param lamba: strength of the TF1 regularization """ super(group_relaxed_SCAD_Dense,self).__init__() self.in_features = in_features self.out_features = out_features self.weight = Parameter(torch.Tensor(in_features, out_features)) self.u = torch.rand(in_features, out_features) self.u = self.u.to('cuda') if bias: self.bias = Parameter(torch.Tensor(out_features)) else: self.register_parameter('bias', None) self.lamba = lamba self.alpha = alpha self.beta = beta self.lamba1 = self.lamba/self.beta self.weight_decay = weight_decay self.floatTensor = torch.FloatTensor if not torch.cuda.is_available() else torch.cuda.FloatTensor self.reset_parameters() print(self) def reset_parameters(self): init.kaiming_normal(self.weight, mode='fan_out') if self.bias is not None: self.bias.data.normal_(0,1e-2) def constrain_parameters(self, **kwargs): self.u = self.weight.clone() s = Softshrink(self.lamba1) #shrinkage on values with absolute value less than 2*lamba1 shrink_value = s(self.weight.data) self.u[self.weight.abs()<=2*self.lamba1] = shrink_value[self.weight.abs()<=2*self.lamba1] #modify values whose absolute values are between 2*lamba1 and alpha*lamba1 modify_weight = self.weight.data modify_weight = ((self.alpha - 1)*modify_weight-modify_weight.sign()*(3.7*self.lamba1))/(self.alpha -2) self.u[(self.weight.abs()>2*self.lamba1) & (self.weight.abs()<=self.alpha*self.lamba1)] = modify_weight[(self.weight.abs()>2*self.lamba1) & (self.weight.abs()<=self.alpha*self.lamba1)] def grow_beta(self, growth_factor): self.beta = self.beta*growth_factor self.lamba1 = self.lamba/self.beta def _reg_w(self, **kwargs): logpw = -self.beta*torch.sum(0.5*self.weight.add(-self.u).pow(2))-self.lamba*np.sqrt(self.out_features)*torch.sum(torch.pow(torch.sum(self.weight.pow(2),1),0.5)) logpb = 0 if self.bias is not None: logpb = - torch.sum(self.weight_decay * .5 * (self.bias.pow(2))) return logpw + logpb def regularization(self): return self._reg_w() def count_zero_u(self): total = np.prod(self.u.size()) zero = total - self.u.nonzero().size(0) return zero def count_zero_w(self): return torch.sum((self.weight.abs()<1e-5).int()).item() def count_weight(self): return np.prod(self.u.size()) def count_active_neuron(self): return torch.sum(torch.sum(self.weight.abs()/self.out_features,1)>1e-5).item() def count_total_neuron(self): return self.in_features def count_expected_flops_and_l0(self): ppos = torch.sum(self.weight.abs()>0.000001).item() expected_flops = (2*ppos-1)*self.out_features expected_l0 = ppos*self.out_features if self.bias is not None: expected_flops += self.out_features expected_l0 += self.out_features return expected_flops, expected_l0 def forward(self, input): output = input.mm(self.weight) if self.bias is not None: output.add_(self.bias.view(1, self.out_features).expand_as(output)) return output def __repr__(self): return self.__class__.__name__+' (' \ + str(self.in_features) + ' -> ' \ + str(self.out_features) + ', lambda: ' \ + str(self.lamba) + ')'
class group_relaxed_TF1Conv2d(Module): """Implementation of TF1 regularization for the feature maps of a convolutional layer""" def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, lamba=1., alpha=1., beta=4., weight_decay = 1., **kwargs): """ :param in_channels: Number of input channels :param out_channels: Number of output channels :param kernel_size: size of the kernel :param stride: stride for the convolution :param padding: padding for the convolution :param dilation: dilation factor for the convolution :param groups: how many groups we will assume in the convolution :param bias: whether we will use a bias :param lamba: strength of the TFL regularization """ super(group_relaxed_TF1Conv2d, self).__init__() self.floatTensor = torch.FloatTensor if not torch.cuda.is_available() else torch.cuda.FloatTensor self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = pair(kernel_size) self.stride = pair(stride) self.padding = pair(padding) self.dilation = pair(dilation) self.output_padding = pair(0) self.groups = groups self.lamba = lamba self.alpha = alpha self.beta = beta self.lamba1 = self.lamba/self.beta self.weight_decay = weight_decay self.weight = Parameter(torch.Tensor(out_channels, in_channels // groups, *self.kernel_size)) self.u = torch.rand(out_channels, in_channels // groups, *self.kernel_size) self.u = self.u.to('cuda') if bias: self.bias = Parameter(torch.Tensor(out_channels)) else: self.register_parameter('bias', None) self.reset_parameters() self.input_shape = None print(self) def reset_parameters(self): init.kaiming_normal(self.weight, mode='fan_in') if self.bias is not None: self.bias.data.normal_(0,1e-2) def phi(self,x): phi_x = torch.acos(1-27*(self.lamba1*self.alpha*(self.alpha+1))/(2*(self.alpha+x.abs())**3)) return phi_x def g(self,x): g_x = x.sign()*(2/3*(self.alpha + x.abs())*torch.cos(self.phi(x)/3)-2*self.alpha/3+x.abs()/3) return g_x def constrain_parameters(self, thres_std=1.): #self.weight.data = F.normalize(self.weight.data, p=2, dim=1) #print(torch.sum(self.weight.pow(2))) if self.lamba1 <= (self.alpha**2)/(2*(self.alpha+1)): t = self.lamba1*(self.alpha+1)/(self.alpha) else: t = np.sqrt(2*self.lamba1*(self.alpha+1))-self.alpha/2 self.u.data = self.weight.data.clone() self.u.data[self.u.data.abs() <=t] = 0 g_result = self.g(self.u) self.u.data[self.u.data.abs() > t] = g_result[self.u.data.abs() > t] def grow_beta(self, growth_factor): self.beta = self.beta*growth_factor self.lamba1 = self.lamba/self.beta def _reg_w(self, **kwargs): logpw = -self.beta*torch.sum(0.5*self.weight.add(-self.u).pow(2))-self.lamba*np.sqrt(self.in_channels*self.kernel_size[0]*self.kernel_size[1])*torch.sum(torch.pow(torch.sum(self.weight.pow(2),3).sum(2).sum(1),0.5)) logpb = 0 if self.bias is not None: logpb = - torch.sum(self.weight_decay * .5 * (self.bias.pow(2))) return logpw+logpb def regularization(self): return self._reg_w() def count_zero_u(self): total = np.prod(self.u.size()) zero = total - self.u.nonzero().size(0) return zero def count_zero_w(self): return torch.sum((self.weight.abs()<1e-5).int()).item() def count_active_neuron(self): return torch.sum((torch.sum(self.weight.abs(),3).sum(2).sum(1)/(self.in_channels*self.kernel_size[0]*self.kernel_size[1]))>1e-5).item() def count_total_neuron(self): return self.out_channels def count_weight(self): return np.prod(self.u.size()) def count_expected_flops_and_l0(self): #ppos = self.out_channels ppos = torch.sum(torch.sum(self.weight.abs(),3).sum(2).sum(1)>0.001).item() n = self.kernel_size[0]*self.kernel_size[1]*self.in_channels flops_per_instance = n+(n-1) num_instances_per_filter = ((self.input_shape[1] -self.kernel_size[0]+2*self.padding[0])/self.stride[0]) + 1 num_instances_per_filter *=((self.input_shape[2] - self.kernel_size[1]+2*self.padding[1])/self.stride[1]) + 1 flops_per_filter = num_instances_per_filter * flops_per_instance expected_flops = flops_per_filter*ppos expected_l0 = n*ppos if self.bias is not None: expected_flops += num_instances_per_filter*ppos expected_l0 += ppos return expected_flops, expected_l0 def forward(self, input_): if self.input_shape is None: self.input_shape = input_.size() output = F.conv2d(input_, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) return output def __repr__(self): s = ('{name}({in_channels}, {out_channels}, kernel_size={kernel_size} ' ', stride={stride}') if self.padding != (0,) * len(self.padding): s += ', padding={padding}' if self.dilation != (1,) * len(self.dilation): s += ', dilation={dilation}' if self.output_padding != (0,) * len(self.output_padding): s += ', output_padding={output_padding}' if self.groups != 1: s += ', groups={groups}' if self.bias is None: s += ', bias=False' s += ')' return s.format(name=self.__class__.__name__, **self.__dict__)
class OrthogonalBayesianDense(Module): def __init__(self, in_features: int, out_features: int, order: int = 8, use_bias: bool = True, add_diagonal: bool = True, prior_std: float = 1., bias_std: float = 1e-2, **kwargs): super(OrthogonalBayesianDense, self).__init__() self.floatTensor = torch.FloatTensor if not torch.cuda.is_available( ) else torch.cuda.FloatTensor self.device = torch.device( 'cpu') if not torch.cuda.is_available() else torch.device('cuda') self.in_features = in_features self.out_features = out_features self.add_diagonal = add_diagonal self.prior_std = prior_std self.bias_std = bias_std max_feature = max(self.in_features, self.out_features) self.order = order or max_feature assert 1 <= self.order <= max_feature self.v_mean = Parameter(self.floatTensor(self.order, max_feature)) self.v_logvar = Parameter(self.floatTensor(self.order, max_feature)) if self.add_diagonal: self.d_mean = Parameter( self.floatTensor(min(self.in_features, self.out_features))) self.d_logvar = Parameter( self.floatTensor(min(self.in_features, self.out_features))) self.use_bias = use_bias if self.use_bias: self.bias_mean = Parameter(self.floatTensor(out_features)) self.bias_logvar = Parameter(self.floatTensor(out_features)) self.reset_parameters() print(self) def reset_parameters(self): self.v_mean.data.normal_(std=self.prior_std) self.v_logvar.data.normal_(mean=-5., std=self.prior_std) if self.add_diagonal: self.d_mean.data.normal_(std=self.prior_std) self.d_logvar.data.normal_(mean=-5., std=self.prior_std) if self.use_bias: self.bias_mean.data.normal_(std=self.bias_std) self.bias_logvar.data.normal_(mean=-5., std=self.bias_std) def _sample_eps(self, shape: tuple): return Variable(self.floatTensor(shape).normal_()) def _eq_logpw(self, prior_mean: float, prior_std: float, mean: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor: logpw = logvar.exp().add((prior_mean - mean)**2).div(prior_std**2).add( math.log(2. * math.pi * (prior_std**2))).mul(-0.5) return torch.sum(logpw) def _eq_logqw(self, logvar: torch.Tensor): logqw = logvar.add(math.log(2. * math.pi)).add(1.).mul(-0.5) return torch.sum(logqw) def eq_logpw(self) -> torch.Tensor: v = self._eq_logpw(prior_mean=0., prior_std=self.prior_std, mean=self.v_mean, logvar=self.v_logvar) if self.add_diagonal: d = self._eq_logpw(prior_mean=0., prior_std=self.prior_std, mean=self.d_mean, logvar=self.d_logvar) logpw = v.add(d) else: logpw = v if self.use_bias: bias = self._eq_logpw(prior_mean=0., prior_std=self.prior_std, mean=self.bias_mean, logvar=self.bias_logvar) logpw.add(bias) return logpw def eq_logqw(self): v = self._eq_logqw(logvar=self.v_logvar) if self.add_diagonal: d = self._eq_logqw(logvar=self.d_logvar) logqw = v.add(d) else: logqw = v if self.use_bias: bias = self._eq_logqw(logvar=self.bias_logvar) logqw.add(bias) return logqw def kldiv(self): return self.eq_logpw() - self.eq_logqw() def kldiv_aux(self) -> float: return 0. def _chain_multiply(self, t: torch.Tensor) -> torch.Tensor: p = t[0] for i in range(1, t.shape[0]): p = p.mm(t[i]) return p def _calc_householder_tensor(self, t: torch.Tensor) -> torch.Tensor: norm = t.norm(p=2, dim=1) t = t.div(norm.unsqueeze(1)) h = torch.einsum('ab,ac->abc', (t, t)) return torch.eye(n=t.shape[1], device=self.device).expand_as(h) - h.mul(2.) def _calc_weights(self, v: torch.Tensor, d: torch.Tensor) -> torch.Tensor: u = self._chain_multiply(self._calc_householder_tensor(v)) if self.out_features <= self.in_features: D = torch.eye(n=self.in_features, m=self.out_features, device=self.device).mm(torch.diag(d)) W = u.mm(D) else: D = torch.diag(d).mm( torch.eye(n=self.in_features, m=self.out_features, device=self.device)) W = D.mm(u) return W def forward(self, input: torch.Tensor): v = self.v_mean.add( self.v_logvar.mul(0.5).exp().mul( self._sample_eps(self.v_logvar.shape))) if self.add_diagonal: d = self.d_mean.add( self.d_logvar.mul(0.5).exp().mul( self._sample_eps(self.d_logvar.shape))) else: d = torch.ones(min(self.in_features, self.out_features), device=self.device) w = self._calc_weights(v=v, d=d) y = input.mm(w) if self.use_bias: bias = self.bias_mean.add( self.bias_logvar.mul(0.5).exp().mul( self._sample_eps(self.bias_logvar.shape))) return y.add(bias.view(1, self.out_features)) return y def __repr__(self): return self.__class__.__name__ + ' (' \ + str(self.in_features) + ' -> ' \ + str(self.out_features) + ', order: ' \ + str(self.order) + ', add_diagonal: ' \ + str(self.add_diagonal) + ')'
class KernelBayesianConv2(ConvNd): def __init__(self, in_channels: int, out_channels: int, kernel_size: int, dim: int = 2, stride: int = 1, padding: int = 0, dilation: int = 1, groups: int = 1, use_bias: bool = True, prior_std: float = 1., bias_std: float = 1e-3, **kwargs): stride = pair(stride) padding = pair(padding) dilation = pair(dilation) self.floatTensor = torch.FloatTensor if not torch.cuda.is_available( ) else torch.cuda.FloatTensor self.in_channels = in_channels self.out_channels = out_channels self.groups = groups self.dim = dim self.kernel_size = pair(kernel_size) self.use_bias = use_bias self.prior_std = prior_std self.bias_std = bias_std super(KernelBayesianConv2, self).__init__(in_channels, out_channels, self.kernel_size, stride, padding, dilation, False, pair(0), groups, use_bias) self.columns_mean = Parameter( self.floatTensor(self.in_channels * int(np.prod(self.kernel_size)), self.dim)) self.columns_logvar = Parameter( self.floatTensor(self.in_channels * int(np.prod(self.kernel_size)), self.dim)) self.rows_mean = Parameter( self.floatTensor( self.out_channels * int(np.prod(self.kernel_size)) // groups, self.dim)) self.rows_logvar = Parameter( self.floatTensor( self.out_channels * int(np.prod(self.kernel_size)) // groups, self.dim)) self.alpha_mean = Parameter( self.floatTensor(self.out_channels // groups, self.in_channels)) self.alpha_logvar = Parameter( self.floatTensor(self.out_channels // groups, self.in_channels)) self.use_bias = use_bias if self.use_bias: self.bias_mean = Parameter( self.floatTensor(self.out_channels // self.groups)) self.bias_logvar = Parameter( self.floatTensor(self.out_channels // self.groups)) self.reset_parameters() print(self) def reset_parameters(self): init.kaiming_normal_(self.weight, mode='fan_in') if hasattr(self, 'columns_mean'): self.columns_mean.data.normal_(std=self.prior_std) self.columns_logvar.data.normal_(std=self.prior_std) if hasattr(self, 'rows_mean'): self.rows_mean.data.normal_(std=self.prior_std) self.rows_logvar.data.normal_(std=self.prior_std) if hasattr(self, 'alpha_mean'): self.alpha_mean.data.normal_(std=self.prior_std) self.alpha_logvar.data.normal_(std=self.prior_std) if hasattr(self, 'bias_mean'): self.bias_mean.data.normal_(std=self.bias_std) self.bias_logvar.data.normal_(std=self.bias_std) def _calc_rbf_weights(self, rows: torch.Tensor, columns: torch.Tensor, alpha: torch.Tensor) -> Parameter: w = self.floatTensor(self.out_channels // self.groups, self.in_channels, np.prod(self.kernel_size)) for i in range(int(np.prod(self.kernel_size))): row_start = i * (self.out_channels // self.groups) row_stop = (i + 1) * (self.out_channels // self.groups) col_start = i * self.in_channels col_stop = (i + 1) * self.in_channels x2 = rows[row_start:row_stop, :].pow(2).sum(dim=1).view( self.out_channels // self.groups, 1) y2 = columns[col_start:col_stop, :].pow(2).sum(dim=1).view( 1, self.in_channels) xy = rows[row_start:row_stop, :].mm( columns[col_start:col_stop, :].t()).mul(-2.) w[:, :, i] = x2.add(y2).add(xy).mul(-1).exp().mul(alpha) return w.view(self.out_channels // self.groups, self.in_channels, *self.kernel_size) def _sample_eps(self, shape: tuple): return Variable(self.floatTensor(shape).normal_()) def _eq_logpw(self, prior_std: float, mean: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor: logpw = logvar.exp().add(mean**2).div(prior_std**2).add( math.log(2. * math.pi * (prior_std**2))).mul(-0.5) return torch.sum(logpw) def _eq_logqw(self, logvar: torch.Tensor): logqw = logvar.add(math.log(2. * math.pi)).add(1.).mul(-0.5) return torch.sum(logqw) def eq_logpw(self) -> torch.Tensor: rows = self._eq_logpw(prior_std=self.prior_std, mean=self.rows_mean, logvar=self.rows_logvar) columns = self._eq_logpw(prior_std=self.prior_std, mean=self.columns_mean, logvar=self.columns_logvar) alpha = self._eq_logpw(prior_std=self.prior_std, mean=self.alpha_mean, logvar=self.alpha_logvar) logpw = rows.add(columns).add(alpha) if self.use_bias: bias = self._eq_logpw(prior_std=self.prior_std, mean=self.bias_mean, logvar=self.bias_logvar) logpw.add(bias) return logpw def eq_logqw(self): rows = self._eq_logqw(logvar=self.rows_logvar) columns = self._eq_logqw(logvar=self.columns_logvar) alpha = self._eq_logqw(logvar=self.alpha_logvar) logqw = rows.add(columns).add(alpha) if self.use_bias: bias = self._eq_logqw(logvar=self.bias_logvar) logqw.add(bias) return logqw def kldiv(self): return self.eq_logpw() - self.eq_logqw() def kldiv_aux(self) -> float: return 0. def forward(self, input: torch.Tensor): rows = self.rows_mean columns = self.columns_mean alpha = self.alpha_mean if self.training: rows = self.rows_mean.add( self.rows_logvar.mul(0.5).exp().mul( self._sample_eps(rows.shape))) columns = self.columns_mean.add( self.columns_logvar.mul(0.5).exp().mul( self._sample_eps(columns.shape))) alpha = self.alpha_mean.add( self.alpha_logvar.mul(0.5).exp().mul( self._sample_eps(alpha.shape))) weight = self._calc_rbf_weights(rows=rows, columns=columns, alpha=alpha) bias = self.bias_mean if self.training: bias = self.bias_mean.add( self.bias_logvar.mul(0.5).exp().mul( self._sample_eps(self.bias_logvar.shape))) if not self.use_bias: bias = None return F.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups) def __repr__(self): s = ('{name}({in_channels}, {out_channels}, kernel_size={kernel_size} ' ', stride={stride}, dim={dim}') if self.padding != (0, ) * len(self.padding): s += ', padding={padding}' if self.dilation != (1, ) * len(self.dilation): s += ', dilation={dilation}' if self.output_padding != (0, ) * len(self.output_padding): s += ', output_padding={output_padding}' if self.groups != 1: s += ', groups={groups}' if self.bias is None: s += ', bias=False' s += ')' return s.format(name=self.__class__.__name__, **self.__dict__)
class OrthogonalBayesianConv2d(ConvNd): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, use_bias=True, simple=True, add_diagonal=True, weight_decay=1., prior_std=1., bias_std=1e-2, **kwargs): kernel_size = pair(kernel_size) stride = pair(stride) padding = pair(padding) dilation = pair(dilation) self.weight_decay = weight_decay self.floatTensor = torch.FloatTensor if not torch.cuda.is_available( ) else torch.cuda.FloatTensor self.device = torch.device( 'cpu') if not torch.cuda.is_available() else torch.device('cuda') self.simple = simple self.add_diagonal = add_diagonal super(OrthogonalBayesianConv2d, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, False, pair(0), groups, use_bias) self.in_channels = in_channels self.out_channels = out_channels // groups self.kernel_size = kernel_size[0] self.prior_std = prior_std self.bias_std = bias_std if simple: self.r_mean = Parameter( self.floatTensor(self.kernel_size * self.kernel_size, self.out_channels)) self.r_logvar = Parameter( self.floatTensor(self.kernel_size * self.kernel_size, self.out_channels)) else: self.r_mean = Parameter(self.floatTensor(2, self.out_channels)) self.r_logvar = Parameter(self.floatTensor(2, self.out_channels)) self.t_mean = Parameter( self.floatTensor(2 * (self.kernel_size - 1), self.out_channels)) self.t_logvar = Parameter( self.floatTensor(2 * (self.kernel_size - 1), self.out_channels)) if self.add_diagonal: self.d_mean = Parameter( self.floatTensor(self.kernel_size, self.kernel_size, min(self.in_channels, self.out_channels))) self.d_logvar = Parameter( self.floatTensor(self.kernel_size, self.kernel_size, min(self.in_channels, self.out_channels))) self.use_bias = use_bias if self.use_bias: self.bias_mean = Parameter( self.floatTensor(self.out_channels // self.groups)) self.bias_logvar = Parameter( self.floatTensor(self.out_channels // self.groups)) self.reset_parameters() print(self) def reset_parameters(self): if hasattr(self, 'r_mean'): torch.nn.init.orthogonal_(self.r_mean) self.r_logvar.data.normal_(mean=-5., std=self.prior_std) if hasattr(self, 't_mean'): torch.nn.init.orthogonal_(self.t_mean) self.t_logvar.data.normal_(mean=-5., std=self.prior_std) if hasattr(self, 'd_mean'): self.d_mean.data.normal_(std=self.prior_std) self.d_logvar.data.normal_(mean=-5, std=self.prior_std) if hasattr(self, 'bias_mean'): self.bias_mean.data.normal_(std=self.bias_std) self.bias_logvar.data.normal_(mean=-5., std=self.bias_std) def constrain_parameters(self, thres_std=1.): pass def _sample_eps(self, shape: tuple): return Variable(self.floatTensor(shape).normal_()) def _eq_logpw(self, prior_mean: float, prior_std: float, mean: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor: logpw = logvar.exp().add((prior_mean - mean)**2).div(prior_std**2).add( math.log(2. * math.pi * (prior_std**2))).mul(-0.5) return torch.sum(logpw) def _eq_logqw(self, logvar: torch.Tensor): logqw = logvar.add(math.log(2. * math.pi)).add(1.).mul(-0.5) return torch.sum(logqw) def eq_logpw(self) -> torch.Tensor: r = self._eq_logpw(prior_mean=0., prior_std=self.prior_std, mean=self.r_mean, logvar=self.r_logvar) if self.add_diagonal: d = self._eq_logpw(prior_mean=0., prior_std=self.prior_std, mean=self.d_mean, logvar=self.d_logvar) logpw = r.add(d) else: logpw = r if not self.simple: t = self._eq_logpw(prior_mean=0., prior_std=self.prior_std, mean=self.t_mean, logvar=self.t_logvar) logpw = logpw.add(t) if self.use_bias: bias = self._eq_logpw(prior_mean=0., prior_std=self.prior_std, mean=self.bias_mean, logvar=self.bias_logvar) logpw.add(bias) return logpw def eq_logqw(self): r = self._eq_logqw(logvar=self.r_logvar) if self.add_diagonal: d = self._eq_logqw(logvar=self.d_logvar) logqw = r.add(d) else: logqw = r if not self.simple: t = self._eq_logqw(logvar=self.t_logvar) logqw.add(t) if self.use_bias: bias = self._eq_logqw(logvar=self.bias_logvar) logqw.add(bias) return logqw def kldiv(self): return self.eq_logpw() - self.eq_logqw() def kldiv_aux(self) -> float: return 0. def _chain_multiply(self, t: torch.Tensor) -> torch.Tensor: p = t[0] for i in range(1, t.shape[0]): p = p.mm(t[i]) return p def _calc_householder_tensor(self, t: torch.Tensor) -> torch.Tensor: norm = t.norm(p=2, dim=1) t = t.div(norm.unsqueeze(1)) h = torch.einsum('ab,ac->abc', (t, t)) return torch.eye(n=t.shape[1], device=self.device).expand_as(h) - h.mul(2.) def _calc_block_orthogonal_tensor(self, size: int, t: torch.Tensor): h = torch.zeros(size, 2, 2, t.shape[1], t.shape[2], device=self.device) for i in range(size): p = t[2 * i] q = t[2 * i + 1] pq = p.mm(q) h[i, 0, 0] = pq h[i, 0, 1] = p.sub(pq) h[i, 1, 0] = q.sub(pq) h[i, 1, 1] = torch.eye(p.shape[0], device=self.device).add(pq).sub(p).sub(q) return h def _matrix_conv(self, s: torch.Tensor, t: torch.Tensor) -> torch.Tensor: assert s[0, 0].shape[0] == t[0, 0].shape[0] n = s[0, 0].shape[0] k = int(np.sqrt(s.shape[0] * s.shape[1])) l = int(np.sqrt(t.shape[0] * t.shape[1])) size = k + l - 1 result = torch.zeros(size, size, n, n, device=self.device) for i in range(size): for j in range(size): for index1 in range(min(k, i + 1)): for index2 in range(min(k, j + 1)): if (i - index1) < l and (j - index2) < l: result[i, j] += torch.mm(s[index1, index2], t[i - index1, j - index2]) return result def _orthogonal_kernel(self, r: torch.Tensor, t: torch.Tensor, d: torch.Tensor = None, transpose: bool = False) -> torch.Tensor: assert self.in_channels <= self.out_channels if self.in_channels <= self.out_channels: d = torch.einsum('abc,cd->abcd', (d, torch.eye(n=self.in_channels, m=self.out_channels, device=self.device))) else: d = torch.einsum('ab, cda->cdab', (torch.eye(n=self.in_channels, m=self.out_channels, device=self.device), d)) if self.simple: q = self._calc_householder_tensor(r).view(self.kernel_size, self.kernel_size, self.out_channels, self.out_channels) else: r = self._chain_multiply(self._calc_householder_tensor(r)) if self.kernel_size == 1: return torch.unsqueeze(torch.unsqueeze(r, 0), 0) t = self._calc_block_orthogonal_tensor( size=self.kernel_size - 1, t=self._calc_householder_tensor(t)) s = t[0] for i in range(1, self.kernel_size - 1): s = self._matrix_conv(s=s, t=t[i]) q = torch.einsum('ab,debc->deac', (r, s)) q = torch.einsum('abcd,abde->abce', (d, q)) if transpose: q = q.permute(2, 3, 1, 0) else: q = q.permute(3, 2, 1, 0) return q def forward(self, input_): r = self.r_mean.add( self.r_logvar.mul(0.5).exp().mul( self._sample_eps(self.r_logvar.shape))) if self.add_diagonal: d = self.d_mean.add( self.d_logvar.mul(0.5).exp().mul( self._sample_eps(self.d_logvar.shape))) else: d = torch.ones(self.kernel_size, self.kernel_size, min(self.in_channels, self.out_channels), device=self.device) if self.simple: weight = self._orthogonal_kernel(r=r, t=None, d=d) else: t = self.t_mean.add( self.t_logvar.mul(0.5).exp().mul( self._sample_eps(self.t_logvar.shape))) weight = self._orthogonal_kernel(r=r, t=t, d=d) bias = self.bias_mean.add( self.bias_logvar.mul(0.5).exp().mul( self._sample_eps(self.bias_logvar.shape))) if not self.use_bias: bias = None return F.conv2d(input_, weight, bias, self.stride, self.padding, self.dilation, self.groups) def __repr__(self): s = ('{name}({in_channels}, {out_channels}, kernel_size={kernel_size} ' ', stride={stride}, weight_decay={weight_decay}') if self.padding != (0, ) * len(self.padding): s += ', padding={padding}' if self.dilation != (1, ) * len(self.dilation): s += ', dilation={dilation}' if self.output_padding != (0, ) * len(self.output_padding): s += ', output_padding={output_padding}' if self.groups != 1: s += ', groups={groups}' if self.bias is None: s += ', bias=False' s += ', simple={simple}' s += ', add_diagonal={add_diagonal}' s += ')' return s.format(name=self.__class__.__name__, **self.__dict__)