def MetaOptNetHead_SVM(query, support, support_labels, n_way, n_shot, C_reg=0.1, double_precision=False, maxIter=15): """ Fits the support set with multi-class SVM and returns the classification score on the query set. This is the multi-class SVM presented in: On the Algorithmic Implementation of Multiclass Kernel-based Vector Machines (Crammer and Singer, Journal of Machine Learning Research 2001). This model is the classification head that we use for the final version. Parameters: query: a (tasks_per_batch, n_query, d) Tensor. support: a (tasks_per_batch, n_support, d) Tensor. support_labels: a (tasks_per_batch, n_support) Tensor. n_way: a scalar. Represents the number of classes in a few-shot classification task. n_shot: a scalar. Represents the number of support examples given per class. C_reg: a scalar. Represents the cost parameter C in SVM. Returns: a (tasks_per_batch, n_query, n_way) Tensor. """ tasks_per_batch = query.size(0) n_support = support.size(1) n_query = query.size(1) assert (query.dim() == 3) assert (support.dim() == 3) assert (query.size(0) == support.size(0) and query.size(2) == support.size(2)) assert (n_support == n_way * n_shot ) # n_support must equal to n_way * n_shot # Here we solve the dual problem: # Note that the classes are indexed by m & samples are indexed by i. # min_{\alpha} 0.5 \sum_m ||w_m(\alpha)||^2 + \sum_i \sum_m e^m_i alpha^m_i # s.t. \alpha^m_i <= C^m_i \forall m,i , \sum_m \alpha^m_i=0 \forall i # where w_m(\alpha) = \sum_i \alpha^m_i x_i, # and C^m_i = C if m = y_i, # C^m_i = 0 if m != y_i. # This borrows the notation of liblinear. # \alpha is an (n_support, n_way) matrix kernel_matrix = computeGramMatrix(support, support) id_matrix_0 = torch.eye(n_way).expand(tasks_per_batch, n_way, n_way).cuda() block_kernel_matrix = batched_kronecker(kernel_matrix, id_matrix_0) block_kernel_matrix += 1.0 * torch.eye(n_way * n_support).expand( tasks_per_batch, n_way * n_support, n_way * n_support).cuda() support_labels_one_hot = one_hot( support_labels.view(tasks_per_batch * n_support), n_way) # (tasks_per_batch * n_support, n_support) support_labels_one_hot = support_labels_one_hot.view( tasks_per_batch, n_support, n_way) support_labels_one_hot = support_labels_one_hot.reshape( tasks_per_batch, n_support * n_way) G = block_kernel_matrix e = -1.0 * support_labels_one_hot # print (G.size()) # This part is for the inequality constraints: # \alpha^m_i <= C^m_i \forall m,i # where C^m_i = C if m = y_i, # C^m_i = 0 if m != y_i. id_matrix_1 = torch.eye(n_way * n_support).expand(tasks_per_batch, n_way * n_support, n_way * n_support) C = Variable(id_matrix_1) h = Variable(C_reg * support_labels_one_hot) # print (C.size(), h.size()) # This part is for the equality constraints: # \sum_m \alpha^m_i=0 \forall i id_matrix_2 = torch.eye(n_support).expand(tasks_per_batch, n_support, n_support).cuda() A = Variable( batched_kronecker(id_matrix_2, torch.ones(tasks_per_batch, 1, n_way).cuda())) b = Variable(torch.zeros(tasks_per_batch, n_support)) if double_precision: G, e, C, h, A, b = [x.double().cuda() for x in [G, e, C, h, A, b]] else: G, e, C, h, A, b = [x.float().cuda() for x in [G, e, C, h, A, b]] # Solve the following QP to fit SVM: # \hat z = argmin_z 1/2 z^T G z + e^T z # subject to Cz <= h # We use detach() to prevent backpropagation to fixed variables. qp_sol = QPFunction(verbose=False, maxIter=maxIter)(G, e.detach(), C.detach(), h.detach(), A.detach(), b.detach()) # Compute the classification score. compatibility = computeGramMatrix(support, query) compatibility = compatibility.float() compatibility = compatibility.unsqueeze(3).expand(tasks_per_batch, n_support, n_query, n_way) qp_sol = qp_sol.reshape(tasks_per_batch, n_support, n_way) logits = qp_sol.float().unsqueeze(2).expand(tasks_per_batch, n_support, n_query, n_way) logits = logits * compatibility logits = torch.sum(logits, 1) return logits
def MetaOptNetHead_SVM_WW(query, support, support_labels, n_way, n_shot, C_reg=0.00001, double_precision=False): """ Fits the support set with multi-class SVM and returns the classification score on the query set. This is the multi-class SVM presented in: Support Vector Machines for Multi Class Pattern Recognition (Weston and Watkins, ESANN 1999). Parameters: query: a (tasks_per_batch, n_query, d) Tensor. support: a (tasks_per_batch, n_support, d) Tensor. support_labels: a (tasks_per_batch, n_support) Tensor. n_way: a scalar. Represents the number of classes in a few-shot classification task. n_shot: a scalar. Represents the number of support examples given per class. C_reg: a scalar. Represents the cost parameter C in SVM. Returns: a (tasks_per_batch, n_query, n_way) Tensor. """ """ Fits the support set with multi-class SVM and returns the classification score on the query set. This is the multi-class SVM presented in: Support Vector Machines for Multi Class Pattern Recognition (Weston and Watkins, ESANN 1999). Parameters: query: a (tasks_per_batch, n_query, d) Tensor. support: a (tasks_per_batch, n_support, d) Tensor. support_labels: a (tasks_per_batch, n_support) Tensor. n_way: a scalar. Represents the number of classes in a few-shot classification task. n_shot: a scalar. Represents the number of support examples given per class. C_reg: a scalar. Represents the cost parameter C in SVM. Returns: a (tasks_per_batch, n_query, n_way) Tensor. """ tasks_per_batch = query.size(0) n_support = support.size(1) n_query = query.size(1) assert (query.dim() == 3) assert (support.dim() == 3) assert (query.size(0) == support.size(0) and query.size(2) == support.size(2)) assert (n_support == n_way * n_shot ) # n_support must equal to n_way * n_shot #In theory, \alpha is an (n_support, n_way) matrix #NOTE: In this implementation, we solve for a flattened vector of size (n_way*n_support) #In order to turn it into a matrix, you must first reshape it into an (n_way, n_support) matrix #then transpose it, resulting in (n_support, n_way) matrix kernel_matrix = computeGramMatrix(support, support) + torch.ones( tasks_per_batch, n_support, n_support).cuda() id_matrix_0 = torch.eye(n_way).expand(tasks_per_batch, n_way, n_way).cuda() block_kernel_matrix = batched_kronecker(id_matrix_0, kernel_matrix) kernel_matrix_mask_x = support_labels.reshape( tasks_per_batch, n_support, 1).expand(tasks_per_batch, n_support, n_support) kernel_matrix_mask_y = support_labels.reshape( tasks_per_batch, 1, n_support).expand(tasks_per_batch, n_support, n_support) kernel_matrix_mask = (kernel_matrix_mask_x == kernel_matrix_mask_y).float() block_kernel_matrix_inter = kernel_matrix_mask * kernel_matrix block_kernel_matrix += block_kernel_matrix_inter.repeat(1, n_way, n_way) kernel_matrix_mask_second_term = support_labels.reshape( tasks_per_batch, n_support, 1).expand(tasks_per_batch, n_support, n_support * n_way) kernel_matrix_mask_second_term = kernel_matrix_mask_second_term == torch.arange( n_way).long().repeat(n_support).reshape(n_support, n_way).transpose( 1, 0).reshape(1, -1).repeat(n_support, 1).cuda() kernel_matrix_mask_second_term = kernel_matrix_mask_second_term.float() block_kernel_matrix -= ( 2.0 - 1e-4) * (kernel_matrix_mask_second_term * kernel_matrix.repeat(1, 1, n_way)).repeat(1, n_way, 1) Y_support = one_hot(support_labels.view(tasks_per_batch * n_support), n_way) Y_support = Y_support.view(tasks_per_batch, n_support, n_way) Y_support = Y_support.transpose(1, 2) # (tasks_per_batch, n_way, n_support) Y_support = Y_support.reshape(tasks_per_batch, n_way * n_support) G = block_kernel_matrix e = -2.0 * torch.ones(tasks_per_batch, n_way * n_support) id_matrix = torch.eye(n_way * n_support).expand(tasks_per_batch, n_way * n_support, n_way * n_support) C_mat = C_reg * torch.ones(tasks_per_batch, n_way * n_support).cuda() - C_reg * Y_support C = Variable(torch.cat((id_matrix, -id_matrix), 1)) #C = Variable(torch.cat((id_matrix_masked, -id_matrix_masked), 1)) zer = torch.zeros(tasks_per_batch, n_way * n_support).cuda() h = Variable(torch.cat((C_mat, zer), 1)) dummy = Variable( torch.Tensor()).cuda() # We want to ignore the equality constraint. if double_precision: G, e, C, h = [x.double().cuda() for x in [G, e, C, h]] else: G, e, C, h = [x.cuda() for x in [G, e, C, h]] # Solve the following QP to fit SVM: # \hat z = argmin_z 1/2 z^T G z + e^T z # subject to Cz <= h # We use detach() to prevent backpropagation to fixed variables. #qp_sol = QPFunction(verbose=False)(G, e.detach(), C.detach(), h.detach(), dummy.detach(), dummy.detach()) qp_sol = QPFunction(verbose=False)(G, e, C, h, dummy.detach(), dummy.detach()) # Compute the classification score. compatibility = computeGramMatrix(support, query) + torch.ones( tasks_per_batch, n_support, n_query).cuda() compatibility = compatibility.float() compatibility = compatibility.unsqueeze(1).expand(tasks_per_batch, n_way, n_support, n_query) qp_sol = qp_sol.float() qp_sol = qp_sol.reshape(tasks_per_batch, n_way, n_support) A_i = torch.sum(qp_sol, 1) # (tasks_per_batch, n_support) A_i = A_i.unsqueeze(1).expand(tasks_per_batch, n_way, n_support) qp_sol = qp_sol.float().unsqueeze(3).expand(tasks_per_batch, n_way, n_support, n_query) Y_support_reshaped = Y_support.reshape(tasks_per_batch, n_way, n_support) Y_support_reshaped = A_i * Y_support_reshaped Y_support_reshaped = Y_support_reshaped.unsqueeze(3).expand( tasks_per_batch, n_way, n_support, n_query) logits = (Y_support_reshaped - qp_sol) * compatibility logits = torch.sum(logits, 2) return logits.transpose(1, 2)
def MetaOptNetHead_Ridge(query, support, support_labels, n_way, n_shot, lambda_reg=50.0, double_precision=False): """ Fits the support set with ridge regression and returns the classification score on the query set. Parameters: query: a (tasks_per_batch, n_query, d) Tensor. support: a (tasks_per_batch, n_support, d) Tensor. support_labels: a (tasks_per_batch, n_support) Tensor. n_way: a scalar. Represents the number of classes in a few-shot classification task. n_shot: a scalar. Represents the number of support examples given per class. lambda_reg: a scalar. Represents the strength of L2 regularization. Returns: a (tasks_per_batch, n_query, n_way) Tensor. """ tasks_per_batch = query.size(0) n_support = support.size(1) n_query = query.size(1) assert (query.dim() == 3) assert (support.dim() == 3) assert (query.size(0) == support.size(0) and query.size(2) == support.size(2)) assert (n_support == n_way * n_shot ) # n_support must equal to n_way * n_shot # Here we solve the dual problem: # Note that the classes are indexed by m & samples are indexed by i. # min_{\alpha} 0.5 \sum_m ||w_m(\alpha)||^2 + \sum_i \sum_m e^m_i alpha^m_i # where w_m(\alpha) = \sum_i \alpha^m_i x_i, # \alpha is an (n_support, n_way) matrix kernel_matrix = computeGramMatrix(support, support) kernel_matrix += lambda_reg * torch.eye(n_support).expand( tasks_per_batch, n_support, n_support).cuda() block_kernel_matrix = kernel_matrix.repeat( n_way, 1, 1) # (n_way * tasks_per_batch, n_support, n_support) support_labels_one_hot = one_hot( support_labels.view(tasks_per_batch * n_support), n_way) # (tasks_per_batch * n_support, n_way) support_labels_one_hot = support_labels_one_hot.transpose( 0, 1) # (n_way, tasks_per_batch * n_support) support_labels_one_hot = support_labels_one_hot.reshape( n_way * tasks_per_batch, n_support) # (n_way*tasks_per_batch, n_support) G = block_kernel_matrix e = -2.0 * support_labels_one_hot # This is a fake inequlity constraint as qpth does not support QP without an inequality constraint. id_matrix_1 = torch.zeros(tasks_per_batch * n_way, n_support, n_support) C = Variable(id_matrix_1) h = Variable(torch.zeros((tasks_per_batch * n_way, n_support))) dummy = Variable( torch.Tensor()).cuda() # We want to ignore the equality constraint. if double_precision: G, e, C, h = [x.double().cuda() for x in [G, e, C, h]] else: G, e, C, h = [x.float().cuda() for x in [G, e, C, h]] # Solve the following QP to fit SVM: # \hat z = argmin_z 1/2 z^T G z + e^T z # subject to Cz <= h # We use detach() to prevent backpropagation to fixed variables. qp_sol = QPFunction(verbose=False)(G, e.detach(), C.detach(), h.detach(), dummy.detach(), dummy.detach()) # qp_sol = QPFunction(verbose=False)(G, e.detach(), dummy.detach(), dummy.detach(), dummy.detach(), dummy.detach()) # qp_sol (n_way*tasks_per_batch, n_support) qp_sol = qp_sol.reshape(n_way, tasks_per_batch, n_support) # qp_sol (n_way, tasks_per_batch, n_support) qp_sol = qp_sol.permute(1, 2, 0) # qp_sol (tasks_per_batch, n_support, n_way) # Compute the classification score. compatibility = computeGramMatrix(support, query) compatibility = compatibility.float() compatibility = compatibility.unsqueeze(3).expand(tasks_per_batch, n_support, n_query, n_way) qp_sol = qp_sol.reshape(tasks_per_batch, n_support, n_way) logits = qp_sol.float().unsqueeze(2).expand(tasks_per_batch, n_support, n_query, n_way) logits = logits * compatibility logits = torch.sum(logits, 1) return logits
def MetaOptNetHead_Ridge(query, support, support_labels, n_way, n_shot, device, lambda_reg=100.0, double_precision=False): tasks_per_batch = query.size(0) n_support = support.size(1) n_query = query.size(1) assert (query.dim() == 3) assert (support.dim() == 3) assert (query.size(0) == support.size(0) and query.size(2) == support.size(2)) assert (n_support == n_way * n_shot ) # n_support must equal to n_way * n_shot kernel_matrix = computeGramMatrix(support, support) kernel_matrix += lambda_reg * torch.eye(n_support).expand( tasks_per_batch, n_support, n_support).cuda() block_kernel_matrix = kernel_matrix.repeat( n_way, 1, 1) #(n_way * tasks_per_batch, n_support, n_support) support_labels_one_hot = one_hot( support_labels.view(tasks_per_batch * n_support), n_way, device) # (tasks_per_batch * n_support, n_way) support_labels_one_hot = support_labels_one_hot.transpose( 0, 1) # (n_way, tasks_per_batch * n_support) support_labels_one_hot = support_labels_one_hot.reshape( n_way * tasks_per_batch, n_support) # (n_way*tasks_per_batch, n_support) G = block_kernel_matrix e = -2.0 * support_labels_one_hot #This is a fake inequlity constraint as qpth does not support QP without an inequality constraint. id_matrix_1 = torch.zeros(tasks_per_batch * n_way, n_support, n_support) C = Variable(id_matrix_1) h = Variable(torch.zeros((tasks_per_batch * n_way, n_support))) dummy = Variable(torch.Tensor()).cuda() if double_precision: G, e, C, h = [x.double().cuda() for x in [G, e, C, h]] else: G, e, C, h = [x.float().cuda() for x in [G, e, C, h]] qp_sol = QPFunction(verbose=False)(G, e.detach(), C.detach(), h.detach(), dummy.detach(), dummy.detach()) #qp_sol = QPFunction(verbose=False)(G, e.detach(), dummy.detach(), dummy.detach(), dummy.detach(), dummy.detach()) #qp_sol (n_way*tasks_per_batch, n_support) qp_sol = qp_sol.reshape(n_way, tasks_per_batch, n_support) #qp_sol (n_way, tasks_per_batch, n_support) qp_sol = qp_sol.permute(1, 2, 0) #qp_sol (tasks_per_batch, n_support, n_way) # Compute the classification score. compatibility = computeGramMatrix(support, query) compatibility = compatibility.float() compatibility = compatibility.unsqueeze(3).expand(tasks_per_batch, n_support, n_query, n_way) qp_sol = qp_sol.reshape(tasks_per_batch, n_support, n_way) logits = qp_sol.float().unsqueeze(2).expand(tasks_per_batch, n_support, n_query, n_way) logits = logits * compatibility logits = torch.sum(logits, 1) return logits
def inner_loop_adapt(self, support, support_labels, query, n_way, n_shot, n_query): """ Fits the support set with ridge regression and returns the classification score on the query set. Parameters: query: a (n_tasks_per_batch, n_query, c, h, w) Tensor. support: a (n_tasks_per_batch, n_support, c, h, w) Tensor. support_labels: a (tasks_per_batch, n_support) Tensor. lambda_reg: a scalar. Represents the strength of L2 regularization. Returns: a (tasks_per_batch, n_query, n_way) Tensor. """ measurements_trajectory = defaultdict(list) assert (query.dim() == 5) assert (support.dim() == 5) # get features orig_query_shape = query.shape orig_support_shape = support.shape support = self._model(support.reshape(-1, *orig_support_shape[2:]), only_features=True).reshape( *orig_support_shape[:2], -1) query = self._model(query.reshape(-1, *orig_query_shape[2:]), only_features=True).reshape( *orig_query_shape[:2], -1) lambda_reg = self._lambda_reg double_precision = self._double_precision tasks_per_batch = query.size(0) total_n_support = support.size( 1) # support samples across all classes in a task total_n_query = query.size( 1) # query samples across all classes in a task d = query.size(2) # dimension assert (query.dim() == 3) assert (support.dim() == 3) assert (query.size(0) == support.size(0) and query.size(2) == support.size(2)) assert (total_n_support == n_way * n_shot ) # total_n_support must equal to n_way * n_shot assert (total_n_query == n_way * n_query ) # total_n_support must equal to n_way * n_shot #Here we solve the dual problem: #Note that the classes are indexed by m & samples are indexed by i. #min_{\alpha} 0.5 \sum_m ||w_m(\alpha)||^2 + \sum_i \sum_m e^m_i alpha^m_i #where w_m(\alpha) = \sum_i \alpha^m_i x_i, #\alpha is an (total_n_support, n_way) matrix kernel_matrix = computeGramMatrix(support, support) kernel_matrix += lambda_reg * torch.eye(total_n_support).expand( tasks_per_batch, total_n_support, total_n_support).cuda() block_kernel_matrix = kernel_matrix.repeat( n_way, 1, 1) #(n_way * tasks_per_batch, total_n_support, total_n_support) support_labels_one_hot = one_hot( support_labels.view(tasks_per_batch * total_n_support), n_way) # (tasks_per_batch * total_n_support, n_way) support_labels_one_hot = support_labels_one_hot.transpose( 0, 1) # (n_way, tasks_per_batch * total_n_support) support_labels_one_hot = support_labels_one_hot.reshape( n_way * tasks_per_batch, total_n_support) # (n_way*tasks_per_batch, total_n_support) G = block_kernel_matrix e = -2.0 * support_labels_one_hot #This is a fake inequlity constraint as qpth does not support QP without an inequality constraint. id_matrix_1 = torch.zeros(tasks_per_batch * n_way, total_n_support, total_n_support) C = Variable(id_matrix_1) h = Variable(torch.zeros((tasks_per_batch * n_way, total_n_support))) dummy = Variable(torch.Tensor()).cuda( ) # We want to ignore the equality constraint. if double_precision: G, e, C, h = [x.double().cuda() for x in [G, e, C, h]] else: G, e, C, h = [x.float().cuda() for x in [G, e, C, h]] # Solve the following QP to fit SVM: # \hat z = argmin_z 1/2 z^T G z + e^T z # subject to Cz <= h # We use detach() to prevent backpropagation to fixed variables. qp_sol = QPFunction(verbose=False)(G, e.detach(), C.detach(), h.detach(), dummy.detach(), dummy.detach()) #qp_sol = QPFunction(verbose=False)(G, e.detach(), dummy.detach(), dummy.detach(), dummy.detach(), dummy.detach()) #qp_sol (n_way*tasks_per_batch, total_n_support) qp_sol = qp_sol.reshape(n_way, tasks_per_batch, total_n_support) #qp_sol (n_way, tasks_per_batch, total_n_support) qp_sol = qp_sol.permute(1, 2, 0) #qp_sol (tasks_per_batch, total_n_support, n_way) # Compute the classification score. compatibility = computeGramMatrix(support, query) compatibility = compatibility.float() compatibility = compatibility.unsqueeze(3).expand( tasks_per_batch, total_n_support, total_n_query, n_way) qp_sol = qp_sol.reshape(tasks_per_batch, total_n_support, n_way) logits = qp_sol.float().unsqueeze(2).expand(tasks_per_batch, total_n_support, total_n_query, n_way) logits = logits * compatibility logits = torch.sum(logits, 1) * self._scale # compute loss and acc on support with torch.no_grad(): compatibility = computeGramMatrix(support, support) compatibility = compatibility.float() compatibility = compatibility.unsqueeze(3).expand( tasks_per_batch, total_n_support, total_n_support, n_way) logits_support = qp_sol.float().unsqueeze(2).expand( tasks_per_batch, total_n_support, total_n_support, n_way) logits_support = logits_support * compatibility logits_support = torch.sum(logits_support, 1) logits_support = logits_support.reshape( -1, logits_support.size(-1)) * self._scale loss = self._inner_loss_func(logits_support, support_labels.reshape(-1)) accu = accuracy(logits_support, support_labels.reshape(-1)) * 100. measurements_trajectory['loss'].append(loss.item()) measurements_trajectory['accu'].append(accu) return logits, measurements_trajectory
def inner_loop_adapt(self, support, support_labels, query, n_way, n_shot, n_query): """ Fits the support set with multi-class SVM and returns the classification score on the query set. This is the multi-class SVM presented in: On the Algorithmic Implementation of Multiclass Kernel-based Vector Machines (Crammer and Singer, Journal of Machine Learning Research 2001). This model is the classification head that we use for the final version. Parameters: query: a (tasks_per_batch, n_query, c, h, w) Tensor. support: a (tasks_per_batch, n_support, c, h, w) Tensor. support_labels: a (tasks_per_batch, n_support) Tensor. n_way: a scalar. Represents the number of classes in a few-shot classification task. n_shot: a scalar. Represents the number of support examples given per class. C_reg: a scalar. Represents the cost parameter C in SVM. Returns: a (tasks_per_batch, n_query, n_way) Tensor. """ measurements_trajectory = defaultdict(list) assert (query.dim() == 5) assert (support.dim() == 5) # get features orig_query_shape = query.shape orig_support_shape = support.shape support = self._model(support.reshape(-1, *orig_support_shape[2:]), only_features=True).reshape( *orig_support_shape[:2], -1) query = self._model(query.reshape(-1, *orig_query_shape[2:]), only_features=True).reshape( *orig_query_shape[:2], -1) tasks_per_batch = query.size(0) total_n_support = support.size( 1) # support samples across all classes in a task total_n_query = query.size( 1) # query samples across all classes in a task d = query.size(2) # dimension C_reg = self._C_reg maxIter = self._max_iter assert (query.dim() == 3) assert (support.dim() == 3) assert (query.size(0) == support.size(0) and query.size(2) == support.size(2)) assert (total_n_support == n_way * n_shot ) # total_n_support must equal to n_way * n_shot assert (total_n_query == n_way * n_query ) # total_n_query must equal to n_way * n_query #Here we solve the dual problem: #Note that the classes are indexed by m & samples are indexed by i. #min_{\alpha} 0.5 \sum_m ||w_m(\alpha)||^2 + \sum_i \sum_m e^m_i alpha^m_i #s.t. \alpha^m_i <= C^m_i \forall m,i , \sum_m \alpha^m_i=0 \forall i #where w_m(\alpha) = \sum_i \alpha^m_i x_i, #and C^m_i = C if m = y_i, #C^m_i = 0 if m != y_i. #This borrows the notation of liblinear. #\alpha is an (total_n_support, n_way) matrix kernel_matrix = computeGramMatrix(support, support) id_matrix_0 = torch.eye(n_way).expand(tasks_per_batch, n_way, n_way).cuda() block_kernel_matrix = batched_kronecker(kernel_matrix, id_matrix_0) #This seems to help avoid PSD error from the QP solver. block_kernel_matrix += 1.0 * torch.eye(n_way * total_n_support).expand( tasks_per_batch, n_way * total_n_support, n_way * total_n_support).cuda() support_labels_one_hot = one_hot( support_labels.view(tasks_per_batch * total_n_support), n_way) # (tasks_per_batch * total_n_support, n_way) support_labels_one_hot = support_labels_one_hot.view( tasks_per_batch, total_n_support, n_way) support_labels_one_hot = support_labels_one_hot.reshape( tasks_per_batch, total_n_support * n_way) # (tasks_per_batch, total_n_support * n_way) G = block_kernel_matrix e = -1.0 * support_labels_one_hot #This part is for the inequality constraints: #\alpha^m_i <= C^m_i \forall m,i #where C^m_i = C if m = y_i, #C^m_i = 0 if m != y_i. id_matrix_1 = torch.eye(n_way * total_n_support).expand( tasks_per_batch, n_way * total_n_support, n_way * total_n_support) C = Variable(id_matrix_1) h = Variable(C_reg * support_labels_one_hot) #print (C.size(), h.size()) #This part is for the equality constraints: #\sum_m \alpha^m_i=0 \forall i id_matrix_2 = torch.eye(total_n_support).expand( tasks_per_batch, total_n_support, total_n_support).cuda() A = Variable( batched_kronecker(id_matrix_2, torch.ones(tasks_per_batch, 1, n_way).cuda())) b = Variable(torch.zeros(tasks_per_batch, total_n_support)) if self._double_precision: G, e, C, h, A, b = [x.double().cuda() for x in [G, e, C, h, A, b]] else: G, e, C, h, A, b = [x.float().cuda() for x in [G, e, C, h, A, b]] # Solve the following QP to fit SVM: # \hat z = argmin_z 1/2 z^T G z + e^T z # subject to Cz <= h # We use detach() to prevent backpropagation to fixed variables. qp_sol = QPFunction(verbose=False, maxIter=maxIter)(G, e.detach(), C.detach(), h.detach(), A.detach(), b.detach()) # G is not detached, that is the only one that needs gradients, since its a function of phi(x). qp_sol = qp_sol.reshape(tasks_per_batch, total_n_support, n_way) # Compute the classification score for query. compatibility_query = computeGramMatrix(support, query) compatibility_query = compatibility_query.float() compatibility_query = compatibility_query.unsqueeze(3).expand( tasks_per_batch, total_n_support, total_n_query, n_way) logits_query = qp_sol.float().unsqueeze(2).expand( tasks_per_batch, total_n_support, total_n_query, n_way) logits_query = logits_query * compatibility_query logits_query = torch.sum(logits_query, 1) * self._scale # Compute the classification score for support. with torch.no_grad(): compatibility_support = computeGramMatrix(support, support) compatibility_support = compatibility_support.float() compatibility_support = compatibility_support.unsqueeze(3).expand( tasks_per_batch, total_n_support, total_n_support, n_way) logits_support = qp_sol.float().unsqueeze(2).expand( tasks_per_batch, total_n_support, total_n_support, n_way) logits_support = logits_support * compatibility_support logits_support = torch.sum(logits_support, 1) * self._scale # compute loss and acc on support logits_support = logits_support.reshape(-1, logits_support.size(-1)) labels_support = support_labels.reshape(-1) loss = self._inner_loss_func(logits_support, labels_support) accu = accuracy(logits_support, labels_support) measurements_trajectory['loss'].append(loss.item()) measurements_trajectory['accu'].append(accu) return logits_query, measurements_trajectory
def CMloss_Fnorm(self, query, support, support_labels, n_way, n_shot): tasks_per_batch = query.size(0) n_support = support.size(1) n_query = query.size(1) assert (query.dim() == 3) assert (support.dim() == 3) assert (query.size(0) == support.size(0) and query.size(2) == support.size(2)) assert (n_support == n_way * n_shot) # n_support must equal to n_way * n_shot # Here we solve the dual problem: # Note that the classes are indexed by m & samples are indexed by i. # min_{\alpha} 0.5 \sum_m ||w_m(\alpha)||^2 + \sum_i \sum_m e^m_i alpha^m_i # s.t. \alpha^m_i <= C^m_i \forall m,i , \sum_m \alpha^m_i=0 \forall i # where w_m(\alpha) = \sum_i \alpha^m_i x_i, # and C^m_i = C if m = y_i, # C^m_i = 0 if m != y_i. # This borrows the notation of liblinear. # \alpha is an (n_support, n_way) matrix kernel_matrix = computeGramMatrix(support, support) id_matrix_0 = torch.eye(n_way).expand(tasks_per_batch, n_way, n_way).cuda() block_kernel_matrix = batched_kronecker(kernel_matrix, id_matrix_0) # This seems to help avoid PSD error from the QP solver. block_kernel_matrix += 1.0 * torch.eye(n_way * n_support).expand(tasks_per_batch, n_way * n_support, n_way * n_support).cuda() support_labels_one_hot = one_hot(support_labels.view(tasks_per_batch * n_support), n_way) # (tasks_per_batch * n_support, n_support) support_labels_one_hot = support_labels_one_hot.view(tasks_per_batch, n_support, n_way) support_labels_one_hot = support_labels_one_hot.reshape(tasks_per_batch, n_support * n_way) G = block_kernel_matrix e = -1.0 * support_labels_one_hot # print (G.size()) # This part is for the inequality constraints: # \alpha^m_i <= C^m_i \forall m,i # where C^m_i = C if m = y_i, # C^m_i = 0 if m != y_i. id_matrix_1 = torch.eye(n_way * n_support).expand(tasks_per_batch, n_way * n_support, n_way * n_support) C = Variable(id_matrix_1) h = Variable(self.C_reg * support_labels_one_hot) # print (C.size(), h.size()) # This part is for the equality constraints: # \sum_m \alpha^m_i=0 \forall i id_matrix_2 = torch.eye(n_support).expand(tasks_per_batch, n_support, n_support).cuda() A = Variable(batched_kronecker(id_matrix_2, torch.ones(tasks_per_batch, 1, n_way).cuda())) b = Variable(torch.zeros(tasks_per_batch, n_support)) # print (A.size(), b.size()) if self.double_precision: G, e, C, h, A, b = [x.double().cuda() for x in [G, e, C, h, A, b]] else: G, e, C, h, A, b = [x.float().cuda() for x in [G, e, C, h, A, b]] # Solve the following QP to fit SVM: # \hat z = argmin_z 1/2 z^T G z + e^T z # subject to Cz <= h # We use detach() to prevent backpropagation to fixed variables. qp_sol = QPFunction(verbose=False, maxIter=self.maxIter)(G, e.detach(), C.detach(), h.detach(), A.detach(), b.detach()) # Compute the classification score. qp_sol = qp_sol.reshape(tasks_per_batch, n_support, n_way) logits = qp_sol.float().unsqueeze(2).expand(tasks_per_batch, n_support, n_query, n_way) w_query = torch.bmm(qp_sol.transpose(1,2), support) dis_matrix = self.support_w - w_query revese_loss = 0 for i in range(tasks_per_batch): revese_loss += (torch.trace(torch.mm(dis_matrix[i,:,:], dis_matrix[i,:,:].t()))).sqrt() revese_loss = 1.0 * revese_loss / (n_way* tasks_per_batch) return revese_loss
def MetaOptNetHead_OC_SVM_batched(query, support, support_labels, n_way, n_shot, nu=0.1, double_precision=True, maxIter=40): """ Fits the support set with OC-SVM and returns the classification score on the query set. Parameters: query: a (tasks_per_batch, n_query, d) Tensor. support: a (tasks_per_batch, n_support, d) Tensor. support_labels: a (tasks_per_batch, n_support) Tensor. n_way: a scalar. Represents the number of classes in a few-shot classification task. n_shot: a scalar. Represents the number of support examples given per class. C_reg: a scalar. Represents the cost parameter C in SVM. Returns: a (tasks_per_batch, n_query, n_way) Tensor. """ tasks_per_batch = query.size(0) n_support = support.size(1) n_query = query.size(1) assert (query.dim() == 3) assert (support.dim() == 3) assert (query.size(0) == support.size(0) and query.size(2) == support.size(2)) assert (n_support == n_shot) # n_support must equal to n_shot # we solve the dual problem # OCC n_way = 1 kernel_matrix = computeGramMatrix(support, support) Q = kernel_matrix.cuda() # #This seems to help avoid PSD error from the QP solver. (as done in the original MetaOptNet with SVM head) Q_spd = 1.0 * torch.eye(n_support).expand(tasks_per_batch, n_support, n_support).cuda() Q += Q_spd p = Variable(torch.zeros((tasks_per_batch, n_support))) A = Variable(torch.ones((tasks_per_batch, 1, n_support))) b = Variable(torch.ones(tasks_per_batch, 1)) G = Variable( torch.cat([-1.0 * torch.eye(n_support), torch.eye(n_support)], dim=0).expand(tasks_per_batch, 2 * n_support, n_support)) h_task = torch.cat([ torch.zeros((n_support, 1)).cuda(), (1 / (nu * n_support)) * torch.ones((n_support, 1)).cuda() ], dim=0) h = torch.squeeze( Variable(h_task.expand(tasks_per_batch, 2 * n_support, 1))) if double_precision: Q, p, G, h, A, b = [x.double().cuda() for x in [Q, p, G, h, A, b]] else: Q, p, G, h, A, b = [x.float().cuda() for x in [Q, p, G, h, A, b]] qp_sol = QPFunction(verbose=False, maxIter=maxIter)(Q, p.detach(), G.detach(), h.detach(), A.detach(), b.detach()) # Compute the classification score. if double_precision: qp_sol = qp_sol.float() w = torch.bmm(qp_sol.reshape((tasks_per_batch, 1, n_support)), support) logits_tmp = torch.bmm(query, w.transpose(1, 2)) logits = [] for i in range(tasks_per_batch): S = (qp_sol[i] > 1e-3).flatten() all_ro = torch.mm(support[i][S], w[i].transpose(0, 1)) ro = all_ro[0] logit = torch.squeeze(logits_tmp[i] - ro) logits.append(logit) logits = torch.stack(logits) return logits