def CMloss_Fnorm(self, query, support, support_labels, n_way, n_shot): tasks_per_batch = query.size(0) n_support = support.size(1) n_query = query.size(1) assert (query.dim() == 3) assert (support.dim() == 3) assert (query.size(0) == support.size(0) and query.size(2) == support.size(2)) assert (n_support == n_way * n_shot) # n_support must equal to n_way * n_shot # Here we solve the dual problem: # Note that the classes are indexed by m & samples are indexed by i. # min_{\alpha} 0.5 \sum_m ||w_m(\alpha)||^2 + \sum_i \sum_m e^m_i alpha^m_i # s.t. \alpha^m_i <= C^m_i \forall m,i , \sum_m \alpha^m_i=0 \forall i # where w_m(\alpha) = \sum_i \alpha^m_i x_i, # and C^m_i = C if m = y_i, # C^m_i = 0 if m != y_i. # This borrows the notation of liblinear. # \alpha is an (n_support, n_way) matrix kernel_matrix = computeGramMatrix(support, support) id_matrix_0 = torch.eye(n_way).expand(tasks_per_batch, n_way, n_way).cuda() block_kernel_matrix = batched_kronecker(kernel_matrix, id_matrix_0) # This seems to help avoid PSD error from the QP solver. block_kernel_matrix += 1.0 * torch.eye(n_way * n_support).expand(tasks_per_batch, n_way * n_support, n_way * n_support).cuda() support_labels_one_hot = one_hot(support_labels.view(tasks_per_batch * n_support), n_way) # (tasks_per_batch * n_support, n_support) support_labels_one_hot = support_labels_one_hot.view(tasks_per_batch, n_support, n_way) support_labels_one_hot = support_labels_one_hot.reshape(tasks_per_batch, n_support * n_way) G = block_kernel_matrix e = -1.0 * support_labels_one_hot # print (G.size()) # This part is for the inequality constraints: # \alpha^m_i <= C^m_i \forall m,i # where C^m_i = C if m = y_i, # C^m_i = 0 if m != y_i. id_matrix_1 = torch.eye(n_way * n_support).expand(tasks_per_batch, n_way * n_support, n_way * n_support) C = Variable(id_matrix_1) h = Variable(self.C_reg * support_labels_one_hot) # print (C.size(), h.size()) # This part is for the equality constraints: # \sum_m \alpha^m_i=0 \forall i id_matrix_2 = torch.eye(n_support).expand(tasks_per_batch, n_support, n_support).cuda() A = Variable(batched_kronecker(id_matrix_2, torch.ones(tasks_per_batch, 1, n_way).cuda())) b = Variable(torch.zeros(tasks_per_batch, n_support)) # print (A.size(), b.size()) if self.double_precision: G, e, C, h, A, b = [x.double().cuda() for x in [G, e, C, h, A, b]] else: G, e, C, h, A, b = [x.float().cuda() for x in [G, e, C, h, A, b]] # Solve the following QP to fit SVM: # \hat z = argmin_z 1/2 z^T G z + e^T z # subject to Cz <= h # We use detach() to prevent backpropagation to fixed variables. qp_sol = QPFunction(verbose=False, maxIter=self.maxIter)(G, e.detach(), C.detach(), h.detach(), A.detach(), b.detach()) # Compute the classification score. qp_sol = qp_sol.reshape(tasks_per_batch, n_support, n_way) logits = qp_sol.float().unsqueeze(2).expand(tasks_per_batch, n_support, n_query, n_way) w_query = torch.bmm(qp_sol.transpose(1,2), support) dis_matrix = self.support_w - w_query revese_loss = 0 for i in range(tasks_per_batch): revese_loss += (torch.trace(torch.mm(dis_matrix[i,:,:], dis_matrix[i,:,:].t()))).sqrt() revese_loss = 1.0 * revese_loss / (n_way* tasks_per_batch) return revese_loss
def forward(self, data): n = data.shape[0] vs = torch.unsqueeze(data[:, 0], 1) vnexts = torch.unsqueeze(data[:, 1], 1) us = torch.unsqueeze(data[:, 2], 1) mu = self.mu #mu = torch.tensor([1.0]) beta = vnexts - vs - us G = torch.tensor([[1.0, -1, 1], [-1, 1, 1], [-1, -1, 0]]) Gpad = torch.tensor([[1.0, -1, 1, 0, 0, 0], [-1, 1, 1, 0, 0, 0], [-1, -1, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0]]) fmats = torch.tensor([[1.0, 1, 0], [-1, -1, 0], [0, 0, 1]]).unsqueeze(0).repeat(n, 1, 1) fvecs = torch.cat((vs, us, mu * torch.ones(vs.shape)), 1).unsqueeze(2) f = torch.bmm(fmats, fvecs) batch_zeros = torch.zeros(f.shape) fpad = torch.cat((f, batch_zeros), 1) # For prediction error A = torch.tensor([[1.0, -1, 0, 0, 0, 0], [-1, 1, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0]]).repeat(n, 1, 1) beta_zeros = torch.zeros(beta.shape) b = torch.cat((-2 * beta, 2 * beta, beta_zeros, beta_zeros, beta_zeros, beta_zeros), 1).unsqueeze(2) slack_penalty = torch.tensor([0.0, 0, 0, 1, 1, 1]).repeat(n, 1).unsqueeze(2) a1 = 1 a2 = 1 a3 = 1 Q = 2 * a1 * A + 2 * a2 * Gpad pdb.set_trace() p = a1 * b + a2 * fpad + a3 * slack_penalty # Constrain lambda and slacks to be >= 0 R = -torch.eye(6) h = torch.zeros((1, 6)) # Constrain G lambda + f >= 0 #R = torch.cat((R, -G)) # Should not have second negative here? R = torch.cat((R, -torch.cat((G, -torch.eye(3)), 1))) #R = torch.cat((R, -torch.cat((G, torch.zeros(3,3)), 1))) h = torch.cat((h.transpose(0, 1), f.unsqueeze(1))) h = h.transpose(0, 1) Qmod = 0.5 * (Q + Q.transpose(0, 1)) + 0.001 * torch.eye(6) z = QPFunction(check_Q_spd=False)(Qmod, p, R, h, torch.tensor([]), torch.tensor([])) #print(self.scipy_optimize(0.5 * (Q + Q.transpose(0, 1)), p, R, h)) #assert(torch.all(torch.matmul(R, z.transpose(0, 1)) \ # <= (h.transpose(0, 1) + torch.ones(h.shape) * 1e-5))) lcp_slack = torch.matmul(Gpad, z.transpose(0, 1)).transpose(0, 1) + fpad #print(z[0]) cost = 0.5 * torch.matmul(z, torch.matmul(Qmod, z.transpose(0, 1))) \ + torch.matmul(p, z.transpose(0, 1)) + a1 * beta**2 return cost
def forward(self, query, support, support_labels, n_way, n_shot): """ Fits the support set with multi-class SVM and returns the classification score on the query set. This is the multi-class SVM presented in: On the Algorithmic Implementation of Multiclass Kernel-based Vector Machines (Crammer and Singer, Journal of Machine Learning Research 2001). This model is the classification head that we use for the final version. Parameters: query: a (tasks_per_batch, n_query, d) Tensor. support: a (tasks_per_batch, n_support, d) Tensor. support_labels: a (tasks_per_batch, n_support) Tensor. n_way: a scalar. Represents the number of classes in a few-shot classification task. n_shot: a scalar. Represents the number of support examples given per class. C_reg: a scalar. Represents the cost parameter C in SVM. Returns: a (tasks_per_batch, n_query, n_way) Tensor. """ tasks_per_batch = query.size(0) n_support = support.size(1) n_query = query.size(1) assert (query.dim() == 3) assert (support.dim() == 3) assert (query.size(0) == support.size(0) and query.size(2) == support.size(2)) assert (n_support == n_way * n_shot) # n_support must equal to n_way * n_shot # Here we solve the dual problem: # Note that the classes are indexed by m & samples are indexed by i. # min_{\alpha} 0.5 \sum_m ||w_m(\alpha)||^2 + \sum_i \sum_m e^m_i alpha^m_i # s.t. \alpha^m_i <= C^m_i \forall m,i , \sum_m \alpha^m_i=0 \forall i # where w_m(\alpha) = \sum_i \alpha^m_i x_i, # and C^m_i = C if m = y_i, # C^m_i = 0 if m != y_i. # This borrows the notation of liblinear. # \alpha is an (n_support, n_way) matrix kernel_matrix = computeGramMatrix(support, support) id_matrix_0 = torch.eye(n_way).expand(tasks_per_batch, n_way, n_way).cuda() block_kernel_matrix = batched_kronecker(kernel_matrix, id_matrix_0) # This seems to help avoid PSD error from the QP solver. block_kernel_matrix += 1.0 * torch.eye(n_way * n_support).expand(tasks_per_batch, n_way * n_support, n_way * n_support).cuda() support_labels_one_hot = one_hot(support_labels.view(tasks_per_batch * n_support), n_way) # (tasks_per_batch * n_support, n_support) support_labels_one_hot = support_labels_one_hot.view(tasks_per_batch, n_support, n_way) support_labels_one_hot = support_labels_one_hot.reshape(tasks_per_batch, n_support * n_way) G = block_kernel_matrix e = -1.0 * support_labels_one_hot # print (G.size()) # This part is for the inequality constraints: # \alpha^m_i <= C^m_i \forall m,i # where C^m_i = C if m = y_i, # C^m_i = 0 if m != y_i. id_matrix_1 = torch.eye(n_way * n_support).expand(tasks_per_batch, n_way * n_support, n_way * n_support) C = Variable(id_matrix_1) h = Variable(self.C_reg * support_labels_one_hot) # print (C.size(), h.size()) # This part is for the equality constraints: # \sum_m \alpha^m_i=0 \forall i id_matrix_2 = torch.eye(n_support).expand(tasks_per_batch, n_support, n_support).cuda() A = Variable(batched_kronecker(id_matrix_2, torch.ones(tasks_per_batch, 1, n_way).cuda())) b = Variable(torch.zeros(tasks_per_batch, n_support)) # print (A.size(), b.size()) if self.double_precision: G, e, C, h, A, b = [x.double().cuda() for x in [G, e, C, h, A, b]] else: G, e, C, h, A, b = [x.float().cuda() for x in [G, e, C, h, A, b]] # Solve the following QP to fit SVM: # \hat z = argmin_z 1/2 z^T G z + e^T z # subject to Cz <= h # We use detach() to prevent backpropagation to fixed variables. qp_sol = QPFunction(verbose=False, maxIter=self.maxIter)(G, e.detach(), C.detach(), h.detach(), A.detach(), b.detach()) # Compute the classification score. compatibility = computeGramMatrix(support, query) compatibility = compatibility.float() compatibility = compatibility.unsqueeze(3).expand(tasks_per_batch, n_support, n_query, n_way) qp_sol = qp_sol.reshape(tasks_per_batch, n_support, n_way) logits = qp_sol.float().unsqueeze(2).expand(tasks_per_batch, n_support, n_query, n_way) logits = logits * compatibility logits = torch.sum(logits, 1) self.support_w = torch.bmm(qp_sol.transpose(1,2), support) return logits