class SVDConv2d(Module): ''' W = UdV ''' def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, norm=False): self.eps = 1e-8 self.norm = norm kernel_size = _pair(kernel_size) stride = _pair(stride) padding = _pair(padding) dilation = _pair(dilation) super(SVDConv2d, self).__init__() if in_channels % groups != 0: raise ValueError('in_channels must be divisible by groups') if out_channels % groups != 0: raise ValueError('out_channels must be divisible by groups') self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = kernel_size self.total_in_dim = in_channels * kernel_size[0] * kernel_size[1] self.weiSize = (self.out_channels, in_channels, kernel_size[0], kernel_size[1]) self.stride = stride self.padding = padding self.dilation = dilation self.output_padding = _pair(0) self.groups = groups self.scale = Parameter(torch.Tensor(1)) self.scale.data.fill_(1) if self.out_channels <= self.total_in_dim: self.Uweight = Parameter( torch.Tensor(self.out_channels, self.out_channels)) self.Dweight = Parameter(torch.Tensor(self.out_channels)) self.Vweight = Parameter( torch.Tensor(self.out_channels, self.total_in_dim)) self.Uweight.data.normal_(0, math.sqrt(2. / self.out_channels)) self.Vweight.data.normal_(0, math.sqrt(2. / self.total_in_dim)) self.Dweight.data.fill_(1) else: self.Uweight = Parameter( torch.Tensor(self.out_channels, self.total_in_dim)) self.Dweight = Parameter(torch.Tensor(self.total_in_dim)) self.Vweight = Parameter( torch.Tensor(self.total_in_dim, self.total_in_dim)) self.Uweight.data.normal_(0, math.sqrt(2. / self.out_channels)) self.Vweight.data.normal_(0, math.sqrt(2. / self.total_in_dim)) self.Dweight.data.fill_(1) self.projectiter = 0 self.project(style='qr', interval=1) if bias: self.bias = Parameter(torch.Tensor(self.out_channels)) self.bias.data.fill_(0) else: self.register_parameter('bias', None) if norm: self.register_buffer( 'input_norm_wei', torch.ones(1, in_channels // groups, *kernel_size)) def update_sigma(self): self.Dweight.data = self.Dweight.data / self.Dweight.data.abs().max() def spectral_reg(self): return -(torch.log(self.Dweight)).mean() @property def W_(self): self.update_sigma() return self.Uweight.mm(self.Dweight.diag()).mm(self.Vweight).view( self.weiSize) * self.scale def forward(self, input): _output = F.conv2d(input, self.W_, self.bias, self.stride, self.padding, self.dilation, self.groups) return _output def orth_reg(self): penalty = 0 if self.out_channels <= self.total_in_dim: W = self.Uweight else: W = self.Uweight.t() Wt = torch.t(W) WWt = W.mm(Wt) I = Variable(torch.eye(WWt.size()[0]).cuda()) penalty = penalty + ((WWt.sub(I))**2).sum() W = self.Vweight Wt = torch.t(W) WWt = W.mm(Wt) I = Variable(torch.eye(WWt.size()[0]).cuda()) penalty = penalty + ((WWt.sub(I))**2).sum() return penalty def project(self, style='none', interval=1): ''' Project weight to l2 ball ''' self.projectiter = self.projectiter + 1 if style == 'qr' and self.projectiter % interval == 0: # Compute the qr factorization for U if self.out_channels <= self.total_in_dim: q, r = torch.qr(self.Uweight.data.t()) else: q, r = torch.qr(self.Uweight.data) # Make Q uniform according to https://arxiv.org/pdf/math-ph/0609050.pdf d = torch.diag(r, 0) ph = d.sign() q *= ph if self.out_channels <= self.total_in_dim: self.Uweight.data = q.t() else: self.Uweight.data = q # Compute the qr factorization for V q, r = torch.qr(self.Vweight.data.t()) # Make Q uniform according to https://arxiv.org/pdf/math-ph/0609050.pdf d = torch.diag(r, 0) ph = d.sign() q *= ph self.Vweight.data = q.t() elif style == 'svd' and self.projectiter % interval == 0: # Compute the svd factorization (may be not stable) for U u, s, v = torch.svd(self.Uweight.data) self.Uweight.data = u.mm(v.t()) # Compute the svd factorization (may be not stable) for V u, s, v = torch.svd(self.Vweight.data) self.Vweight.data = u.mm(v.t()) def showOrthInfo(self): s = self.Dweight.data _D = self.Dweight.data.diag() W = self.Uweight.data.mm(_D).mm(self.Vweight.data) _, ss, _ = torch.svd(W.t()) print('Singular Value Summary: ') print('max :', s.max().item(), 'max* :', ss.max().item()) print('mean:', s.mean().item(), 'mean*:', ss.mean().item()) print('min :', s.min().item(), 'min* :', ss.min().item()) print('var :', s.var().item(), 'var* :', ss.var().item()) print('s RMSE: ', ((s - ss)**2).mean().item()**0.5) if self.out_channels <= self.total_in_dim: pu = (self.Uweight.data.mm(self.Uweight.data.t()) - torch.eye(self.Uweight.size()[0]).cuda()).norm().item()**2 else: pu = (self.Uweight.data.t().mm(self.Uweight.data) - torch.eye(self.Uweight.size()[1]).cuda()).norm().item()**2 pv = (self.Vweight.data.mm(self.Vweight.data.t()) - torch.eye(self.Vweight.size()[0]).cuda()).norm().item()**2 print('penalty :', pu, ' (U) + ', pv, ' (V)') return ss
class LDS(GenerativeModel): """ Gaussian latent LDS with (optional) NN observations: x(0) ~ N(x0, Q0 * Q0') x(t) ~ N(A x(t-1), Q * Q') y(t) ~ N(NN(x(t)), R * R') For a Kalman Filter model, choose the observation network, NN(x), to be a one-layer network with a linear output. The latent state has dimensionality n (parameter "xDim") and observations have dimensionality m (parameter "yDim"). Inputs: (See GenerativeModel abstract class definition for a list of standard parameters.) GenerativeParams - Dictionary of LDS parameters * A : [n x n] linear dynamics matrix; should have eigenvalues with magnitude strictly less than 1 * QChol : [n x n] square root of the innovation covariance Q * Q0Chol: [n x n] square root of the innitial innovation covariance * RChol : [n x 1] square root of the diagonal of the observation covariance * x0 : [n x 1] mean of initial latent state * NN: module specifying network transforming x to the mean of y (input dim: n, output dim: m) """ def __init__(self, GenerativeParams, xDim, yDim): super(LDS, self).__init__(GenerativeParams, xDim, yDim) # parameters if 'A' in GenerativeParams: self.A = Parameter(torch.Tensor(GenerativeParams['A'])) else: # TBD:MAKE A BETTER WAY OF SAMPLING DEFAULT A self.A = Parameter(torch.eye(xDim).mul_(0.5)) if 'QChol' in GenerativeParams: self.QChol = Parameter(torch.Tensor(GenerativeParams['QChol'])) else: self.QChol = Parameter(torch.eye(xDim)) if 'Q0Chol' in GenerativeParams: self.Q0Chol = Parameter(torch.Tensor(GenerativeParams['Q0Chol'])) else: self.Q0Chol = Parameter(torch.eye(xDim)) if 'RChol' in GenerativeParams: self.RChol = Parameter(torch.Tensor(GenerativeParams['RChol'])) else: self.RChol = Parameter(torch.randn(yDim).div_(10)) if 'x0' in GenerativeParams: self.x0 = Parameter(torch.Tensor(GenerativeParams['x0'])) else: self.x0 = Parameter(torch.zeros(xDim)) if 'NN' in GenerativeParams: self.add_module('NN', GenerativeParams['NN']) else: self.add_module('NN', torch.nn.Linear(xDim, yDim)) # we assume diagonal covariance (RChol is a vector) self.Rinv = self.RChol.pow(2).reciprocal() self.Lambda = torch.inverse(torch.matmul(self.QChol, self.QChol.t())) self.Lambda0 = torch.inverse(torch.matmul(self.Q0Chol, self.Q0Chol.t())) self.Ypred = self.NN(self.Xsamp) def sampleX(self, N): """ Sample latent state from the generative model. Return as a torch tensor. """ _x0 = self.x0.data _Q0Chol = self.Q0Chol.data _QChol = self.QChol.data _A = self.A.data norm_samp = torch.normal(torch.zeros(N, self.xDim)) x_vals = torch.zeros([N, self.xDim]) x_vals[0] = _x0 + norm_samp[0].matmul(_Q0Chol.t()) for ii in range(N - 1): x_vals[ii + 1] = x_vals[ii].matmul( _A.t()) + norm_samp[ii + 1].matmul(_QChol.t()) return x_vals def sampleY(self): """ Return a torch tensor sample from the generative model. """ eps = torch.normal(torch.zeros([self.yDim])) return self.Ypred.data + torch.matmul(eps, torch.diag(self.RChol.data)) def sampleXY(self, N): """ Return numpy samples from the generative model. """ X = Variable(self.sampleX(N), requires_grad=False) eps = torch.randn([X.size(0), self.yDim]) _RChol = self.RChol.data Y = self.NN(X).data + torch.matmul(eps, torch.diag(_RChol)) return [X.data.numpy(), Y.numpy()] def forward(self, X, Y): """ Calculate log p(Y|X). """ resY = Y - self.Ypred resX = X[1:] - X[:-1].matmul(self.A.t()) resX0 = X[0] - self.x0 self.resX0 = resX0 N = X.size(0) LogDensity = -(torch.matmul(resY.t(), resY) * torch.diag(self.Rinv)).sum() LogDensity += -(torch.matmul(resX.t(), resX) * self.Lambda).sum() LogDensity += -torch.matmul(torch.matmul(resX0, self.Lambda0), resX0) LogDensity += N * self.Rinv.log().sum() LogDensity += -2 * (N - 1) * self.QChol.diag().log().sum() LogDensity += -2 * self.Q0Chol.diag().log().sum() LogDensity += -N * (self.xDim + self.yDim) * np.log(2 * np.pi) LogDensity *= 0.5 return LogDensity