Exemplo n.º 1
0
    def forward(self, x):
        nBatch = x.size(0)

        # FC-ReLU-(BN)-FC-ReLU-(BN)-QP-Softmax
        x = x.view(nBatch, -1)
        x = F.relu(self.fc1(x))
        if self.bn:
            x = self.bn1(x)
        x = F.relu(self.fc2(x))
        if self.bn:
            x = self.bn2(x)

        L = self.M * self.L
        Q = L.mm(L.t()) + self.eps * Variable(torch.eye(self.nCls)).cuda()
        Q = Q.unsqueeze(0).expand(nBatch, self.nCls, self.nCls)
        G = self.G.unsqueeze(0).expand(nBatch, self.nineq, self.nCls)
        z0 = self.qp_z0(x)
        s0 = self.qp_s0(x)
        h = z0.mm(self.G.t()) + s0
        e = Variable(torch.Tensor())
        inputs = x
        x = QPFunction()(inputs.double(), Q.double(), G.double(), h.double(),
                         e, e)
        x = x.float()
        # x = x[:,:10].float()

        return F.log_softmax(x)
Exemplo n.º 2
0
    def forward(self, inputs):
        shot = inputs.shape[1]

        kernel_matrices = torch.bmm(inputs, inputs.transpose(1, 2))
        kernel_matrices += self._eps * torch.eye(shot)
        kernel_diags = torch.diagonal(kernel_matrices, dim1=-2, dim2=-1)
        Q = 2 * kernel_matrices
        p = -kernel_diags
        A = torch.ones(1, shot)
        b = torch.ones(1)
        G = -torch.eye(shot)
        h = torch.zeros(shot)
        alphas = QPFunction(verbose=False)(
            Q,
            p,
            G.detach(),
            h.detach(),
            A.detach(),
            b.detach(),
        )

        alphas = alphas.unsqueeze(-1)
        centers = torch.sum(alphas * inputs, dim=self._dim)
        # `keepdim=True` here to avoid unsqueezing in `CentersDistance`, which
        # could be used for vanilla protonet?

        return centers
Exemplo n.º 3
0
 def projF(x):
     nBatch = x.size(0)
     Q = self.Q.unsqueeze(0).expand(nBatch, nCls, nCls)
     G = self.G.unsqueeze(0).expand(nBatch, nCls, nCls)
     h = self.h.unsqueeze(0).expand(nBatch, nCls)
     A = self.A.unsqueeze(0).expand(nBatch, 1, nCls)
     b = self.b.unsqueeze(0).expand(nBatch, 1)
     x = QPFunction()(-x.double(), Q, G, h, A, b).float()
     x = x.log()
     return x
Exemplo n.º 4
0
    def project_action(self, u, x):
        Px = x @ self.P
        G = 2 * Px.expand(self.B.shape[0], Px.shape[0], Px.shape[1]).bmm(self.B).transpose(0, 1)
        h = (Px * x).sum(-1).unsqueeze(1) + \
            2 * Px.expand(self.B.shape[0], Px.shape[0], Px.shape[1]).bmm(self.A).transpose(0, 1).matmul(
            x.unsqueeze(2)).squeeze(2)

        Q = torch.eye(u.shape[-1], device=x.device).unsqueeze(0).expand(u.shape[0], u.shape[-1], u.shape[-1])
        res = QPFunction(verbose=-1)(Q.double(), -u.double(), G.double(), -h.double(), self.e, self.e)
        return res.type(TORCH_DTYPE) 
Exemplo n.º 5
0
    def forward(self, puzzles):
        nBatch = puzzles.size(0)

        x = puzzles.view(nBatch,-1)
        x = self.fc_in(x)

        e = Variable(torch.Tensor())

        h = self.G.mv(self.z)+self.s
        x = QPFunction(verbose=False)(
            self.Q, x, self.G, h, e, e,
        )

        x = self.fc_out(x)
        x = x.view_as(puzzles)
        return x
Exemplo n.º 6
0
def svm_logits(query, support, labels, ways, shots, C_reg=0.1, max_iters=15):
    num_support = support.size(0)
    num_query = query.size(0)
    device = support.device
    kernel = support @ support.t()
    I_ways = torch.eye(ways).to(device)
    block_kernel = kronecker(kernel, I_ways)
    block_kernel.add_(torch.eye(ways * num_support, device=device))
    labels_onehot = onehot(labels, dim=ways).view(1, -1).to(device)
    I_sw = torch.eye(num_support * ways, device=device)
    I_s = torch.eye(num_support, device=device)
    h = C_reg * labels_onehot
    A = kronecker(I_s, torch.ones(1, ways, device=device))
    b = torch.zeros(1, num_support, device=device)
    qp = QPFunction(verbose=False, maxIter=max_iters)
    qp_solution = qp(block_kernel, -labels_onehot, I_sw, h, A, b)
    qp_solution = qp_solution.reshape(num_support, ways)

    qp_solution = qp_solution.unsqueeze(1).expand(num_support, num_query, ways)
    compatibility = support @ query.t()
    compatibility = compatibility.unsqueeze(2).expand(num_support, num_query,
                                                      ways)
    logits = qp_solution * compatibility
    logits = torch.sum(logits, dim=0)
    return logits
Exemplo n.º 7
0
 def fit_(self, support, labels, ways=None, C_reg=None, max_iters=None):
     if C_reg is None:
         C_reg = self.C_reg
     if max_iters is None:
         max_iters = self.max_iters
     if self._normalize:
         support = self.normalize(support)
     if ways is None:
         ways = len(torch.unique(labels))
     num_support = support.size(0)
     device = support.device
     kernel = support @ support.t()
     I_ways = torch.eye(ways).to(device)
     block_kernel = kronecker(kernel, I_ways)
     block_kernel.add_(torch.eye(ways * num_support, device=device))
     labels_onehot = onehot(labels, dim=ways).view(1, -1).to(device)
     I_sw = torch.eye(num_support * ways, device=device)
     I_s = torch.eye(num_support, device=device)
     h = C_reg * labels_onehot
     A = kronecker(I_s, torch.ones(1, ways, device=device))
     b = torch.zeros(1, num_support, device=device)
     qp = QPFunction(verbose=False, maxIter=max_iters)
     qp_solution = qp(block_kernel, -labels_onehot, I_sw, h, A, b)
     self.qp_solution = qp_solution.reshape(num_support, ways)
     self.support = support
     self.num_support = num_support
     self.ways = ways
Exemplo n.º 8
0
    def forward(self, x):
        nBatch = x.size(0)

        # FC-ReLU-(BN)-FC-ReLU-(BN)-QP-Softmax
        x = x.view(nBatch, -1)

        x = x.unsqueeze(0)

        x = x.float()
        tmp = self.fc1(x)

        x = F.relu(tmp)

        x = x.squeeze(2)

        #if self.bn:
        #x = self.bn1(x)
        #x = F.relu(self.fc2(x))
        #if self.bn:
        #x = self.bn2(x)

        L = self.M * self.L
        Q = L.mm(L.t()) + self.eps * Variable(torch.eye(self.nCls)).cuda()
        p = self.p.double()
        h = self.G.mv(self.z0) + self.s0
        G = self.G.double()
        Q = Q.double()
        h = h.double()
        print(Q.size(), p.size(), G.size(), h.size())

        e = Variable(torch.Tensor())

        x = QPFunction(verbose=True)(Q, p, G, h, e, e).cuda()
        print(x)
        return F.log_softmax(x, dim=1)
Exemplo n.º 9
0
    def forward(self, puzzles):
        nBatch = puzzles.size(0)

        p = -puzzles.view(nBatch, -1)
        b = self.A.mv(self.log_z0.exp())

        if self.qp_solver == 'qpth':
            y = QPFunction(verbose=-1)(self.Q, p.double(), self.G, self.h,
                                       self.A, b).float().view_as(puzzles)
        elif self.qp_solver == 'osqpth':
            _l = torch.cat((b,
                            torch.full(self.h.shape,
                                       float('-inf'),
                                       device=self.h.device,
                                       dtype=self.h.dtype)),
                           dim=0)
            _u = torch.cat((b, self.h), dim=0)
            Q_data = self.Q[self.Q_idx[0], self.Q_idx[1]]

            AG = torch.cat((self.A, self.G), dim=0)
            AG_data = AG[self.AG_idx[0], self.AG_idx[1]]
            y = OSQP(self.Q_idx,
                     self.Q.shape,
                     self.AG_idx,
                     AG.shape,
                     diff_mode=DiffModes.FULL)(Q_data, p.double(), AG_data, _l,
                                               _u).float().view_as(puzzles)
        else:
            assert False

        return y
Exemplo n.º 10
0
    def forward(self, puzzles):
        nBatch = puzzles.size(0)

        p = -puzzles.view(nBatch, -1)

        return QPFunction(verbose=-1)(
            self.Q, p.double(), self.G, self.h, self.A, self.b
        ).float().view_as(puzzles)
Exemplo n.º 11
0
 def forward(self, z0, mu, dg, d2g):
     nBatch, n = z0.size()
     
     Q = torch.cat([torch.diag(d2g[i] + 1).unsqueeze(0) 
         for i in range(nBatch)], 0).double()
     p = (dg - d2g*z0 - mu).double()
     G = self.G.unsqueeze(0).expand(nBatch, self.G.size(0), self.G.size(1))
     h = self.h.unsqueeze(0).expand(nBatch, self.h.size(0))
     
     out = QPFunction(verbose=False)(p, Q, G, h, self.e, self.e)
     return out
Exemplo n.º 12
0
    def forward(self, puzzles):
        nBatch = puzzles.size(0)

        Q = self.Q.unsqueeze(0).expand(nBatch, self.Q.size(0), self.Q.size(1))
        p = -puzzles.view(nBatch, -1)
        G = self.G.unsqueeze(0).expand(nBatch, self.G.size(0), self.G.size(1))
        h = self.h.unsqueeze(0).expand(nBatch, self.h.size(0))
        A = self.A.unsqueeze(0).expand(nBatch, self.A.size(0), self.A.size(1))
        b = self.b.unsqueeze(0).expand(nBatch, self.b.size(0))

        return QPFunction(verbose=False)(p.double(), Q, G, h, A,
                                         b).float().view_as(puzzles)
Exemplo n.º 13
0
    def forward(self, puzzles):
        nBatch = puzzles.size(0)

        p = -puzzles.view(nBatch, -1)

        h2 = self.G2.mv(self.z2) + self.s2
        G = torch.cat((self.G1, self.G2), 0)
        h = torch.cat((self.h1, h2), 0)
        e = Variable(torch.Tensor())

        return QPFunction(verbose=False)(self.Q, p.double(), G, h, e,
                                         e).float().view_as(puzzles)
Exemplo n.º 14
0
    def forward(self, x, Q, p, G, h, m):
        print("Cuda current device", torch.cuda.current_device())
        nBatch = x.size(0)

        if (m > 1):
            p = p.float().t()
        else:
            p = p.float()

        G = G.float()  #.cuda()
        Q = Q.float()
        if (m >= 2):
            Q = Q.unsqueeze(0)
        h = h.float()
        #print(Q.size(),p.size(),G.size(),h.size())

        e = Variable(torch.Tensor(), requires_grad=True)

        x = QPFunction(verbose=True)(Q, p, G, h, e, e)  #.cuda()

        x = x.view(10, -1)  ##this was not needed earlier

        return F.log_softmax(x, dim=1)
Exemplo n.º 15
0
    def do_lipschitz_projection(self):
        """
        Perform the Lipschitz projection step by solving the QP
        """

        with torch.no_grad():
            if self.QP == "qpth":
                # qpth library
                proj_coefficients = QPFunction(verbose=False)(nn.Parameter(self.Q), -2.0*self.coefficients, nn.Parameter(self.G), nn.Parameter(self.h), nn.Parameter(self.e), nn.Parameter(self.e))
                self.coefficients_vect_.data = proj_coefficients.view(-1)

            elif self.QP == "cvxpy":
                # cvxpylayers library
                """
                # row_wise verification 
                proj_coefficients = torch.empty(self.coefficients.data.shape)
                for i in range(self.coefficients.data.shape[0]):
                    proj_coefficient, = self.qp(-2.0*self.coefficients.data[i, :])
                    proj_coefficients[i, :] = proj_coefficient
                self.coefficients_vect_.data = proj_coefficients.view(-1)
                """
                proj_coefficients, = self.qp(-2.0 * self.coefficients.data)
                self.coefficients_vect_.data = proj_coefficients.view(-1)
Exemplo n.º 16
0
    def forward(self, y):
        nBatch, k = y.size()

        Q_scale = torch.cat([torch.diag(torch.cat(
            [self.one, y[i], y[i]])).unsqueeze(0) for i in range(nBatch)], 0)
        Q = self.Q.unsqueeze(0).expand_as(Q_scale).mul(Q_scale)
        p_scale = torch.cat([Variable(torch.ones(nBatch,1).cuda()), y, y], 1)
        p = self.p.unsqueeze(0).expand_as(p_scale).mul(p_scale)
        G = self.G.unsqueeze(0).expand(nBatch, self.G.size(0), self.G.size(1))
        h = self.h.unsqueeze(0).expand(nBatch, self.h.size(0))
        e = Variable(torch.Tensor().cuda()).double()

        out = QPFunction(verbose=False)\
            (p.double(), Q.double(), G.double(), h.double(), e, e).float()

        return out[:,:1]
Exemplo n.º 17
0
    def forward(self, x):
        nBatch = x.size(0)

        x = self.fc1(x)

        L = self.M*self.L
        Q = L.mm(L.t()) + self.args.eps*Variable(torch.eye(self.nHidden)).cuda()
        Q = Q.unsqueeze(0).expand(nBatch, self.nHidden, self.nHidden)
        G = self.G.unsqueeze(0).expand(nBatch, self.nineq, self.nHidden)
        h = self.G.mv(self.z0)+self.s0
        h = h.unsqueeze(0).expand(nBatch, self.nineq)
        e = Variable(torch.Tensor())
        x = QPFunction()(Q, x, G, h, e, e)
        x = x[:,:self.nFeatures]

        return x
Exemplo n.º 18
0
    def forward(self, x):
        nBatch = x.size(0)

        # FC-ReLU-QP-FC-Softmax
        x = x.view(nBatch, -1)
        x = F.relu(self.fc1(x))

        Q = self.Q.unsqueeze(0).expand(nBatch, self.Q.size(0), self.Q.size(1))
        p = -x.view(nBatch, -1)
        G = self.G.unsqueeze(0).expand(nBatch, self.G.size(0), self.G.size(1))
        h = self.h.unsqueeze(0).expand(nBatch, self.h.size(0))
        A = self.A.unsqueeze(0).expand(nBatch, self.A.size(0), self.A.size(1))
        b = self.b.unsqueeze(0).expand(nBatch, self.b.size(0))

        x = QPFunction(verbose=False)(p.double(), Q, G, h, A, b).float()
        x = self.fc2(x)

        return F.log_softmax(x)
Exemplo n.º 19
0
    def forward(self, x):
        nBatch = x.size(0)

        # FC-ReLU-(BN)-FC-ReLU-(BN)-QP-Softmax
        x = x.view(nBatch, -1)
        x = F.relu(self.fc1(x))
        if self.bn:
            x = self.bn1(x)
        x = F.relu(self.fc2(x))
        if self.bn:
            x = self.bn2(x)

        L = self.M * self.L
        Q = L.mm(L.t()) + self.eps * Variable(torch.eye(self.nCls)).cuda()
        h = self.G.mv(self.z0) + self.s0
        e = Variable(torch.Tensor())
        x = QPFunction(verbose=False)(Q, x, G, h, e, e)

        return F.log_softmax(x)
Exemplo n.º 20
0
    def forward(self, x):
        nBatch = x.size(0)

        x = F.max_pool2d(self.conv1(x), 2)
        x = F.max_pool2d(self.conv2(x), 2)
        x = x.view(nBatch, -1)

        L = self.M * self.L
        Q = L.mm(L.t()) + self.eps * Variable(torch.eye(self.nHidden)).cuda()
        Q = Q.unsqueeze(0).expand(nBatch, self.nHidden, self.nHidden)
        G = self.G.unsqueeze(0).expand(nBatch, self.nineq, self.nHidden)
        z0 = self.qp_z0(x)
        s0 = self.qp_s0(x)
        h = z0.mm(self.G.t()) + s0
        e = Variable(torch.Tensor())
        inputs = self.qp_o(x)
        x = QPFunction()(inputs, Q, G, h, e, e)
        x = x[:, :10]

        return F.log_softmax(x)
Exemplo n.º 21
0
    def forward(self, x):
        nBatch = x.size(0)

        L = self.M*self.L
        Q = L.mm(L.t()) + self.args.eps*Variable(torch.eye(self.nHidden)).cuda()
        Q = Q.unsqueeze(0).expand(nBatch, self.nHidden, self.nHidden)
        nI = Variable(-torch.eye(self.nFeatures-1).type_as(Q.data))
        G = torch.cat((
              torch.cat(( self.D, nI), 1),
              torch.cat((-self.D, nI), 1)
        ))
        G = G.unsqueeze(0).expand(nBatch, self.nineq, self.nHidden)
        h = self.h.unsqueeze(0).expand(nBatch, self.nineq)
        e = Variable(torch.Tensor())
        # p = torch.cat((-x, self.lam.unsqueeze(0).expand(nBatch, self.nFeatures-1)), 1)
        p = torch.cat((-x, Parameter(13.*torch.ones(nBatch, self.nFeatures-1).cuda())), 1)
        x = QPFunction()(Q.double(), p.double(), G.double(), h.double(), e, e).float()
        x = x[:,:self.nFeatures]

        return x
Exemplo n.º 22
0
    def forward(self, log_prices):
        prices = torch.exp(log_prices)
        
        nBatch = prices.size(0)
        T = self.T

        Q = self.Q.unsqueeze(0).expand(nBatch, self.Q.size(0), self.Q.size(1))
        c = torch.cat(
            [prices, -prices, 
            -(self.lam * self.B * torch.ones(T, device=DEVICE)).unsqueeze(0).expand(nBatch,T)], 
            1)
        A = self.A.unsqueeze(0).expand(nBatch, self.A.size(0), self.A.size(1))
        b = self.b.unsqueeze(0).expand(nBatch, self.b.size(0))
        Ae = self.Ae.unsqueeze(0).expand(nBatch, self.Ae.size(0), self.Ae.size(1))
        be = self.be.unsqueeze(0).expand(nBatch, self.be.size(0))
                
        out = QPFunction(verbose=True)\
            (Q.double(), c.double(), A.double(), b.double(), Ae.double(), be.double())
        
        return out
Exemplo n.º 23
0
    def forward(self, y):
        nBatch, k = y.size()

        eps2 = 1e-8
        Q_scale = torch.cat([
            torch.diag(torch.cat([self.one, y[i] + eps2, y[i] + eps2
                                  ])).unsqueeze(0) for i in range(nBatch)
        ], 0)
        Q = self.Q.unsqueeze(0).expand_as(Q_scale).mul(Q_scale)
        p_scale = torch.cat([torch.ones(nBatch, 1, device=DEVICE), y, y], 1)
        p = self.p.unsqueeze(0).expand_as(p_scale).mul(p_scale)
        G = self.G.unsqueeze(0).expand(nBatch, self.G.size(0), self.G.size(1))
        h = self.h.unsqueeze(0).expand(nBatch, self.h.size(0))
        e = torch.DoubleTensor()
        if USE_GPU:
            e = e.cuda()

        out = QPFunction(verbose=False)\
            (Q.double(), p.double(), G.double(), h.double(), e, e).float()

        return out[:, :1]
Exemplo n.º 24
0
def MetaOptNetHead_SVM(query,
                       support,
                       support_labels,
                       n_way,
                       n_shot,
                       C_reg=0.1,
                       double_precision=False,
                       maxIter=15):
    """
    Fits the support set with multi-class SVM and
    returns the classification score on the query set.

    This is the multi-class SVM presented in:
    On the Algorithmic Implementation of Multiclass Kernel-based Vector Machines
    (Crammer and Singer, Journal of Machine Learning Research 2001).
    This model is the classification head that we use for the final version.
    Parameters:
      query:  a (tasks_per_batch, n_query, d) Tensor.
      support:  a (tasks_per_batch, n_support, d) Tensor.
      support_labels: a (tasks_per_batch, n_support) Tensor.
      n_way: a scalar. Represents the number of classes in a few-shot classification task.
      n_shot: a scalar. Represents the number of support examples given per class.
      C_reg: a scalar. Represents the cost parameter C in SVM.
    Returns: a (tasks_per_batch, n_query, n_way) Tensor.
    """
    tasks_per_batch = query.size(0)
    n_support = support.size(1)
    n_query = query.size(1)
    assert (query.dim() == 3)
    assert (support.dim() == 3)
    assert (query.size(0) == support.size(0)
            and query.size(2) == support.size(2))
    assert (n_support == n_way * n_shot
            )  # n_support must equal to n_way * n_shot

    # Here we solve the dual problem:
    # Note that the classes are indexed by m & samples are indexed by i.
    # min_{\alpha}  0.5 \sum_m ||w_m(\alpha)||^2 + \sum_i \sum_m e^m_i alpha^m_i
    # s.t.  \alpha^m_i <= C^m_i \forall m,i , \sum_m \alpha^m_i=0 \forall i

    # where w_m(\alpha) = \sum_i \alpha^m_i x_i,
    # and C^m_i = C if m  = y_i,
    # C^m_i = 0 if m != y_i.
    # This borrows the notation of liblinear.

    # \alpha is an (n_support, n_way) matrix
    kernel_matrix = computeGramMatrix(support, support)
    id_matrix_0 = torch.eye(n_way).expand(tasks_per_batch, n_way, n_way).cuda()
    block_kernel_matrix = batched_kronecker(kernel_matrix, id_matrix_0)

    block_kernel_matrix += 1.0 * torch.eye(n_way * n_support).expand(
        tasks_per_batch, n_way * n_support, n_way * n_support).cuda()
    support_labels_one_hot = one_hot(
        support_labels.view(tasks_per_batch * n_support),
        n_way)  # (tasks_per_batch * n_support, n_support)
    support_labels_one_hot = support_labels_one_hot.view(
        tasks_per_batch, n_support, n_way)
    support_labels_one_hot = support_labels_one_hot.reshape(
        tasks_per_batch, n_support * n_way)

    G = block_kernel_matrix
    e = -1.0 * support_labels_one_hot
    # print (G.size())
    # This part is for the inequality constraints:
    # \alpha^m_i <= C^m_i \forall m,i
    # where C^m_i = C if m  = y_i,
    # C^m_i = 0 if m != y_i.
    id_matrix_1 = torch.eye(n_way * n_support).expand(tasks_per_batch,
                                                      n_way * n_support,
                                                      n_way * n_support)
    C = Variable(id_matrix_1)
    h = Variable(C_reg * support_labels_one_hot)
    # print (C.size(), h.size())
    # This part is for the equality constraints:
    # \sum_m \alpha^m_i=0 \forall i
    id_matrix_2 = torch.eye(n_support).expand(tasks_per_batch, n_support,
                                              n_support).cuda()
    A = Variable(
        batched_kronecker(id_matrix_2,
                          torch.ones(tasks_per_batch, 1, n_way).cuda()))
    b = Variable(torch.zeros(tasks_per_batch, n_support))

    if double_precision:
        G, e, C, h, A, b = [x.double().cuda() for x in [G, e, C, h, A, b]]
    else:
        G, e, C, h, A, b = [x.float().cuda() for x in [G, e, C, h, A, b]]
    # Solve the following QP to fit SVM:
    #        \hat z =   argmin_z 1/2 z^T G z + e^T z
    #                 subject to Cz <= h
    # We use detach() to prevent backpropagation to fixed variables.
    qp_sol = QPFunction(verbose=False, maxIter=maxIter)(G, e.detach(),
                                                        C.detach(), h.detach(),
                                                        A.detach(), b.detach())

    # Compute the classification score.
    compatibility = computeGramMatrix(support, query)
    compatibility = compatibility.float()
    compatibility = compatibility.unsqueeze(3).expand(tasks_per_batch,
                                                      n_support, n_query,
                                                      n_way)
    qp_sol = qp_sol.reshape(tasks_per_batch, n_support, n_way)
    logits = qp_sol.float().unsqueeze(2).expand(tasks_per_batch, n_support,
                                                n_query, n_way)
    logits = logits * compatibility
    logits = torch.sum(logits, 1)
    return logits
Exemplo n.º 25
0
def MetaOptNetHead_Ridge(query,
                         support,
                         support_labels,
                         n_way,
                         n_shot,
                         lambda_reg=50.0,
                         double_precision=False):
    """
    Fits the support set with ridge regression and
    returns the classification score on the query set.
    Parameters:
      query:  a (tasks_per_batch, n_query, d) Tensor.
      support:  a (tasks_per_batch, n_support, d) Tensor.
      support_labels: a (tasks_per_batch, n_support) Tensor.
      n_way: a scalar. Represents the number of classes in a few-shot classification task.
      n_shot: a scalar. Represents the number of support examples given per class.
      lambda_reg: a scalar. Represents the strength of L2 regularization.
    Returns: a (tasks_per_batch, n_query, n_way) Tensor.
    """

    tasks_per_batch = query.size(0)
    n_support = support.size(1)
    n_query = query.size(1)

    assert (query.dim() == 3)
    assert (support.dim() == 3)
    assert (query.size(0) == support.size(0)
            and query.size(2) == support.size(2))
    assert (n_support == n_way * n_shot
            )  # n_support must equal to n_way * n_shot

    # Here we solve the dual problem:
    # Note that the classes are indexed by m & samples are indexed by i.
    # min_{\alpha}  0.5 \sum_m ||w_m(\alpha)||^2 + \sum_i \sum_m e^m_i alpha^m_i

    # where w_m(\alpha) = \sum_i \alpha^m_i x_i,

    # \alpha is an (n_support, n_way) matrix
    kernel_matrix = computeGramMatrix(support, support)
    kernel_matrix += lambda_reg * torch.eye(n_support).expand(
        tasks_per_batch, n_support, n_support).cuda()
    block_kernel_matrix = kernel_matrix.repeat(
        n_way, 1, 1)  # (n_way * tasks_per_batch, n_support, n_support)
    support_labels_one_hot = one_hot(
        support_labels.view(tasks_per_batch * n_support),
        n_way)  # (tasks_per_batch * n_support, n_way)
    support_labels_one_hot = support_labels_one_hot.transpose(
        0, 1)  # (n_way, tasks_per_batch * n_support)
    support_labels_one_hot = support_labels_one_hot.reshape(
        n_way * tasks_per_batch,
        n_support)  # (n_way*tasks_per_batch, n_support)

    G = block_kernel_matrix
    e = -2.0 * support_labels_one_hot

    # This is a fake inequlity constraint as qpth does not support QP without an inequality constraint.
    id_matrix_1 = torch.zeros(tasks_per_batch * n_way, n_support, n_support)
    C = Variable(id_matrix_1)
    h = Variable(torch.zeros((tasks_per_batch * n_way, n_support)))
    dummy = Variable(
        torch.Tensor()).cuda()  # We want to ignore the equality constraint.

    if double_precision:
        G, e, C, h = [x.double().cuda() for x in [G, e, C, h]]

    else:
        G, e, C, h = [x.float().cuda() for x in [G, e, C, h]]

    # Solve the following QP to fit SVM:
    #        \hat z =   argmin_z 1/2 z^T G z + e^T z
    #                 subject to Cz <= h
    # We use detach() to prevent backpropagation to fixed variables.
    qp_sol = QPFunction(verbose=False)(G, e.detach(), C.detach(), h.detach(),
                                       dummy.detach(), dummy.detach())
    # qp_sol = QPFunction(verbose=False)(G, e.detach(), dummy.detach(), dummy.detach(), dummy.detach(), dummy.detach())

    # qp_sol (n_way*tasks_per_batch, n_support)
    qp_sol = qp_sol.reshape(n_way, tasks_per_batch, n_support)
    # qp_sol (n_way, tasks_per_batch, n_support)
    qp_sol = qp_sol.permute(1, 2, 0)
    # qp_sol (tasks_per_batch, n_support, n_way)

    # Compute the classification score.
    compatibility = computeGramMatrix(support, query)
    compatibility = compatibility.float()
    compatibility = compatibility.unsqueeze(3).expand(tasks_per_batch,
                                                      n_support, n_query,
                                                      n_way)
    qp_sol = qp_sol.reshape(tasks_per_batch, n_support, n_way)
    logits = qp_sol.float().unsqueeze(2).expand(tasks_per_batch, n_support,
                                                n_query, n_way)
    logits = logits * compatibility
    logits = torch.sum(logits, 1)

    return logits
Exemplo n.º 26
0
def prof_instance(nz, nBatch, nTrials, cuda=True):
    nineq, neq = nz, 0
    assert (neq == 0)
    L = npr.rand(nBatch, nz, nz)
    Q = np.matmul(L, L.transpose((0, 2, 1))) + 1e-3 * np.eye(nz, nz)
    G = npr.randn(nBatch, nineq, nz)
    z0 = npr.randn(nBatch, nz)
    s0 = npr.rand(nBatch, nineq)
    p = npr.randn(nBatch, nz)
    h = np.matmul(G, np.expand_dims(z0, axis=(2))).squeeze(2) + s0
    A = npr.randn(nBatch, neq, nz)
    b = np.matmul(A, np.expand_dims(z0, axis=(2))).squeeze(2)

    lm = nn.Linear(nz, nz)

    p, L, Q, G, z0, s0, h = [torch.Tensor(x) for x in [p, L, Q, G, z0, s0, h]]
    if cuda:
        p, L, Q, G, z0, s0, h = [x.cuda() for x in [p, L, Q, G, z0, s0, h]]
        lm = lm.cuda()
    if neq > 0:
        A = torch.Tensor(A)
        b = torch.Tensor(b)
    else:
        A, b = [torch.Tensor()] * 2
    if cuda:
        A = A.cuda()
        b = b.cuda()

    p, L, Q, G, z0, s0, h, A, b = [
        Variable(x) for x in [p, L, Q, G, z0, s0, h, A, b]
    ]
    p.requires_grad = True

    linearf_times = []
    linearb_times = []
    for i in range(nTrials + 1):
        start = time.time()
        zhat_l = lm(p)
        linearf_times.append(time.time() - start)
        start = time.time()
        zhat_l.backward(torch.ones(nBatch, nz).cuda())
        linearb_times.append(time.time() - start)
    linearf_times = linearf_times[1:]
    linearb_times = linearb_times[1:]

    qpthf_times = []
    qpthb_times = []
    for i in range(nTrials + 1):
        start = time.time()
        qpf = QPFunction()
        zhat_b = qpf(Q, p, G, h, A, b)
        qpthf_times.append(time.time() - start)

        start = time.time()
        zhat_b.backward(torch.ones(nBatch, nz).cuda())
        qpthb_times.append(time.time() - start)
    qpthf_times = qpthf_times[1:]
    qpthb_times = qpthb_times[1:]

    return np.array(linearf_times), np.array(qpthf_times), \
        np.array(linearb_times), np.array(qpthb_times)
Exemplo n.º 27
0
def MetaOptNetHead_SVM_WW(query,
                          support,
                          support_labels,
                          n_way,
                          n_shot,
                          C_reg=0.00001,
                          double_precision=False):
    """
    Fits the support set with multi-class SVM and 
    returns the classification score on the query set.
    
    This is the multi-class SVM presented in:
    Support Vector Machines for Multi Class Pattern Recognition
    (Weston and Watkins, ESANN 1999).
    
    Parameters:
      query:  a (tasks_per_batch, n_query, d) Tensor.
      support:  a (tasks_per_batch, n_support, d) Tensor.
      support_labels: a (tasks_per_batch, n_support) Tensor.
      n_way: a scalar. Represents the number of classes in a few-shot classification task.
      n_shot: a scalar. Represents the number of support examples given per class.
      C_reg: a scalar. Represents the cost parameter C in SVM.
    Returns: a (tasks_per_batch, n_query, n_way) Tensor.
    """
    """
    Fits the support set with multi-class SVM and 
    returns the classification score on the query set.
    
    This is the multi-class SVM presented in:
    Support Vector Machines for Multi Class Pattern Recognition
    (Weston and Watkins, ESANN 1999).
    
    Parameters:
      query:  a (tasks_per_batch, n_query, d) Tensor.
      support:  a (tasks_per_batch, n_support, d) Tensor.
      support_labels: a (tasks_per_batch, n_support) Tensor.
      n_way: a scalar. Represents the number of classes in a few-shot classification task.
      n_shot: a scalar. Represents the number of support examples given per class.
      C_reg: a scalar. Represents the cost parameter C in SVM.
    Returns: a (tasks_per_batch, n_query, n_way) Tensor.
    """
    tasks_per_batch = query.size(0)
    n_support = support.size(1)
    n_query = query.size(1)

    assert (query.dim() == 3)
    assert (support.dim() == 3)
    assert (query.size(0) == support.size(0)
            and query.size(2) == support.size(2))
    assert (n_support == n_way * n_shot
            )  # n_support must equal to n_way * n_shot

    #In theory, \alpha is an (n_support, n_way) matrix
    #NOTE: In this implementation, we solve for a flattened vector of size (n_way*n_support)
    #In order to turn it into a matrix, you must first reshape it into an (n_way, n_support) matrix
    #then transpose it, resulting in (n_support, n_way) matrix
    kernel_matrix = computeGramMatrix(support, support) + torch.ones(
        tasks_per_batch, n_support, n_support).cuda()

    id_matrix_0 = torch.eye(n_way).expand(tasks_per_batch, n_way, n_way).cuda()
    block_kernel_matrix = batched_kronecker(id_matrix_0, kernel_matrix)

    kernel_matrix_mask_x = support_labels.reshape(
        tasks_per_batch, n_support, 1).expand(tasks_per_batch, n_support,
                                              n_support)
    kernel_matrix_mask_y = support_labels.reshape(
        tasks_per_batch, 1, n_support).expand(tasks_per_batch, n_support,
                                              n_support)
    kernel_matrix_mask = (kernel_matrix_mask_x == kernel_matrix_mask_y).float()

    block_kernel_matrix_inter = kernel_matrix_mask * kernel_matrix
    block_kernel_matrix += block_kernel_matrix_inter.repeat(1, n_way, n_way)

    kernel_matrix_mask_second_term = support_labels.reshape(
        tasks_per_batch, n_support, 1).expand(tasks_per_batch, n_support,
                                              n_support * n_way)
    kernel_matrix_mask_second_term = kernel_matrix_mask_second_term == torch.arange(
        n_way).long().repeat(n_support).reshape(n_support, n_way).transpose(
            1, 0).reshape(1, -1).repeat(n_support, 1).cuda()
    kernel_matrix_mask_second_term = kernel_matrix_mask_second_term.float()

    block_kernel_matrix -= (
        2.0 - 1e-4) * (kernel_matrix_mask_second_term *
                       kernel_matrix.repeat(1, 1, n_way)).repeat(1, n_way, 1)

    Y_support = one_hot(support_labels.view(tasks_per_batch * n_support),
                        n_way)
    Y_support = Y_support.view(tasks_per_batch, n_support, n_way)
    Y_support = Y_support.transpose(1,
                                    2)  # (tasks_per_batch, n_way, n_support)
    Y_support = Y_support.reshape(tasks_per_batch, n_way * n_support)

    G = block_kernel_matrix

    e = -2.0 * torch.ones(tasks_per_batch, n_way * n_support)
    id_matrix = torch.eye(n_way * n_support).expand(tasks_per_batch,
                                                    n_way * n_support,
                                                    n_way * n_support)

    C_mat = C_reg * torch.ones(tasks_per_batch,
                               n_way * n_support).cuda() - C_reg * Y_support

    C = Variable(torch.cat((id_matrix, -id_matrix), 1))
    #C = Variable(torch.cat((id_matrix_masked, -id_matrix_masked), 1))
    zer = torch.zeros(tasks_per_batch, n_way * n_support).cuda()

    h = Variable(torch.cat((C_mat, zer), 1))

    dummy = Variable(
        torch.Tensor()).cuda()  # We want to ignore the equality constraint.

    if double_precision:
        G, e, C, h = [x.double().cuda() for x in [G, e, C, h]]
    else:
        G, e, C, h = [x.cuda() for x in [G, e, C, h]]

    # Solve the following QP to fit SVM:
    #        \hat z =   argmin_z 1/2 z^T G z + e^T z
    #                 subject to Cz <= h
    # We use detach() to prevent backpropagation to fixed variables.
    #qp_sol = QPFunction(verbose=False)(G, e.detach(), C.detach(), h.detach(), dummy.detach(), dummy.detach())
    qp_sol = QPFunction(verbose=False)(G, e, C, h, dummy.detach(),
                                       dummy.detach())

    # Compute the classification score.
    compatibility = computeGramMatrix(support, query) + torch.ones(
        tasks_per_batch, n_support, n_query).cuda()
    compatibility = compatibility.float()
    compatibility = compatibility.unsqueeze(1).expand(tasks_per_batch, n_way,
                                                      n_support, n_query)
    qp_sol = qp_sol.float()
    qp_sol = qp_sol.reshape(tasks_per_batch, n_way, n_support)
    A_i = torch.sum(qp_sol, 1)  # (tasks_per_batch, n_support)
    A_i = A_i.unsqueeze(1).expand(tasks_per_batch, n_way, n_support)
    qp_sol = qp_sol.float().unsqueeze(3).expand(tasks_per_batch, n_way,
                                                n_support, n_query)
    Y_support_reshaped = Y_support.reshape(tasks_per_batch, n_way, n_support)
    Y_support_reshaped = A_i * Y_support_reshaped
    Y_support_reshaped = Y_support_reshaped.unsqueeze(3).expand(
        tasks_per_batch, n_way, n_support, n_query)
    logits = (Y_support_reshaped - qp_sol) * compatibility

    logits = torch.sum(logits, 2)

    return logits.transpose(1, 2)
Exemplo n.º 28
0
def MetaOptNetHead_SVM_He(query,
                          support,
                          support_labels,
                          n_way,
                          n_shot,
                          C_reg=0.01,
                          double_precision=False):
    """
    Fits the support set with multi-class SVM and 
    returns the classification score on the query set.
    
    This is the multi-class SVM presented in:
    A simplified multi-class support vector machine with reduced dual optimization
    (He et al., Pattern Recognition Letter 2012).
    
    This SVM is desirable because the dual variable of size is n_support
    (as opposed to n_way*n_support in the Weston&Watkins or Crammer&Singer multi-class SVM).
    This model is the classification head that we have initially used for our project.
    This was dropped since it turned out that it performs suboptimally on the meta-learning scenarios.
    
    Parameters:
      query:  a (tasks_per_batch, n_query, d) Tensor.
      support:  a (tasks_per_batch, n_support, d) Tensor.
      support_labels: a (tasks_per_batch, n_support) Tensor.
      n_way: a scalar. Represents the number of classes in a few-shot classification task.
      n_shot: a scalar. Represents the number of support examples given per class.
      C_reg: a scalar. Represents the cost parameter C in SVM.
    Returns: a (tasks_per_batch, n_query, n_way) Tensor.
    """

    tasks_per_batch = query.size(0)
    n_support = support.size(1)
    n_query = query.size(1)

    assert (query.dim() == 3)
    assert (support.dim() == 3)
    assert (query.size(0) == support.size(0)
            and query.size(2) == support.size(2))
    assert (n_support == n_way * n_shot
            )  # n_support must equal to n_way * n_shot

    kernel_matrix = computeGramMatrix(support, support)

    V = (support_labels * n_way -
         torch.ones(tasks_per_batch, n_support, n_way).cuda()) / (n_way - 1)
    G = computeGramMatrix(V, V).detach()
    G = kernel_matrix * G

    e = Variable(-1.0 * torch.ones(tasks_per_batch, n_support))
    id_matrix = torch.eye(n_support).expand(tasks_per_batch, n_support,
                                            n_support)
    C = Variable(torch.cat((id_matrix, -id_matrix), 1))
    h = Variable(
        torch.cat((C_reg * torch.ones(tasks_per_batch, n_support),
                   torch.zeros(tasks_per_batch, n_support)), 1))
    dummy = Variable(
        torch.Tensor()).cuda()  # We want to ignore the equality constraint.

    if double_precision:
        G, e, C, h = [x.double().cuda() for x in [G, e, C, h]]
    else:
        G, e, C, h = [x.cuda() for x in [G, e, C, h]]

    # Solve the following QP to fit SVM:
    #        \hat z =   argmin_z 1/2 z^T G z + e^T z
    #                 subject to Cz <= h
    # We use detach() to prevent backpropagation to fixed variables.
    qp_sol = QPFunction(verbose=False)(G, e.detach(), C.detach(), h.detach(),
                                       dummy.detach(), dummy.detach())

    # Compute the classification score.
    compatibility = computeGramMatrix(query, support)
    compatibility = compatibility.float()

    logits = qp_sol.float().unsqueeze(1).expand(tasks_per_batch, n_query,
                                                n_support)
    logits = logits * compatibility
    logits = logits.view(tasks_per_batch, n_query, n_shot, n_way)
    logits = torch.sum(logits, 2)

    return logits
Exemplo n.º 29
0
    def forward(self, data):
        n = data.shape[0]
        vs = torch.unsqueeze(data[:, 0], 1)
        vnexts = torch.unsqueeze(data[:, 1], 1)
        us = torch.unsqueeze(data[:, 2], 1)

        mu = self.mu
        #mu = torch.tensor([1.0])

        beta = vnexts - vs - us

        G = torch.tensor([[1.0, -1, 1], [-1, 1, 1], [-1, -1, 0]])

        Gpad = torch.tensor([[1.0, -1, 1, 0, 0, 0], [-1, 1, 1, 0, 0, 0],
                             [-1, -1, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0],
                             [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0]])

        fmats = torch.tensor([[1.0, 1, 0], [-1, -1, 0],
                              [0, 0, 1]]).unsqueeze(0).repeat(n, 1, 1)
        fvecs = torch.cat((vs, us, mu * torch.ones(vs.shape)), 1).unsqueeze(2)
        f = torch.bmm(fmats, fvecs)
        batch_zeros = torch.zeros(f.shape)
        fpad = torch.cat((f, batch_zeros), 1)

        # For prediction error
        A = torch.tensor([[1.0, -1, 0, 0, 0, 0], [-1, 1, 0, 0, 0, 0],
                          [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0],
                          [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0,
                                               0]]).repeat(n, 1, 1)
        beta_zeros = torch.zeros(beta.shape)
        b = torch.cat((-2 * beta, 2 * beta, beta_zeros, beta_zeros, beta_zeros,
                       beta_zeros), 1).unsqueeze(2)

        slack_penalty = torch.tensor([0.0, 0, 0, 1, 1,
                                      1]).repeat(n, 1).unsqueeze(2)

        a1 = 1
        a2 = 1
        a3 = 1

        Q = 2 * a1 * A + 2 * a2 * Gpad
        pdb.set_trace()
        p = a1 * b + a2 * fpad + a3 * slack_penalty

        # Constrain lambda and slacks to be >= 0
        R = -torch.eye(6)
        h = torch.zeros((1, 6))

        # Constrain G lambda + f >= 0
        #R = torch.cat((R, -G))
        # Should not have second negative here?
        R = torch.cat((R, -torch.cat((G, -torch.eye(3)), 1)))
        #R = torch.cat((R, -torch.cat((G, torch.zeros(3,3)), 1)))
        h = torch.cat((h.transpose(0, 1), f.unsqueeze(1)))
        h = h.transpose(0, 1)

        Qmod = 0.5 * (Q + Q.transpose(0, 1)) + 0.001 * torch.eye(6)

        z = QPFunction(check_Q_spd=False)(Qmod, p, R, h, torch.tensor([]),
                                          torch.tensor([]))

        #print(self.scipy_optimize(0.5 * (Q + Q.transpose(0, 1)), p, R, h))
        #assert(torch.all(torch.matmul(R, z.transpose(0, 1)) \
        #                <= (h.transpose(0, 1) + torch.ones(h.shape) * 1e-5)))

        lcp_slack = torch.matmul(Gpad, z.transpose(0, 1)).transpose(0,
                                                                    1) + fpad
        #print(z[0])

        cost = 0.5 * torch.matmul(z, torch.matmul(Qmod, z.transpose(0, 1))) \
                + torch.matmul(p, z.transpose(0, 1)) + a1 * beta**2
        return cost
Exemplo n.º 30
0
def MetaOptNetHead_Ridge(query,
                         support,
                         support_labels,
                         n_way,
                         n_shot,
                         device,
                         lambda_reg=100.0,
                         double_precision=False):
    tasks_per_batch = query.size(0)
    n_support = support.size(1)
    n_query = query.size(1)

    assert (query.dim() == 3)
    assert (support.dim() == 3)
    assert (query.size(0) == support.size(0)
            and query.size(2) == support.size(2))
    assert (n_support == n_way * n_shot
            )  # n_support must equal to n_way * n_shot

    kernel_matrix = computeGramMatrix(support, support)
    kernel_matrix += lambda_reg * torch.eye(n_support).expand(
        tasks_per_batch, n_support, n_support).cuda()

    block_kernel_matrix = kernel_matrix.repeat(
        n_way, 1, 1)  #(n_way * tasks_per_batch, n_support, n_support)

    support_labels_one_hot = one_hot(
        support_labels.view(tasks_per_batch * n_support), n_way,
        device)  # (tasks_per_batch * n_support, n_way)
    support_labels_one_hot = support_labels_one_hot.transpose(
        0, 1)  # (n_way, tasks_per_batch * n_support)
    support_labels_one_hot = support_labels_one_hot.reshape(
        n_way * tasks_per_batch,
        n_support)  # (n_way*tasks_per_batch, n_support)

    G = block_kernel_matrix
    e = -2.0 * support_labels_one_hot

    #This is a fake inequlity constraint as qpth does not support QP without an inequality constraint.
    id_matrix_1 = torch.zeros(tasks_per_batch * n_way, n_support, n_support)
    C = Variable(id_matrix_1)
    h = Variable(torch.zeros((tasks_per_batch * n_way, n_support)))
    dummy = Variable(torch.Tensor()).cuda()

    if double_precision:
        G, e, C, h = [x.double().cuda() for x in [G, e, C, h]]

    else:
        G, e, C, h = [x.float().cuda() for x in [G, e, C, h]]

    qp_sol = QPFunction(verbose=False)(G, e.detach(), C.detach(), h.detach(),
                                       dummy.detach(), dummy.detach())
    #qp_sol = QPFunction(verbose=False)(G, e.detach(), dummy.detach(), dummy.detach(), dummy.detach(), dummy.detach())

    #qp_sol (n_way*tasks_per_batch, n_support)
    qp_sol = qp_sol.reshape(n_way, tasks_per_batch, n_support)
    #qp_sol (n_way, tasks_per_batch, n_support)
    qp_sol = qp_sol.permute(1, 2, 0)
    #qp_sol (tasks_per_batch, n_support, n_way)

    # Compute the classification score.
    compatibility = computeGramMatrix(support, query)
    compatibility = compatibility.float()
    compatibility = compatibility.unsqueeze(3).expand(tasks_per_batch,
                                                      n_support, n_query,
                                                      n_way)
    qp_sol = qp_sol.reshape(tasks_per_batch, n_support, n_way)
    logits = qp_sol.float().unsqueeze(2).expand(tasks_per_batch, n_support,
                                                n_query, n_way)
    logits = logits * compatibility
    logits = torch.sum(logits, 1)

    return logits