class RiemannLayer(nn.Module): def __init__(self, in_features, out_features, manifold, over_param): super(RiemannLayer, self).__init__() self.in_features = in_features self.out_features = out_features self.weight = Parameter(torch.Tensor(out_features, in_features)) self.over_param = over_param if self.over_param: self._bias = ManifoldParameter(torch.Tensor(out_features, in_features), manifold=manifold) else: self._bias = Parameter(torch.Tensor(out_features, 1)) self.manifold = manifold self.reset_parameters() @property def weight(self): return self.manifold.transp0(self._bias, self._weight) @property def bias(self): return self.manifold.expmap0(self.weight.mul(self._bias)) def reset_parameters(self): nn.init.kaiming_normal_(self.weight, a =math.sqrt(5)) if self.bias is not None: fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) bound = 4 / math.sqrt(fan_in) nn.init.uniform_(self.bias, -bound, bound) if self.over_param: with torch.no_grad(): self._bias.set_(self.manifold.expmap0(self._bias))
class ConvexNet(Module): __constants__ = ['width'] def __init__(self, width): super(ConvexNet, self).__init__() self.width = width self.weight = Parameter(torch.Tensor(2, width)) self.bias = Parameter(torch.Tensor(width)) def set_value(self, lines: List[Line]): assert len(lines) == self.width weight = [[], []] bias = [] for line in lines: weight[0].append(line.a) weight[1].append(line.b) bias.append(line.c) self.weight = Parameter(torch.Tensor(weight)) self.bias = Parameter(torch.Tensor(bias)) def get_value(self) -> List[Line]: ret = [] weight_list = self.weight.tolist() bias_list = self.bias.tolist() for i in range(self.width): ret.append(Line(weight_list[0][i], weight_list[1][i], bias_list[i])) return ret def forward(self, input): length2 = torch.sum(self.weight.mul(self.weight), dim=0) length = torch.sqrt(length2) approx_dis = torch.addmm(self.bias, input, self.weight) real_dis = torch.div(approx_dis, length) max_dis, _ = torch.max(real_dis, dim=1) ret = torch.sigmoid(max_dis * 1e2) # print(real_dis) # print(max_dis) # print(ret) return ret
class NoisyConv2d(Module): """Applies a noisy conv2d transformation to the incoming data: More details can be found in the paper `Noisy Networks for Exploration` _ . Args: in_channels: size of each input sample out_channels: size of each output sample bias: If set to False, the layer will not learn an additive bias. Default: True factorised: whether or not to use factorised noise. Default: True std_init: initialization constant for standard deviation component of weights. If None, defaults to 0.017 for independent and 0.4 for factorised. Default: None Shape: - Input: :math:`(N, in\_features)` - Output: :math:`(N, out\_features)` Attributes: weight: the learnable weights of the module of shape (out_features x in_features) bias: the learnable bias of the module of shape (out_features) Examples:: >>> m = NoisyConv2d(4, 2, (3,1)) >>> input = torch.autograd.Variable(torch.randn(1, 4, 51, 3)) >>> output = m(input) >>> print(output.size()) """ def __init__(self, in_channels, out_channels, kernel_size, bias=True, stride=1, padding=1, dilation=1, groups=1, factorised=True, std_init=None, gpu_id=0): super(NoisyConv2d, self).__init__() self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = _pair(kernel_size) self.stride = _pair(stride) self.padding = _pair(padding) self.dilation = _pair(dilation) self.groups = groups self.factorised = factorised self.weight_mu = Parameter( torch.Tensor(out_channels, in_channels // groups, *kernel_size)) self.weight_sigma = Parameter( torch.Tensor(out_channels, in_channels // groups, *kernel_size)) self.gpu_id = gpu_id if bias: self.bias_mu = Parameter(torch.Tensor(out_channels)) self.bias_sigma = Parameter(torch.Tensor(out_channels)) else: self.register_parameter('bias', None) if not std_init: if self.factorised: self.std_init = 0.4 else: self.std_init = 0.017 else: self.std_init = std_init self.reset_parameters(bias) def reset_parameters(self, bias): if self.factorised: mu_range = 1. / math.sqrt(self.weight_mu.size(1)) self.weight_mu.data.uniform_(-mu_range, mu_range) self.weight_sigma.data.fill_(self.std_init / math.sqrt(self.weight_sigma.size(1))) if bias: self.bias_mu.data.uniform_(-mu_range, mu_range) self.bias_sigma.data.fill_(self.std_init / math.sqrt(self.bias_sigma.size(0))) else: mu_range = math.sqrt(3. / self.weight_mu.size(1)) self.weight_mu.data.uniform_(-mu_range, mu_range) self.weight_sigma.data.fill_(self.std_init) if bias: self.bias_mu.data.uniform_(-mu_range, mu_range) self.bias_sigma.data.fill_(self.std_init) def scale_noise(self, size): with torch.cuda.device(self.gpu_id): x = torch.Tensor(size).normal_().cuda() x = x.sign().mul(x.abs().sqrt()) return x def forward(self, input): if self.factorised: epsilon = None for dim in self.weight_sigma.size(): if epsilon is None: epsilon = self.scale_noise(dim) else: epsilon = epsilon.unsqueeze(-1) * self.scale_noise(dim) weight_epsilon = Variable(epsilon) bias_epsilon = Variable(self.scale_noise(self.out_channels)) else: with torch.cuda.device(self.gpu_id): weight_epsilon = Variable( torch.Tensor(self.out_channels, self.in_channels, *self.kernel_size).normal_()).cuda() bias_epsilon = Variable( torch.Tensor(self.out_channels).normal_()).cuda() return F.conv2d(input, self.weight_mu + self.weight_sigma.mul(weight_epsilon), self.bias_mu + self.bias_sigma.mul(bias_epsilon), stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups) def __repr__(self): s = ('{name}({in_channels}, {out_channels}, kernel_size={kernel_size}' ', stride={stride}') if self.padding != (0, ) * len(self.padding): s += ', padding={padding}' if self.dilation != (1, ) * len(self.dilation): s += ', dilation={dilation}' if self.groups != 1: s += ', groups={groups}' if self.bias_mu is None: s += ', bias=False' s += ')' return s.format(name=self.__class__.__name__, **self.__dict__)
class NoisyLinear(Module): r"""Applies a noisy linear transformation to the incoming data. During training: .. math:: `y = (mu_w + sigma_w \cdot epsilon_w)x + mu_b + sigma_b \cdot epsilon_b` During evaluation: .. math:: `y = mu_w * x + mu_b` More details can be found in the paper `Noisy Networks for Exploration` _ . Args: in_features: size of each input sample out_features: size of each output sample bias: If set to False, the layer will not learn an additive bias. Default: True factorized: whether or not to use factorized noise. Default: True std_init: constant for weight_sigma and bias_sigma initialization. If None, defaults to 0.017 for independent and 0.4 for factorized. Default: None Shape: - Input: :math:`(N, in\_features)` - Output: :math:`(N, out\_features)` Attributes: weight: the learnable weights of the module of shape (out_features x in_features) bias: the learnable bias of the module of shape (out_features) Methods: resample: resamples the noise tensors Examples:: >>> m = nn.NoisyLinear(20, 30) >>> input = autograd.Variable(torch.randn(128, 20)) >>> output = m(input) >>> m.resample() >>> output_new = m(input) >>> print(output) >>> print(output_new) """ def __init__(self, in_features, out_features, bias=True, factorized=True, std_init=None): super(NoisyLinear, self).__init__() self.in_features = in_features self.out_features = out_features self.factorized = factorized self.include_bias = bias self.weight_mu = Parameter(torch.Tensor(out_features, in_features)) self.weight_sigma = Parameter(torch.Tensor(out_features, in_features)) self.register_buffer('weight_epsilon', torch.Tensor(out_features, in_features)) if self.include_bias: self.bias_mu = Parameter(torch.Tensor(out_features)) self.bias_sigma = Parameter(torch.Tensor(out_features)) self.register_buffer('bias_epsilon', torch.Tensor(out_features)) else: self.register_parameter('bias', None) if not std_init: if self.factorized: self.std_init = 0.4 else: self.std_init = 0.017 else: self.std_init = std_init self.reset_parameters() self.resample() def reset_parameters(self): if self.factorized: mu_range = 1. / math.sqrt(self.weight_mu.size(1)) self.weight_mu.data.uniform_(-mu_range, mu_range) self.weight_sigma.data.fill_(self.std_init / math.sqrt(self.weight_sigma.size(1))) if self.include_bias: self.bias_mu.data.uniform_(-mu_range, mu_range) self.bias_sigma.data.fill_(self.std_init / math.sqrt(self.bias_sigma.size(0))) else: mu_range = math.sqrt(3. / self.weight_mu.size(1)) self.weight_mu.data.uniform_(-mu_range, mu_range) self.weight_sigma.data.fill_(self.std_init) if self.include_bias: self.bias_mu.data.uniform_(-mu_range, mu_range) self.bias_sigma.data.fill_(self.std_init) def _scale_noise(self, size): x = torch.randn(size) x = x.sign().mul(x.abs().sqrt()) return x def resample(self): if self.factorized: epsilon_in = self._scale_noise(self.in_features) epsilon_out = self._scale_noise(self.out_features) self.weight_epsilon.copy_(epsilon_out.ger(epsilon_in)) if self.include_bias: self.bias_epsilon.copy_(self._scale_noise(self.out_features)) else: self.weight_epsilon.normal_() if self.include_bias: self.bias_epsilon.normal_() def forward(self, input): if self.training: return F.linear( input, self.weight_mu + self.weight_sigma.mul(Variable(self.weight_epsilon)), self.bias_mu + self.bias_sigma.mul(Variable(self.bias_epsilon))) else: return F.linear(input, self.weight_mu, self.bias_mu) def __repr__(self): return self.__class__.__name__ + ' (' \ + str(self.in_features) + ' -> ' \ + str(self.out_features) + ', Factorized: ' \ + str(self.factorized) + ')'
class HSDense(Module): def __init__(self, in_features, out_features, bias=True, prior_std=1., prior_std_z=1., dof=1., **kwargs): super(HSDense, self).__init__() self.in_features = in_features self.out_features = out_features self.prior_std = prior_std self.mean_w = Parameter(torch.Tensor(in_features, out_features)) self.logvar_w = Parameter(torch.Tensor(in_features, out_features)) self.qz_mean = Parameter(torch.Tensor(in_features)) self.qz_logvar = Parameter(torch.Tensor(in_features)) self.dof = dof self.prior_std_z = prior_std_z self.use_bias = False if bias: self.mean_bias = Parameter(torch.Tensor(out_features)) self.logvar_bias = Parameter(torch.Tensor(out_features)) self.use_bias = True self.floatTensor = torch.FloatTensor if not torch.cuda.is_available( ) else torch.cuda.FloatTensor self.reset_parameters() print(self) def reset_parameters(self): init.kaiming_normal_(self.mean_w, mode='fan_out') self.logvar_w.data.normal_(-9., 1e-4) self.qz_mean.data.normal_(math.log(math.exp(1) - 1), 1e-3) self.qz_logvar.data.normal_(math.log(0.1), 1e-4) if self.use_bias: self.mean_bias.data.normal_(0, 1e-2) self.logvar_bias.data.normal_(-9., 1e-4) def constrain_parameters(self, thres_std=1.): self.logvar_w.data.clamp_(max=2. * math.log(thres_std)) if self.use_bias: self.logvar_bias.data.clamp_(max=2. * math.log(thres_std)) def eq_logpw(self): logpw = -.5 * math.log( 2 * math.pi * self.prior_std**2) - .5 * self.logvar_w.exp().div( self.prior_std**2) logpw -= .5 * self.mean_w.pow(2).div(self.prior_std**2) logpb = 0. if self.use_bias: logpb = - .5 * math.log(2 * math.pi * self.prior_std ** 2) - .5 * self.logvar_bias.exp().div \ (self.prior_std ** 2) logpb -= .5 * self.mean_bias.pow(2).div(self.prior_std**2) return torch.sum(logpw) + torch.sum(logpb) def eq_logqw(self): logqw = -torch.sum(.5 * (math.log(2 * math.pi) + self.logvar_w + 1)) logqb = 0. if self.use_bias: logqb = -torch.sum(.5 * (math.log(2 * math.pi) + self.logvar_bias + 1)) return logqw + logqb def kldiv_aux(self): z = self.sample_z(1) z = z.view(self.in_features) logqm = -torch.sum(.5 * (math.log(2 * math.pi) + self.qz_logvar + 1)) logqm = logqm.add(-torch.sum(F.sigmoid(z.exp().add(-1).log()).log())) logpm = torch.sum( 2 * math.lgamma(.5 * (self.dof + 1)) - math.lgamma(.5 * self.dof) - math.log(self.prior_std_z) - .5 * math.log(self.dof * math.pi) - .5 * (self.dof + 1) * torch.log(1. + z.pow(2) / (self.dof * self.prior_std_z**2))) return logpm - logqm def kldiv(self): return self.kldiv_aux() + self.eq_logpw() - self.eq_logqw() def get_eps(self, size): eps = self.floatTensor(size).normal_() eps = Variable(eps) return eps def sample_z(self, batch_size): z = self.qz_mean.view(1, self.in_features) if self.training: eps = self.get_eps(self.floatTensor(batch_size, self.in_features)) z = z + eps.mul( self.qz_logvar.view(1, self.in_features).mul(0.5).exp_()) return F.softplus(z) def sample_W(self): W = self.mean_w if self.training: eps = self.get_eps(self.mean_w.size()) W = W.add(eps.mul(self.logvar_w.mul(0.5).exp_())) return W def sample_b(self): b = self.mean_bias if self.training: eps = self.get_eps(self.mean_bias.size()) b = b.add(eps.mul(self.logvar_bias.mul(0.5).exp_())) return b def get_mean_x(self, input): mean_xin = input.mm(self.mean_w) if self.use_bias: mean_xin = mean_xin.add(self.mean_bias.view(1, self.out_features)) return mean_xin def get_var_x(self, input): var_xin = input.pow(2).mm(self.logvar_w.exp()) if self.use_bias: var_xin = var_xin.add(self.logvar_bias.exp().view( 1, self.out_features)) return var_xin def forward(self, input): # sampling batch_size = input.size(0) z = self.sample_z(batch_size) xin = input.mul(z) mean_xin = self.get_mean_x(xin) output = mean_xin if self.training: var_xin = self.get_var_x(xin) eps = self.get_eps(self.floatTensor(batch_size, self.out_features)) output = output.add(var_xin.sqrt().mul(eps)) return output def __repr__(self): return self.__class__.__name__ + ' (' \ + str(self.in_features) + ' -> ' \ + str(self.out_features) + ', dof: ' \ + str(self.dof) + ', prior_std_z: ' \ + str(self.prior_std_z) + ')'
class KernelDenseBayesian(Module): def __init__(self, in_features: int, out_features: int, dim: int = 2, use_bias: bool = True, prior_std: float = 1., bias_std: float = 1e-3, **kwargs): super(KernelDenseBayesian, self).__init__() self.floatTensor = torch.FloatTensor if not torch.cuda.is_available( ) else torch.cuda.FloatTensor self.in_features = in_features self.out_features = out_features self.dim = dim self.prior_std = prior_std self.bias_std = bias_std self.columns_mean = Parameter( self.floatTensor(self.in_features, self.dim)) self.columns_logvar = Parameter( self.floatTensor(self.in_features, self.dim)) self.rows_mean = Parameter( self.floatTensor(self.out_features, self.dim)) self.rows_logvar = Parameter( self.floatTensor(self.out_features, self.dim)) self.alpha_mean = Parameter(self.floatTensor(self.in_features)) self.alpha_logvar = Parameter(self.floatTensor(self.in_features)) self.use_bias = use_bias if self.use_bias: self.bias_mean = Parameter(self.floatTensor(self.out_features)) self.bias_logvar = Parameter(self.floatTensor(self.out_features)) self.reset_parameters() print(self) def reset_parameters(self): self.columns_mean.data.normal_(std=self.prior_std) self.columns_logvar.data.normal_(std=self.prior_std) self.rows_mean.data.normal_(std=self.prior_std) self.rows_logvar.data.normal_(std=self.prior_std) self.alpha_mean.data.normal_(std=self.prior_std) self.alpha_logvar.data.normal_(std=self.prior_std) if self.use_bias: self.bias_mean.data.normal_(std=self.bias_std) self.bias_logvar.data.normal_(std=self.bias_std) def _calc_rbf_weights(self, rows: torch.Tensor, columns: torch.Tensor): x2 = rows.pow(2).sum(dim=1).view(1, self.out_features) y2 = columns.pow(2).sum(dim=1).view(self.in_features, 1) xy = columns.mm(rows.t()).mul(-2.) return x2.add(y2).add(xy).mul(-1).exp() def _sample_eps(self, shape: tuple): return Variable(self.floatTensor(shape).normal_()) def _eq_logpw(self, prior_std: float, mean: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor: logpw = logvar.exp().add(mean**2).div(prior_std**2).add( math.log(2. * math.pi * (prior_std**2))).mul(-0.5) return torch.sum(logpw) def _eq_logqw(self, logvar: torch.Tensor): logqw = logvar.add(math.log(2. * math.pi)).add(1.).mul(-0.5) return torch.sum(logqw) def eq_logpw(self) -> torch.Tensor: rows = self._eq_logpw(prior_std=self.prior_std, mean=self.rows_mean, logvar=self.rows_logvar) columns = self._eq_logpw(prior_std=self.prior_std, mean=self.columns_mean, logvar=self.columns_logvar) alpha = self._eq_logpw(prior_std=self.prior_std, mean=self.alpha_mean, logvar=self.alpha_logvar) logpw = rows.add(columns).add(alpha) if self.use_bias: bias = self._eq_logpw(prior_std=self.prior_std, mean=self.bias_mean, logvar=self.bias_logvar) logpw.add(bias) return logpw def eq_logqw(self): rows = self._eq_logqw(logvar=self.rows_logvar) columns = self._eq_logqw(logvar=self.columns_logvar) alpha = self._eq_logqw(logvar=self.alpha_logvar) logqw = rows.add(columns).add(alpha) if self.use_bias: bias = self._eq_logqw(logvar=self.bias_logvar) logqw.add(bias) return logqw def kldiv(self): return self.eq_logpw() - self.eq_logqw() def kldiv_aux(self) -> float: return 0. def forward(self, input: torch.Tensor): rows = self.rows_mean columns = self.columns_mean alpha = self.alpha_mean if self.training: rows.add( self.rows_logvar.mul(0.5).exp().mul( self._sample_eps(rows.shape))) columns.add( self.columns_logvar.mul(0.5).exp().mul( self._sample_eps(columns.shape))) alpha.add( self.alpha_logvar.mul(0.5).exp().mul( self._sample_eps(alpha.shape))) w = self._calc_rbf_weights(rows=rows, columns=columns) y = input.mul(alpha).mm(w) if self.use_bias: y.add(self.bias_mean.view(1, self.out_features)) if self.training: y.add( self.bias_logvar.mul(0.5).exp().mul( self._sample_eps(self.bias_logvar.shape)).view( 1, self.out_features)) return y def __repr__(self): return self.__class__.__name__ + ' (' \ + str(self.in_features) + ' -> ' \ + str(self.out_features) + ', dim: ' \ + str(self.dim) + ')'
class OrthogonalBayesianDense(Module): def __init__(self, in_features: int, out_features: int, order: int = 8, use_bias: bool = True, add_diagonal: bool = True, prior_std: float = 1., bias_std: float = 1e-2, **kwargs): super(OrthogonalBayesianDense, self).__init__() self.floatTensor = torch.FloatTensor if not torch.cuda.is_available( ) else torch.cuda.FloatTensor self.device = torch.device( 'cpu') if not torch.cuda.is_available() else torch.device('cuda') self.in_features = in_features self.out_features = out_features self.add_diagonal = add_diagonal self.prior_std = prior_std self.bias_std = bias_std max_feature = max(self.in_features, self.out_features) self.order = order or max_feature assert 1 <= self.order <= max_feature self.v_mean = Parameter(self.floatTensor(self.order, max_feature)) self.v_logvar = Parameter(self.floatTensor(self.order, max_feature)) if self.add_diagonal: self.d_mean = Parameter( self.floatTensor(min(self.in_features, self.out_features))) self.d_logvar = Parameter( self.floatTensor(min(self.in_features, self.out_features))) self.use_bias = use_bias if self.use_bias: self.bias_mean = Parameter(self.floatTensor(out_features)) self.bias_logvar = Parameter(self.floatTensor(out_features)) self.reset_parameters() print(self) def reset_parameters(self): self.v_mean.data.normal_(std=self.prior_std) self.v_logvar.data.normal_(mean=-5., std=self.prior_std) if self.add_diagonal: self.d_mean.data.normal_(std=self.prior_std) self.d_logvar.data.normal_(mean=-5., std=self.prior_std) if self.use_bias: self.bias_mean.data.normal_(std=self.bias_std) self.bias_logvar.data.normal_(mean=-5., std=self.bias_std) def _sample_eps(self, shape: tuple): return Variable(self.floatTensor(shape).normal_()) def _eq_logpw(self, prior_mean: float, prior_std: float, mean: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor: logpw = logvar.exp().add((prior_mean - mean)**2).div(prior_std**2).add( math.log(2. * math.pi * (prior_std**2))).mul(-0.5) return torch.sum(logpw) def _eq_logqw(self, logvar: torch.Tensor): logqw = logvar.add(math.log(2. * math.pi)).add(1.).mul(-0.5) return torch.sum(logqw) def eq_logpw(self) -> torch.Tensor: v = self._eq_logpw(prior_mean=0., prior_std=self.prior_std, mean=self.v_mean, logvar=self.v_logvar) if self.add_diagonal: d = self._eq_logpw(prior_mean=0., prior_std=self.prior_std, mean=self.d_mean, logvar=self.d_logvar) logpw = v.add(d) else: logpw = v if self.use_bias: bias = self._eq_logpw(prior_mean=0., prior_std=self.prior_std, mean=self.bias_mean, logvar=self.bias_logvar) logpw.add(bias) return logpw def eq_logqw(self): v = self._eq_logqw(logvar=self.v_logvar) if self.add_diagonal: d = self._eq_logqw(logvar=self.d_logvar) logqw = v.add(d) else: logqw = v if self.use_bias: bias = self._eq_logqw(logvar=self.bias_logvar) logqw.add(bias) return logqw def kldiv(self): return self.eq_logpw() - self.eq_logqw() def kldiv_aux(self) -> float: return 0. def _chain_multiply(self, t: torch.Tensor) -> torch.Tensor: p = t[0] for i in range(1, t.shape[0]): p = p.mm(t[i]) return p def _calc_householder_tensor(self, t: torch.Tensor) -> torch.Tensor: norm = t.norm(p=2, dim=1) t = t.div(norm.unsqueeze(1)) h = torch.einsum('ab,ac->abc', (t, t)) return torch.eye(n=t.shape[1], device=self.device).expand_as(h) - h.mul(2.) def _calc_weights(self, v: torch.Tensor, d: torch.Tensor) -> torch.Tensor: u = self._chain_multiply(self._calc_householder_tensor(v)) if self.out_features <= self.in_features: D = torch.eye(n=self.in_features, m=self.out_features, device=self.device).mm(torch.diag(d)) W = u.mm(D) else: D = torch.diag(d).mm( torch.eye(n=self.in_features, m=self.out_features, device=self.device)) W = D.mm(u) return W def forward(self, input: torch.Tensor): v = self.v_mean.add( self.v_logvar.mul(0.5).exp().mul( self._sample_eps(self.v_logvar.shape))) if self.add_diagonal: d = self.d_mean.add( self.d_logvar.mul(0.5).exp().mul( self._sample_eps(self.d_logvar.shape))) else: d = torch.ones(min(self.in_features, self.out_features), device=self.device) w = self._calc_weights(v=v, d=d) y = input.mm(w) if self.use_bias: bias = self.bias_mean.add( self.bias_logvar.mul(0.5).exp().mul( self._sample_eps(self.bias_logvar.shape))) return y.add(bias.view(1, self.out_features)) return y def __repr__(self): return self.__class__.__name__ + ' (' \ + str(self.in_features) + ' -> ' \ + str(self.out_features) + ', order: ' \ + str(self.order) + ', add_diagonal: ' \ + str(self.add_diagonal) + ')'
class FFGaussDense(Module): def __init__(self, in_features, out_features, bias=True, prior_std=1., **kwargs): super(FFGaussDense, self).__init__() self.in_features = in_features self.out_features = out_features self.prior_std = prior_std self.mean_w = Parameter(torch.Tensor(in_features, out_features)) self.logvar_w = Parameter(torch.Tensor(in_features, out_features)) self.use_bias = False if bias: self.mean_bias = Parameter(torch.Tensor(out_features)) self.logvar_bias = Parameter(torch.Tensor(out_features)) self.use_bias = True self.floatTensor = torch.FloatTensor if not torch.cuda.is_available( ) else torch.cuda.FloatTensor self.reset_parameters() print(self) def reset_parameters(self): init.kaiming_normal_(self.mean_w, mode='fan_out') self.logvar_w.data.normal_(-9., 1e-4) if self.use_bias: self.mean_bias.data.zero_() self.logvar_bias.data.normal_(-9., 1e-4) def constrain_parameters(self, thres_std=1.): self.logvar_w.data.clamp_(max=2. * math.log(thres_std)) if self.use_bias: self.logvar_bias.data.clamp_(max=2. * math.log(thres_std)) def eq_logpw(self): logpw = -.5 * math.log( 2 * math.pi * self.prior_std**2) - .5 * self.logvar_w.exp().div( self.prior_std**2) logpw -= .5 * self.mean_w.pow(2).div(self.prior_std**2) logpb = 0. if self.use_bias: logpb = -.5 * math.log(2 * math.pi * self.prior_std** 2) - .5 * self.logvar_bias.exp().div( self.prior_std**2) logpb -= .5 * self.mean_bias.pow(2).div(self.prior_std**2) return torch.sum(logpw) + torch.sum(logpb) def eq_logqw(self): logqw = -torch.sum(.5 * (math.log(2 * math.pi) + self.logvar_w + 1)) logqb = 0. if self.use_bias: logqb = -torch.sum(.5 * (math.log(2 * math.pi) + self.logvar_bias + 1)) return logqw + logqb def kldiv_aux(self): return 0. def kldiv(self): return self.kldiv_aux() + self.eq_logpw() - self.eq_logqw() def get_eps(self, size): eps = self.floatTensor(size).normal_() eps = Variable(eps) return eps def sample_pW(self): return self.floatTensor(self.in_features, self.out_features).normal_() def sample_pb(self): return self.floatTensor(self.out_features).normal_() def sample_W(self): w = self.mean_w if self.training: eps = self.get_eps(self.mean_w.size()) w = w.add(eps.mul(self.logvar_w.mul(0.5).exp_())) return w def sample_b(self): b = self.mean_bias if self.training: eps = self.get_eps(self.mean_bias.size()) b = b.add(eps.mul(self.logvar_bias.mul(0.5).exp_())) return b def get_mean_x(self, input): mean_xin = input.mm(self.mean_w) if self.use_bias: mean_xin = mean_xin.add(self.mean_bias.view(1, self.out_features)) return mean_xin def get_var_x(self, input): var_xin = input.pow(2).mm(self.logvar_w.exp()) if self.use_bias: var_xin = var_xin.add(self.logvar_bias.exp().view( 1, self.out_features)) return var_xin def forward(self, input): batch_size = input.size(0) mean_xin = self.get_mean_x(input) if self.training: var_xin = self.get_var_x(input) eps = self.get_eps(self.floatTensor(batch_size, self.out_features)) output = mean_xin.add(var_xin.sqrt().mul(eps)) else: output = mean_xin return output def __repr__(self): return self.__class__.__name__ + ' (' \ + str(self.in_features) + ' -> ' \ + str(self.out_features) + ', prior_std: ' \ + str(self.prior_std) + ')'
class KernelBayesianConv2(ConvNd): def __init__(self, in_channels: int, out_channels: int, kernel_size: int, dim: int = 2, stride: int = 1, padding: int = 0, dilation: int = 1, groups: int = 1, use_bias: bool = True, prior_std: float = 1., bias_std: float = 1e-3, **kwargs): stride = pair(stride) padding = pair(padding) dilation = pair(dilation) self.floatTensor = torch.FloatTensor if not torch.cuda.is_available( ) else torch.cuda.FloatTensor self.in_channels = in_channels self.out_channels = out_channels self.groups = groups self.dim = dim self.kernel_size = pair(kernel_size) self.use_bias = use_bias self.prior_std = prior_std self.bias_std = bias_std super(KernelBayesianConv2, self).__init__(in_channels, out_channels, self.kernel_size, stride, padding, dilation, False, pair(0), groups, use_bias) self.columns_mean = Parameter( self.floatTensor(self.in_channels * int(np.prod(self.kernel_size)), self.dim)) self.columns_logvar = Parameter( self.floatTensor(self.in_channels * int(np.prod(self.kernel_size)), self.dim)) self.rows_mean = Parameter( self.floatTensor( self.out_channels * int(np.prod(self.kernel_size)) // groups, self.dim)) self.rows_logvar = Parameter( self.floatTensor( self.out_channels * int(np.prod(self.kernel_size)) // groups, self.dim)) self.alpha_mean = Parameter( self.floatTensor(self.out_channels // groups, self.in_channels)) self.alpha_logvar = Parameter( self.floatTensor(self.out_channels // groups, self.in_channels)) self.use_bias = use_bias if self.use_bias: self.bias_mean = Parameter( self.floatTensor(self.out_channels // self.groups)) self.bias_logvar = Parameter( self.floatTensor(self.out_channels // self.groups)) self.reset_parameters() print(self) def reset_parameters(self): init.kaiming_normal_(self.weight, mode='fan_in') if hasattr(self, 'columns_mean'): self.columns_mean.data.normal_(std=self.prior_std) self.columns_logvar.data.normal_(std=self.prior_std) if hasattr(self, 'rows_mean'): self.rows_mean.data.normal_(std=self.prior_std) self.rows_logvar.data.normal_(std=self.prior_std) if hasattr(self, 'alpha_mean'): self.alpha_mean.data.normal_(std=self.prior_std) self.alpha_logvar.data.normal_(std=self.prior_std) if hasattr(self, 'bias_mean'): self.bias_mean.data.normal_(std=self.bias_std) self.bias_logvar.data.normal_(std=self.bias_std) def _calc_rbf_weights(self, rows: torch.Tensor, columns: torch.Tensor, alpha: torch.Tensor) -> Parameter: w = self.floatTensor(self.out_channels // self.groups, self.in_channels, np.prod(self.kernel_size)) for i in range(int(np.prod(self.kernel_size))): row_start = i * (self.out_channels // self.groups) row_stop = (i + 1) * (self.out_channels // self.groups) col_start = i * self.in_channels col_stop = (i + 1) * self.in_channels x2 = rows[row_start:row_stop, :].pow(2).sum(dim=1).view( self.out_channels // self.groups, 1) y2 = columns[col_start:col_stop, :].pow(2).sum(dim=1).view( 1, self.in_channels) xy = rows[row_start:row_stop, :].mm( columns[col_start:col_stop, :].t()).mul(-2.) w[:, :, i] = x2.add(y2).add(xy).mul(-1).exp().mul(alpha) return w.view(self.out_channels // self.groups, self.in_channels, *self.kernel_size) def _sample_eps(self, shape: tuple): return Variable(self.floatTensor(shape).normal_()) def _eq_logpw(self, prior_std: float, mean: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor: logpw = logvar.exp().add(mean**2).div(prior_std**2).add( math.log(2. * math.pi * (prior_std**2))).mul(-0.5) return torch.sum(logpw) def _eq_logqw(self, logvar: torch.Tensor): logqw = logvar.add(math.log(2. * math.pi)).add(1.).mul(-0.5) return torch.sum(logqw) def eq_logpw(self) -> torch.Tensor: rows = self._eq_logpw(prior_std=self.prior_std, mean=self.rows_mean, logvar=self.rows_logvar) columns = self._eq_logpw(prior_std=self.prior_std, mean=self.columns_mean, logvar=self.columns_logvar) alpha = self._eq_logpw(prior_std=self.prior_std, mean=self.alpha_mean, logvar=self.alpha_logvar) logpw = rows.add(columns).add(alpha) if self.use_bias: bias = self._eq_logpw(prior_std=self.prior_std, mean=self.bias_mean, logvar=self.bias_logvar) logpw.add(bias) return logpw def eq_logqw(self): rows = self._eq_logqw(logvar=self.rows_logvar) columns = self._eq_logqw(logvar=self.columns_logvar) alpha = self._eq_logqw(logvar=self.alpha_logvar) logqw = rows.add(columns).add(alpha) if self.use_bias: bias = self._eq_logqw(logvar=self.bias_logvar) logqw.add(bias) return logqw def kldiv(self): return self.eq_logpw() - self.eq_logqw() def kldiv_aux(self) -> float: return 0. def forward(self, input: torch.Tensor): rows = self.rows_mean columns = self.columns_mean alpha = self.alpha_mean if self.training: rows = self.rows_mean.add( self.rows_logvar.mul(0.5).exp().mul( self._sample_eps(rows.shape))) columns = self.columns_mean.add( self.columns_logvar.mul(0.5).exp().mul( self._sample_eps(columns.shape))) alpha = self.alpha_mean.add( self.alpha_logvar.mul(0.5).exp().mul( self._sample_eps(alpha.shape))) weight = self._calc_rbf_weights(rows=rows, columns=columns, alpha=alpha) bias = self.bias_mean if self.training: bias = self.bias_mean.add( self.bias_logvar.mul(0.5).exp().mul( self._sample_eps(self.bias_logvar.shape))) if not self.use_bias: bias = None return F.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups) def __repr__(self): s = ('{name}({in_channels}, {out_channels}, kernel_size={kernel_size} ' ', stride={stride}, dim={dim}') if self.padding != (0, ) * len(self.padding): s += ', padding={padding}' if self.dilation != (1, ) * len(self.dilation): s += ', dilation={dilation}' if self.output_padding != (0, ) * len(self.output_padding): s += ', output_padding={output_padding}' if self.groups != 1: s += ', groups={groups}' if self.bias is None: s += ', bias=False' s += ')' return s.format(name=self.__class__.__name__, **self.__dict__)
class FFGaussConv2d(Module): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, prior_std=1, **kwargs): super(FFGaussConv2d, self).__init__() if in_channels % groups != 0: raise ValueError('in_channels must be divisible by groups') if out_channels % groups != 0: raise ValueError('out_channels must be divisible by groups') self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = pair(kernel_size) self.stride = pair(stride) self.padding = pair(padding) self.dilation = pair(dilation) self.output_padding = pair(0) self.groups = groups self.prior_std = prior_std self.use_bias = False self.mean_w = Parameter( torch.Tensor(out_channels, in_channels // groups, *self.kernel_size)) self.logvar_w = Parameter( torch.Tensor(out_channels, in_channels // groups, *self.kernel_size)) if bias: self.mean_bias = Parameter(torch.Tensor(out_channels)) self.logvar_bias = Parameter(torch.Tensor(out_channels)) self.use_bias = True self.floatTensor = torch.FloatTensor if not torch.cuda.is_available( ) else torch.cuda.FloatTensor self.reset_parameters() print(self) def reset_parameters(self): init.kaiming_normal_(self.mean_w, mode='fan_in') self.logvar_w.data.normal_(-9., 1e-4) if self.use_bias: self.mean_bias.data.zero_() self.logvar_bias.data.normal_(-9., 1e-4) def constrain_parameters(self, thres_std=1.): self.logvar_w.data.clamp_(max=2. * math.log(thres_std)) if self.use_bias: self.logvar_bias.data.clamp_(max=2. * math.log(thres_std)) def eq_logpw(self): logpw = -.5 * math.log( 2 * math.pi * self.prior_std**2) - .5 * self.logvar_w.exp().div( self.prior_std**2) logpw -= .5 * self.mean_w.pow(2).div(self.prior_std**2) logpb = 0. if self.use_bias: logpb = -.5 * math.log(2 * math.pi * self.prior_std** 2) - .5 * self.logvar_bias.exp().div( self.prior_std**2) logpb -= .5 * self.mean_bias.pow(2).div(self.prior_std**2) return torch.sum(logpw) + torch.sum(logpb) def eq_logqw(self): logqw = -torch.sum(.5 * (math.log(2 * math.pi) + self.logvar_w + 1)) logqb = 0. if self.use_bias: logqb = -torch.sum(.5 * (math.log(2 * math.pi) + self.logvar_bias + 1)) return logqw + logqb def kldiv_aux(self): return 0. def kldiv(self): return self.kldiv_aux() + self.eq_logpw() - self.eq_logqw() def get_eps(self, size): eps = self.floatTensor(size).normal_() eps = Variable(eps) return eps def sample_W(self): W = self.mean_w if self.training: eps = self.get_eps(self.mean_w.size()) W = W.add(eps.mul(self.logvar_w.mul(0.5).exp_())) return W def sample_b(self): if not self.use_bias: return None b = self.mean_bias if self.training: eps = self.get_eps(self.mean_bias.size()) b = b.add(eps.mul(self.logvar_bias.mul(0.5).exp_())) return b def forward(self, input_): W = self.sample_W() b = self.sample_b() return F.conv2d(input_, W, b, self.stride, self.padding, self.dilation, self.groups) def __repr__(self): s = ('{name}({in_channels}, {out_channels}, kernel_size={kernel_size}' ', stride={stride}, prior_std={prior_std}') if self.padding != (0, ) * len(self.padding): s += ', padding={padding}' if self.dilation != (1, ) * len(self.dilation): s += ', dilation={dilation}' if self.output_padding != (0, ) * len(self.output_padding): s += ', output_padding={output_padding}' if self.groups != 1: s += ', groups={groups}' if not self.use_bias: s += ', bias=False' s += ')' return s.format(name=self.__class__.__name__, **self.__dict__)
class HSConv2d(Module): '''Input channel noise''' def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, prior_std=1., prior_std_z=1., dof=1., **kwargs): super(HSConv2d, self).__init__() if in_channels % groups != 0: raise ValueError('in_channels must be divisible by groups') if out_channels % groups != 0: raise ValueError('out_channels must be divisible by groups') self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = pair(kernel_size) self.stride = pair(stride) self.padding = pair(padding) self.dilation = pair(dilation) self.output_padding = pair(0) self.groups = groups self.prior_std = prior_std self.prior_std_z = prior_std_z self.use_bias = False self.dof = dof self.mean_w = Parameter( torch.Tensor(out_channels, in_channels // groups, *self.kernel_size)) self.logvar_w = Parameter( torch.Tensor(out_channels, in_channels // groups, *self.kernel_size)) self.qz_mean = Parameter(torch.Tensor(in_channels // groups)) self.qz_logvar = Parameter(torch.Tensor(in_channels // groups)) self.dim_z = in_channels // groups if bias: self.mean_bias = Parameter(torch.Tensor(out_channels)) self.logvar_bias = Parameter(torch.Tensor(out_channels)) self.use_bias = True self.floatTensor = torch.FloatTensor if not torch.cuda.is_available( ) else torch.cuda.FloatTensor self.reset_parameters() print(self) def reset_parameters(self): init.kaiming_normal_(self.mean_w, mode='fan_in') self.logvar_w.data.normal_(-9., 1e-4) self.qz_mean.data.normal_(math.log(math.exp(1) - 1), 1e-3) self.qz_logvar.data.normal_(math.log(0.1), 1e-4) if self.use_bias: self.mean_bias.data.normal_(0, 1e-2) self.logvar_bias.data.normal_(-9., 1e-4) def constrain_parameters(self, thres_std=1.): self.logvar_w.data.clamp_(max=2. * math.log(thres_std)) if self.use_bias: self.logvar_bias.data.clamp_(max=2. * math.log(thres_std)) def eq_logpw(self): logpw = -.5 * math.log( 2 * math.pi * self.prior_std**2) - .5 * self.logvar_w.exp().div( self.prior_std**2) logpw -= .5 * self.mean_w.pow(2).div(self.prior_std**2) logpb = 0. if self.use_bias: logpb = -.5 * math.log(2 * math.pi * self.prior_std** 2) - .5 * self.logvar_bias.exp().div( self.prior_std**2) logpb -= .5 * self.mean_bias.pow(2).div(self.prior_std**2) logpb = torch.sum(logpb) return torch.sum(logpw) + logpb def eq_logqw(self): logqw = -torch.sum(.5 * (math.log(2 * math.pi) + self.logvar_w + 1)) logqb = 0. if self.use_bias: logqb = -torch.sum(.5 * (math.log(2 * math.pi) + self.logvar_bias + 1)) return logqw + logqb def kldiv_aux(self): z = self.sample_z(1) z = z.view(self.dim_z) logqm = -torch.sum(.5 * (math.log(2 * math.pi) + self.qz_logvar + 1)) logqm = logqm.add(-torch.sum(F.sigmoid(z.exp().add(-1).log()).log())) logpm = torch.sum( 2 * math.lgamma(.5 * (self.dof + 1)) - math.lgamma(.5 * self.dof) - math.log(self.prior_std_z) - .5 * math.log(self.dof * math.pi) - .5 * (self.dof + 1) * torch.log(1. + z.pow(2) / (self.dof * self.prior_std_z**2))) return logpm - logqm def kldiv(self): return self.kldiv_aux() + self.eq_logpw() - self.eq_logqw() def get_eps(self, size): eps = self.floatTensor(size).normal_() eps = Variable(eps) return eps def sample_z(self, batch_size): z = self.qz_mean.view(1, self.dim_z).expand(batch_size, self.dim_z) if self.training: eps = self.get_eps(self.floatTensor(batch_size, self.dim_z)) z = z.add( eps.mul( self.qz_logvar.view(1, self.dim_z).expand( batch_size, self.dim_z).mul(0.5).exp_())) z = z.contiguous().view(batch_size, self.dim_z, 1, 1) return F.softplus(z) def sample_W(self): W = self.mean_w if self.training: eps = self.get_eps(self.mean_w.size()) W = W.add(eps.mul(self.logvar_w.mul(0.5).exp_())) return W def sample_b(self): if not self.use_bias: return None b = self.mean_bias if self.training: eps = self.get_eps(self.mean_bias.size()) b = b.add(eps.mul(self.logvar_bias.mul(0.5).exp_())) return b def forward(self, input_): z = self.sample_z(input_.size(0)) W = self.sample_W() b = self.sample_b() return F.conv2d(input_.mul(z.expand_as(input_)), W, b, self.stride, self.padding, self.dilation, self.groups) def __repr__(self): s = ('{name}({in_channels}, {out_channels}, kernel_size={kernel_size}' ', stride={stride}, prior_std_z={prior_std_z}, dof={dof}') if self.padding != (0, ) * len(self.padding): s += ', padding={padding}' if self.dilation != (1, ) * len(self.dilation): s += ', dilation={dilation}' if self.output_padding != (0, ) * len(self.output_padding): s += ', output_padding={output_padding}' if self.groups != 1: s += ', groups={groups}' if not self.use_bias: s += ', bias=False' s += ')' return s.format(name=self.__class__.__name__, **self.__dict__)
class OrthogonalBayesianConv2d(ConvNd): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, use_bias=True, simple=True, add_diagonal=True, weight_decay=1., prior_std=1., bias_std=1e-2, **kwargs): kernel_size = pair(kernel_size) stride = pair(stride) padding = pair(padding) dilation = pair(dilation) self.weight_decay = weight_decay self.floatTensor = torch.FloatTensor if not torch.cuda.is_available( ) else torch.cuda.FloatTensor self.device = torch.device( 'cpu') if not torch.cuda.is_available() else torch.device('cuda') self.simple = simple self.add_diagonal = add_diagonal super(OrthogonalBayesianConv2d, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, False, pair(0), groups, use_bias) self.in_channels = in_channels self.out_channels = out_channels // groups self.kernel_size = kernel_size[0] self.prior_std = prior_std self.bias_std = bias_std if simple: self.r_mean = Parameter( self.floatTensor(self.kernel_size * self.kernel_size, self.out_channels)) self.r_logvar = Parameter( self.floatTensor(self.kernel_size * self.kernel_size, self.out_channels)) else: self.r_mean = Parameter(self.floatTensor(2, self.out_channels)) self.r_logvar = Parameter(self.floatTensor(2, self.out_channels)) self.t_mean = Parameter( self.floatTensor(2 * (self.kernel_size - 1), self.out_channels)) self.t_logvar = Parameter( self.floatTensor(2 * (self.kernel_size - 1), self.out_channels)) if self.add_diagonal: self.d_mean = Parameter( self.floatTensor(self.kernel_size, self.kernel_size, min(self.in_channels, self.out_channels))) self.d_logvar = Parameter( self.floatTensor(self.kernel_size, self.kernel_size, min(self.in_channels, self.out_channels))) self.use_bias = use_bias if self.use_bias: self.bias_mean = Parameter( self.floatTensor(self.out_channels // self.groups)) self.bias_logvar = Parameter( self.floatTensor(self.out_channels // self.groups)) self.reset_parameters() print(self) def reset_parameters(self): if hasattr(self, 'r_mean'): torch.nn.init.orthogonal_(self.r_mean) self.r_logvar.data.normal_(mean=-5., std=self.prior_std) if hasattr(self, 't_mean'): torch.nn.init.orthogonal_(self.t_mean) self.t_logvar.data.normal_(mean=-5., std=self.prior_std) if hasattr(self, 'd_mean'): self.d_mean.data.normal_(std=self.prior_std) self.d_logvar.data.normal_(mean=-5, std=self.prior_std) if hasattr(self, 'bias_mean'): self.bias_mean.data.normal_(std=self.bias_std) self.bias_logvar.data.normal_(mean=-5., std=self.bias_std) def constrain_parameters(self, thres_std=1.): pass def _sample_eps(self, shape: tuple): return Variable(self.floatTensor(shape).normal_()) def _eq_logpw(self, prior_mean: float, prior_std: float, mean: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor: logpw = logvar.exp().add((prior_mean - mean)**2).div(prior_std**2).add( math.log(2. * math.pi * (prior_std**2))).mul(-0.5) return torch.sum(logpw) def _eq_logqw(self, logvar: torch.Tensor): logqw = logvar.add(math.log(2. * math.pi)).add(1.).mul(-0.5) return torch.sum(logqw) def eq_logpw(self) -> torch.Tensor: r = self._eq_logpw(prior_mean=0., prior_std=self.prior_std, mean=self.r_mean, logvar=self.r_logvar) if self.add_diagonal: d = self._eq_logpw(prior_mean=0., prior_std=self.prior_std, mean=self.d_mean, logvar=self.d_logvar) logpw = r.add(d) else: logpw = r if not self.simple: t = self._eq_logpw(prior_mean=0., prior_std=self.prior_std, mean=self.t_mean, logvar=self.t_logvar) logpw = logpw.add(t) if self.use_bias: bias = self._eq_logpw(prior_mean=0., prior_std=self.prior_std, mean=self.bias_mean, logvar=self.bias_logvar) logpw.add(bias) return logpw def eq_logqw(self): r = self._eq_logqw(logvar=self.r_logvar) if self.add_diagonal: d = self._eq_logqw(logvar=self.d_logvar) logqw = r.add(d) else: logqw = r if not self.simple: t = self._eq_logqw(logvar=self.t_logvar) logqw.add(t) if self.use_bias: bias = self._eq_logqw(logvar=self.bias_logvar) logqw.add(bias) return logqw def kldiv(self): return self.eq_logpw() - self.eq_logqw() def kldiv_aux(self) -> float: return 0. def _chain_multiply(self, t: torch.Tensor) -> torch.Tensor: p = t[0] for i in range(1, t.shape[0]): p = p.mm(t[i]) return p def _calc_householder_tensor(self, t: torch.Tensor) -> torch.Tensor: norm = t.norm(p=2, dim=1) t = t.div(norm.unsqueeze(1)) h = torch.einsum('ab,ac->abc', (t, t)) return torch.eye(n=t.shape[1], device=self.device).expand_as(h) - h.mul(2.) def _calc_block_orthogonal_tensor(self, size: int, t: torch.Tensor): h = torch.zeros(size, 2, 2, t.shape[1], t.shape[2], device=self.device) for i in range(size): p = t[2 * i] q = t[2 * i + 1] pq = p.mm(q) h[i, 0, 0] = pq h[i, 0, 1] = p.sub(pq) h[i, 1, 0] = q.sub(pq) h[i, 1, 1] = torch.eye(p.shape[0], device=self.device).add(pq).sub(p).sub(q) return h def _matrix_conv(self, s: torch.Tensor, t: torch.Tensor) -> torch.Tensor: assert s[0, 0].shape[0] == t[0, 0].shape[0] n = s[0, 0].shape[0] k = int(np.sqrt(s.shape[0] * s.shape[1])) l = int(np.sqrt(t.shape[0] * t.shape[1])) size = k + l - 1 result = torch.zeros(size, size, n, n, device=self.device) for i in range(size): for j in range(size): for index1 in range(min(k, i + 1)): for index2 in range(min(k, j + 1)): if (i - index1) < l and (j - index2) < l: result[i, j] += torch.mm(s[index1, index2], t[i - index1, j - index2]) return result def _orthogonal_kernel(self, r: torch.Tensor, t: torch.Tensor, d: torch.Tensor = None, transpose: bool = False) -> torch.Tensor: assert self.in_channels <= self.out_channels if self.in_channels <= self.out_channels: d = torch.einsum('abc,cd->abcd', (d, torch.eye(n=self.in_channels, m=self.out_channels, device=self.device))) else: d = torch.einsum('ab, cda->cdab', (torch.eye(n=self.in_channels, m=self.out_channels, device=self.device), d)) if self.simple: q = self._calc_householder_tensor(r).view(self.kernel_size, self.kernel_size, self.out_channels, self.out_channels) else: r = self._chain_multiply(self._calc_householder_tensor(r)) if self.kernel_size == 1: return torch.unsqueeze(torch.unsqueeze(r, 0), 0) t = self._calc_block_orthogonal_tensor( size=self.kernel_size - 1, t=self._calc_householder_tensor(t)) s = t[0] for i in range(1, self.kernel_size - 1): s = self._matrix_conv(s=s, t=t[i]) q = torch.einsum('ab,debc->deac', (r, s)) q = torch.einsum('abcd,abde->abce', (d, q)) if transpose: q = q.permute(2, 3, 1, 0) else: q = q.permute(3, 2, 1, 0) return q def forward(self, input_): r = self.r_mean.add( self.r_logvar.mul(0.5).exp().mul( self._sample_eps(self.r_logvar.shape))) if self.add_diagonal: d = self.d_mean.add( self.d_logvar.mul(0.5).exp().mul( self._sample_eps(self.d_logvar.shape))) else: d = torch.ones(self.kernel_size, self.kernel_size, min(self.in_channels, self.out_channels), device=self.device) if self.simple: weight = self._orthogonal_kernel(r=r, t=None, d=d) else: t = self.t_mean.add( self.t_logvar.mul(0.5).exp().mul( self._sample_eps(self.t_logvar.shape))) weight = self._orthogonal_kernel(r=r, t=t, d=d) bias = self.bias_mean.add( self.bias_logvar.mul(0.5).exp().mul( self._sample_eps(self.bias_logvar.shape))) if not self.use_bias: bias = None return F.conv2d(input_, weight, bias, self.stride, self.padding, self.dilation, self.groups) def __repr__(self): s = ('{name}({in_channels}, {out_channels}, kernel_size={kernel_size} ' ', stride={stride}, weight_decay={weight_decay}') if self.padding != (0, ) * len(self.padding): s += ', padding={padding}' if self.dilation != (1, ) * len(self.dilation): s += ', dilation={dilation}' if self.output_padding != (0, ) * len(self.output_padding): s += ', output_padding={output_padding}' if self.groups != 1: s += ', groups={groups}' if self.bias is None: s += ', bias=False' s += ', simple={simple}' s += ', add_diagonal={add_diagonal}' s += ')' return s.format(name=self.__class__.__name__, **self.__dict__)