Exemplo n.º 1
0
    def CMloss_Fnorm(self, query, support, support_labels, n_way, n_shot):

        tasks_per_batch = query.size(0)
        n_support = support.size(1)
        n_query = query.size(1)

        assert (query.dim() == 3)
        assert (support.dim() == 3)
        assert (query.size(0) == support.size(0) and query.size(2) == support.size(2))
        assert (n_support == n_way * n_shot)  # n_support must equal to n_way * n_shot

        # Here we solve the dual problem:
        # Note that the classes are indexed by m & samples are indexed by i.
        # min_{\alpha}  0.5 \sum_m ||w_m(\alpha)||^2 + \sum_i \sum_m e^m_i alpha^m_i
        # s.t.  \alpha^m_i <= C^m_i \forall m,i , \sum_m \alpha^m_i=0 \forall i

        # where w_m(\alpha) = \sum_i \alpha^m_i x_i,
        # and C^m_i = C if m  = y_i,
        # C^m_i = 0 if m != y_i.
        # This borrows the notation of liblinear.

        # \alpha is an (n_support, n_way) matrix
        kernel_matrix = computeGramMatrix(support, support)

        id_matrix_0 = torch.eye(n_way).expand(tasks_per_batch, n_way, n_way).cuda()
        block_kernel_matrix = batched_kronecker(kernel_matrix, id_matrix_0)
        # This seems to help avoid PSD error from the QP solver.
        block_kernel_matrix += 1.0 * torch.eye(n_way * n_support).expand(tasks_per_batch, n_way * n_support,
                                                                         n_way * n_support).cuda()

        support_labels_one_hot = one_hot(support_labels.view(tasks_per_batch * n_support),
                                         n_way)  # (tasks_per_batch * n_support, n_support)
        support_labels_one_hot = support_labels_one_hot.view(tasks_per_batch, n_support, n_way)
        support_labels_one_hot = support_labels_one_hot.reshape(tasks_per_batch, n_support * n_way)

        G = block_kernel_matrix
        e = -1.0 * support_labels_one_hot
        # print (G.size())
        # This part is for the inequality constraints:
        # \alpha^m_i <= C^m_i \forall m,i
        # where C^m_i = C if m  = y_i,
        # C^m_i = 0 if m != y_i.
        id_matrix_1 = torch.eye(n_way * n_support).expand(tasks_per_batch, n_way * n_support, n_way * n_support)
        C = Variable(id_matrix_1)
        h = Variable(self.C_reg * support_labels_one_hot)
        # print (C.size(), h.size())
        # This part is for the equality constraints:
        # \sum_m \alpha^m_i=0 \forall i
        id_matrix_2 = torch.eye(n_support).expand(tasks_per_batch, n_support, n_support).cuda()

        A = Variable(batched_kronecker(id_matrix_2, torch.ones(tasks_per_batch, 1, n_way).cuda()))
        b = Variable(torch.zeros(tasks_per_batch, n_support))
        # print (A.size(), b.size())
        if self.double_precision:
            G, e, C, h, A, b = [x.double().cuda() for x in [G, e, C, h, A, b]]
        else:
            G, e, C, h, A, b = [x.float().cuda() for x in [G, e, C, h, A, b]]

        # Solve the following QP to fit SVM:
        #        \hat z =   argmin_z 1/2 z^T G z + e^T z
        #                 subject to Cz <= h
        # We use detach() to prevent backpropagation to fixed variables.
        qp_sol = QPFunction(verbose=False, maxIter=self.maxIter)(G, e.detach(), C.detach(), h.detach(), A.detach(),
                                                                 b.detach())

        # Compute the classification score.
        qp_sol = qp_sol.reshape(tasks_per_batch, n_support, n_way)
        logits = qp_sol.float().unsqueeze(2).expand(tasks_per_batch, n_support, n_query, n_way)
        w_query = torch.bmm(qp_sol.transpose(1,2), support)

        dis_matrix = self.support_w - w_query
        revese_loss = 0
        for i in range(tasks_per_batch):
            revese_loss += (torch.trace(torch.mm(dis_matrix[i,:,:], dis_matrix[i,:,:].t()))).sqrt()

        revese_loss = 1.0 * revese_loss / (n_way* tasks_per_batch)


        return revese_loss
Exemplo n.º 2
0
    def forward(self, data):
        n = data.shape[0]
        vs = torch.unsqueeze(data[:, 0], 1)
        vnexts = torch.unsqueeze(data[:, 1], 1)
        us = torch.unsqueeze(data[:, 2], 1)

        mu = self.mu
        #mu = torch.tensor([1.0])

        beta = vnexts - vs - us

        G = torch.tensor([[1.0, -1, 1], [-1, 1, 1], [-1, -1, 0]])

        Gpad = torch.tensor([[1.0, -1, 1, 0, 0, 0], [-1, 1, 1, 0, 0, 0],
                             [-1, -1, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0],
                             [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0]])

        fmats = torch.tensor([[1.0, 1, 0], [-1, -1, 0],
                              [0, 0, 1]]).unsqueeze(0).repeat(n, 1, 1)
        fvecs = torch.cat((vs, us, mu * torch.ones(vs.shape)), 1).unsqueeze(2)
        f = torch.bmm(fmats, fvecs)
        batch_zeros = torch.zeros(f.shape)
        fpad = torch.cat((f, batch_zeros), 1)

        # For prediction error
        A = torch.tensor([[1.0, -1, 0, 0, 0, 0], [-1, 1, 0, 0, 0, 0],
                          [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0],
                          [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0,
                                               0]]).repeat(n, 1, 1)
        beta_zeros = torch.zeros(beta.shape)
        b = torch.cat((-2 * beta, 2 * beta, beta_zeros, beta_zeros, beta_zeros,
                       beta_zeros), 1).unsqueeze(2)

        slack_penalty = torch.tensor([0.0, 0, 0, 1, 1,
                                      1]).repeat(n, 1).unsqueeze(2)

        a1 = 1
        a2 = 1
        a3 = 1

        Q = 2 * a1 * A + 2 * a2 * Gpad
        pdb.set_trace()
        p = a1 * b + a2 * fpad + a3 * slack_penalty

        # Constrain lambda and slacks to be >= 0
        R = -torch.eye(6)
        h = torch.zeros((1, 6))

        # Constrain G lambda + f >= 0
        #R = torch.cat((R, -G))
        # Should not have second negative here?
        R = torch.cat((R, -torch.cat((G, -torch.eye(3)), 1)))
        #R = torch.cat((R, -torch.cat((G, torch.zeros(3,3)), 1)))
        h = torch.cat((h.transpose(0, 1), f.unsqueeze(1)))
        h = h.transpose(0, 1)

        Qmod = 0.5 * (Q + Q.transpose(0, 1)) + 0.001 * torch.eye(6)

        z = QPFunction(check_Q_spd=False)(Qmod, p, R, h, torch.tensor([]),
                                          torch.tensor([]))

        #print(self.scipy_optimize(0.5 * (Q + Q.transpose(0, 1)), p, R, h))
        #assert(torch.all(torch.matmul(R, z.transpose(0, 1)) \
        #                <= (h.transpose(0, 1) + torch.ones(h.shape) * 1e-5)))

        lcp_slack = torch.matmul(Gpad, z.transpose(0, 1)).transpose(0,
                                                                    1) + fpad
        #print(z[0])

        cost = 0.5 * torch.matmul(z, torch.matmul(Qmod, z.transpose(0, 1))) \
                + torch.matmul(p, z.transpose(0, 1)) + a1 * beta**2
        return cost
Exemplo n.º 3
0
    def forward(self, query, support, support_labels, n_way, n_shot):
        """
        Fits the support set with multi-class SVM and
        returns the classification score on the query set.

        This is the multi-class SVM presented in:
        On the Algorithmic Implementation of Multiclass Kernel-based Vector Machines
        (Crammer and Singer, Journal of Machine Learning Research 2001).

        This model is the classification head that we use for the final version.
        Parameters:
          query:  a (tasks_per_batch, n_query, d) Tensor.
          support:  a (tasks_per_batch, n_support, d) Tensor.
          support_labels: a (tasks_per_batch, n_support) Tensor.
          n_way: a scalar. Represents the number of classes in a few-shot classification task.
          n_shot: a scalar. Represents the number of support examples given per class.
          C_reg: a scalar. Represents the cost parameter C in SVM.
        Returns: a (tasks_per_batch, n_query, n_way) Tensor.
        """

        tasks_per_batch = query.size(0)
        n_support = support.size(1)
        n_query = query.size(1)

        assert (query.dim() == 3)
        assert (support.dim() == 3)
        assert (query.size(0) == support.size(0) and query.size(2) == support.size(2))
        assert (n_support == n_way * n_shot)  # n_support must equal to n_way * n_shot

        # Here we solve the dual problem:
        # Note that the classes are indexed by m & samples are indexed by i.
        # min_{\alpha}  0.5 \sum_m ||w_m(\alpha)||^2 + \sum_i \sum_m e^m_i alpha^m_i
        # s.t.  \alpha^m_i <= C^m_i \forall m,i , \sum_m \alpha^m_i=0 \forall i

        # where w_m(\alpha) = \sum_i \alpha^m_i x_i,
        # and C^m_i = C if m  = y_i,
        # C^m_i = 0 if m != y_i.
        # This borrows the notation of liblinear.

        # \alpha is an (n_support, n_way) matrix
        kernel_matrix = computeGramMatrix(support, support)

        id_matrix_0 = torch.eye(n_way).expand(tasks_per_batch, n_way, n_way).cuda()
        block_kernel_matrix = batched_kronecker(kernel_matrix, id_matrix_0)
        # This seems to help avoid PSD error from the QP solver.
        block_kernel_matrix += 1.0 * torch.eye(n_way * n_support).expand(tasks_per_batch, n_way * n_support,
                                                                         n_way * n_support).cuda()

        support_labels_one_hot = one_hot(support_labels.view(tasks_per_batch * n_support),
                                         n_way)  # (tasks_per_batch * n_support, n_support)
        support_labels_one_hot = support_labels_one_hot.view(tasks_per_batch, n_support, n_way)
        support_labels_one_hot = support_labels_one_hot.reshape(tasks_per_batch, n_support * n_way)

        G = block_kernel_matrix
        e = -1.0 * support_labels_one_hot
        # print (G.size())
        # This part is for the inequality constraints:
        # \alpha^m_i <= C^m_i \forall m,i
        # where C^m_i = C if m  = y_i,
        # C^m_i = 0 if m != y_i.
        id_matrix_1 = torch.eye(n_way * n_support).expand(tasks_per_batch, n_way * n_support, n_way * n_support)
        C = Variable(id_matrix_1)
        h = Variable(self.C_reg * support_labels_one_hot)
        # print (C.size(), h.size())
        # This part is for the equality constraints:
        # \sum_m \alpha^m_i=0 \forall i
        id_matrix_2 = torch.eye(n_support).expand(tasks_per_batch, n_support, n_support).cuda()

        A = Variable(batched_kronecker(id_matrix_2, torch.ones(tasks_per_batch, 1, n_way).cuda()))
        b = Variable(torch.zeros(tasks_per_batch, n_support))
        # print (A.size(), b.size())
        if self.double_precision:
            G, e, C, h, A, b = [x.double().cuda() for x in [G, e, C, h, A, b]]
        else:
            G, e, C, h, A, b = [x.float().cuda() for x in [G, e, C, h, A, b]]

        # Solve the following QP to fit SVM:
        #        \hat z =   argmin_z 1/2 z^T G z + e^T z
        #                 subject to Cz <= h
        # We use detach() to prevent backpropagation to fixed variables.
        qp_sol = QPFunction(verbose=False, maxIter=self.maxIter)(G, e.detach(), C.detach(), h.detach(), A.detach(),
                                                            b.detach())

        # Compute the classification score.
        compatibility = computeGramMatrix(support, query)
        compatibility = compatibility.float()
        compatibility = compatibility.unsqueeze(3).expand(tasks_per_batch, n_support, n_query, n_way)
        qp_sol = qp_sol.reshape(tasks_per_batch, n_support, n_way)
        logits = qp_sol.float().unsqueeze(2).expand(tasks_per_batch, n_support, n_query, n_way)
        logits = logits * compatibility
        logits = torch.sum(logits, 1)

        self.support_w = torch.bmm(qp_sol.transpose(1,2), support)

        return logits