Esempio n. 1
0
class SVDConv2d(Module):
    '''
    W = UdV
    '''
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride=1,
                 padding=0,
                 dilation=1,
                 groups=1,
                 bias=True,
                 norm=False):
        self.eps = 1e-8
        self.norm = norm

        kernel_size = _pair(kernel_size)
        stride = _pair(stride)
        padding = _pair(padding)
        dilation = _pair(dilation)
        super(SVDConv2d, self).__init__()

        if in_channels % groups != 0:
            raise ValueError('in_channels must be divisible by groups')
        if out_channels % groups != 0:
            raise ValueError('out_channels must be divisible by groups')

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.total_in_dim = in_channels * kernel_size[0] * kernel_size[1]
        self.weiSize = (self.out_channels, in_channels, kernel_size[0],
                        kernel_size[1])

        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.output_padding = _pair(0)
        self.groups = groups

        self.scale = Parameter(torch.Tensor(1))
        self.scale.data.fill_(1)

        if self.out_channels <= self.total_in_dim:
            self.Uweight = Parameter(
                torch.Tensor(self.out_channels, self.out_channels))
            self.Dweight = Parameter(torch.Tensor(self.out_channels))
            self.Vweight = Parameter(
                torch.Tensor(self.out_channels, self.total_in_dim))
            self.Uweight.data.normal_(0, math.sqrt(2. / self.out_channels))
            self.Vweight.data.normal_(0, math.sqrt(2. / self.total_in_dim))
            self.Dweight.data.fill_(1)
        else:
            self.Uweight = Parameter(
                torch.Tensor(self.out_channels, self.total_in_dim))
            self.Dweight = Parameter(torch.Tensor(self.total_in_dim))
            self.Vweight = Parameter(
                torch.Tensor(self.total_in_dim, self.total_in_dim))
            self.Uweight.data.normal_(0, math.sqrt(2. / self.out_channels))
            self.Vweight.data.normal_(0, math.sqrt(2. / self.total_in_dim))
            self.Dweight.data.fill_(1)
        self.projectiter = 0
        self.project(style='qr', interval=1)

        if bias:
            self.bias = Parameter(torch.Tensor(self.out_channels))
            self.bias.data.fill_(0)
        else:
            self.register_parameter('bias', None)

        if norm:
            self.register_buffer(
                'input_norm_wei',
                torch.ones(1, in_channels // groups, *kernel_size))

    def update_sigma(self):
        self.Dweight.data = self.Dweight.data / self.Dweight.data.abs().max()

    def spectral_reg(self):
        return -(torch.log(self.Dweight)).mean()

    @property
    def W_(self):
        self.update_sigma()
        return self.Uweight.mm(self.Dweight.diag()).mm(self.Vweight).view(
            self.weiSize) * self.scale

    def forward(self, input):
        _output = F.conv2d(input, self.W_, self.bias, self.stride,
                           self.padding, self.dilation, self.groups)
        return _output

    def orth_reg(self):
        penalty = 0

        if self.out_channels <= self.total_in_dim:
            W = self.Uweight
        else:
            W = self.Uweight.t()
        Wt = torch.t(W)
        WWt = W.mm(Wt)
        I = Variable(torch.eye(WWt.size()[0]).cuda())
        penalty = penalty + ((WWt.sub(I))**2).sum()

        W = self.Vweight
        Wt = torch.t(W)
        WWt = W.mm(Wt)
        I = Variable(torch.eye(WWt.size()[0]).cuda())
        penalty = penalty + ((WWt.sub(I))**2).sum()
        return penalty

    def project(self, style='none', interval=1):
        '''
        Project weight to l2 ball
        '''
        self.projectiter = self.projectiter + 1
        if style == 'qr' and self.projectiter % interval == 0:
            # Compute the qr factorization for U
            if self.out_channels <= self.total_in_dim:
                q, r = torch.qr(self.Uweight.data.t())
            else:
                q, r = torch.qr(self.Uweight.data)
            # Make Q uniform according to https://arxiv.org/pdf/math-ph/0609050.pdf
            d = torch.diag(r, 0)
            ph = d.sign()
            q *= ph
            if self.out_channels <= self.total_in_dim:
                self.Uweight.data = q.t()
            else:
                self.Uweight.data = q

            # Compute the qr factorization for V
            q, r = torch.qr(self.Vweight.data.t())
            # Make Q uniform according to https://arxiv.org/pdf/math-ph/0609050.pdf
            d = torch.diag(r, 0)
            ph = d.sign()
            q *= ph
            self.Vweight.data = q.t()
        elif style == 'svd' and self.projectiter % interval == 0:
            # Compute the svd factorization (may be not stable) for U
            u, s, v = torch.svd(self.Uweight.data)
            self.Uweight.data = u.mm(v.t())

            # Compute the svd factorization (may be not stable) for V
            u, s, v = torch.svd(self.Vweight.data)
            self.Vweight.data = u.mm(v.t())

    def showOrthInfo(self):
        s = self.Dweight.data
        _D = self.Dweight.data.diag()
        W = self.Uweight.data.mm(_D).mm(self.Vweight.data)
        _, ss, _ = torch.svd(W.t())
        print('Singular Value Summary: ')
        print('max :', s.max().item(), 'max* :', ss.max().item())
        print('mean:', s.mean().item(), 'mean*:', ss.mean().item())
        print('min :', s.min().item(), 'min* :', ss.min().item())
        print('var :', s.var().item(), 'var* :', ss.var().item())
        print('s RMSE: ', ((s - ss)**2).mean().item()**0.5)
        if self.out_channels <= self.total_in_dim:
            pu = (self.Uweight.data.mm(self.Uweight.data.t()) -
                  torch.eye(self.Uweight.size()[0]).cuda()).norm().item()**2
        else:
            pu = (self.Uweight.data.t().mm(self.Uweight.data) -
                  torch.eye(self.Uweight.size()[1]).cuda()).norm().item()**2
        pv = (self.Vweight.data.mm(self.Vweight.data.t()) -
              torch.eye(self.Vweight.size()[0]).cuda()).norm().item()**2
        print('penalty :', pu, ' (U) + ', pv, ' (V)')
        return ss
Esempio n. 2
0
class LDS(GenerativeModel):
    """
    Gaussian latent LDS with (optional) NN observations:
    x(0) ~ N(x0, Q0 * Q0')
    x(t) ~ N(A x(t-1), Q * Q')
    y(t) ~ N(NN(x(t)), R * R')
    For a Kalman Filter model, choose the observation network, NN(x), to be
    a one-layer network with a linear output. The latent state has dimensionality
    n (parameter "xDim") and observations have dimensionality m (parameter "yDim").
    Inputs:
    (See GenerativeModel abstract class definition for a list of standard parameters.)
    GenerativeParams  -  Dictionary of LDS parameters
                           * A     : [n x n] linear dynamics matrix; should
                                     have eigenvalues with magnitude strictly
                                     less than 1
                           * QChol : [n x n] square root of the innovation
                                     covariance Q
                           * Q0Chol: [n x n] square root of the innitial innovation
                                     covariance
                           * RChol : [n x 1] square root of the diagonal of the
                                     observation covariance
                           * x0    : [n x 1] mean of initial latent state
                           * NN: module specifying network transforming x to
                                 the mean of y (input dim: n, output dim: m)
    """
    def __init__(self, GenerativeParams, xDim, yDim):

        super(LDS, self).__init__(GenerativeParams, xDim, yDim)

        # parameters
        if 'A' in GenerativeParams:
            self.A = Parameter(torch.Tensor(GenerativeParams['A']))
        else:
            # TBD:MAKE A BETTER WAY OF SAMPLING DEFAULT A
            self.A = Parameter(torch.eye(xDim).mul_(0.5))

        if 'QChol' in GenerativeParams:
            self.QChol = Parameter(torch.Tensor(GenerativeParams['QChol']))
        else:
            self.QChol = Parameter(torch.eye(xDim))

        if 'Q0Chol' in GenerativeParams:
            self.Q0Chol = Parameter(torch.Tensor(GenerativeParams['Q0Chol']))
        else:
            self.Q0Chol = Parameter(torch.eye(xDim))

        if 'RChol' in GenerativeParams:
            self.RChol = Parameter(torch.Tensor(GenerativeParams['RChol']))
        else:
            self.RChol = Parameter(torch.randn(yDim).div_(10))

        if 'x0' in GenerativeParams:
            self.x0 = Parameter(torch.Tensor(GenerativeParams['x0']))
        else:
            self.x0 = Parameter(torch.zeros(xDim))

        if 'NN' in GenerativeParams:
            self.add_module('NN', GenerativeParams['NN'])
        else:
            self.add_module('NN', torch.nn.Linear(xDim, yDim))

        # we assume diagonal covariance (RChol is a vector)
        self.Rinv = self.RChol.pow(2).reciprocal()
        self.Lambda = torch.inverse(torch.matmul(self.QChol, self.QChol.t()))
        self.Lambda0 = torch.inverse(torch.matmul(self.Q0Chol,
                                                  self.Q0Chol.t()))

        self.Ypred = self.NN(self.Xsamp)

    def sampleX(self, N):
        """
        Sample latent state from the generative model. Return as a torch tensor.
        """
        _x0 = self.x0.data
        _Q0Chol = self.Q0Chol.data
        _QChol = self.QChol.data
        _A = self.A.data

        norm_samp = torch.normal(torch.zeros(N, self.xDim))
        x_vals = torch.zeros([N, self.xDim])

        x_vals[0] = _x0 + norm_samp[0].matmul(_Q0Chol.t())

        for ii in range(N - 1):
            x_vals[ii + 1] = x_vals[ii].matmul(
                _A.t()) + norm_samp[ii + 1].matmul(_QChol.t())

        return x_vals

    def sampleY(self):
        """ Return a torch tensor sample from the generative model. """
        eps = torch.normal(torch.zeros([self.yDim]))
        return self.Ypred.data + torch.matmul(eps, torch.diag(self.RChol.data))

    def sampleXY(self, N):
        """ Return numpy samples from the generative model. """
        X = Variable(self.sampleX(N), requires_grad=False)
        eps = torch.randn([X.size(0), self.yDim])
        _RChol = self.RChol.data
        Y = self.NN(X).data + torch.matmul(eps, torch.diag(_RChol))
        return [X.data.numpy(), Y.numpy()]

    def forward(self, X, Y):
        """
        Calculate log p(Y|X).
        """
        resY = Y - self.Ypred
        resX = X[1:] - X[:-1].matmul(self.A.t())
        resX0 = X[0] - self.x0
        self.resX0 = resX0
        N = X.size(0)

        LogDensity = -(torch.matmul(resY.t(), resY) *
                       torch.diag(self.Rinv)).sum()
        LogDensity += -(torch.matmul(resX.t(), resX) * self.Lambda).sum()
        LogDensity += -torch.matmul(torch.matmul(resX0, self.Lambda0), resX0)
        LogDensity += N * self.Rinv.log().sum()
        LogDensity += -2 * (N - 1) * self.QChol.diag().log().sum()
        LogDensity += -2 * self.Q0Chol.diag().log().sum()
        LogDensity += -N * (self.xDim + self.yDim) * np.log(2 * np.pi)
        LogDensity *= 0.5

        return LogDensity