def MetaOptNetHead_Ridge(query, support, support_labels, n_way, n_shot, device, lambda_reg=100.0, double_precision=False): tasks_per_batch = query.size(0) n_support = support.size(1) n_query = query.size(1) assert (query.dim() == 3) assert (support.dim() == 3) assert (query.size(0) == support.size(0) and query.size(2) == support.size(2)) assert (n_support == n_way * n_shot ) # n_support must equal to n_way * n_shot kernel_matrix = computeGramMatrix(support, support) kernel_matrix += lambda_reg * torch.eye(n_support).expand( tasks_per_batch, n_support, n_support).cuda() block_kernel_matrix = kernel_matrix.repeat( n_way, 1, 1) #(n_way * tasks_per_batch, n_support, n_support) support_labels_one_hot = one_hot( support_labels.view(tasks_per_batch * n_support), n_way, device) # (tasks_per_batch * n_support, n_way) support_labels_one_hot = support_labels_one_hot.transpose( 0, 1) # (n_way, tasks_per_batch * n_support) support_labels_one_hot = support_labels_one_hot.reshape( n_way * tasks_per_batch, n_support) # (n_way*tasks_per_batch, n_support) G = block_kernel_matrix e = -2.0 * support_labels_one_hot #This is a fake inequlity constraint as qpth does not support QP without an inequality constraint. id_matrix_1 = torch.zeros(tasks_per_batch * n_way, n_support, n_support) C = Variable(id_matrix_1) h = Variable(torch.zeros((tasks_per_batch * n_way, n_support))) dummy = Variable(torch.Tensor()).cuda() if double_precision: G, e, C, h = [x.double().cuda() for x in [G, e, C, h]] else: G, e, C, h = [x.float().cuda() for x in [G, e, C, h]] qp_sol = QPFunction(verbose=False)(G, e.detach(), C.detach(), h.detach(), dummy.detach(), dummy.detach()) #qp_sol = QPFunction(verbose=False)(G, e.detach(), dummy.detach(), dummy.detach(), dummy.detach(), dummy.detach()) #qp_sol (n_way*tasks_per_batch, n_support) qp_sol = qp_sol.reshape(n_way, tasks_per_batch, n_support) #qp_sol (n_way, tasks_per_batch, n_support) qp_sol = qp_sol.permute(1, 2, 0) #qp_sol (tasks_per_batch, n_support, n_way) # Compute the classification score. compatibility = computeGramMatrix(support, query) compatibility = compatibility.float() compatibility = compatibility.unsqueeze(3).expand(tasks_per_batch, n_support, n_query, n_way) qp_sol = qp_sol.reshape(tasks_per_batch, n_support, n_way) logits = qp_sol.float().unsqueeze(2).expand(tasks_per_batch, n_support, n_query, n_way) logits = logits * compatibility logits = torch.sum(logits, 1) return logits
def MetaOptNetHead_Ridge(query, support, support_labels, n_way, n_shot, lambda_reg=50.0, double_precision=False): """ Fits the support set with ridge regression and returns the classification score on the query set. Parameters: query: a (tasks_per_batch, n_query, d) Tensor. support: a (tasks_per_batch, n_support, d) Tensor. support_labels: a (tasks_per_batch, n_support) Tensor. n_way: a scalar. Represents the number of classes in a few-shot classification task. n_shot: a scalar. Represents the number of support examples given per class. lambda_reg: a scalar. Represents the strength of L2 regularization. Returns: a (tasks_per_batch, n_query, n_way) Tensor. """ tasks_per_batch = query.size(0) n_support = support.size(1) n_query = query.size(1) assert (query.dim() == 3) assert (support.dim() == 3) assert (query.size(0) == support.size(0) and query.size(2) == support.size(2)) assert (n_support == n_way * n_shot ) # n_support must equal to n_way * n_shot # Here we solve the dual problem: # Note that the classes are indexed by m & samples are indexed by i. # min_{\alpha} 0.5 \sum_m ||w_m(\alpha)||^2 + \sum_i \sum_m e^m_i alpha^m_i # where w_m(\alpha) = \sum_i \alpha^m_i x_i, # \alpha is an (n_support, n_way) matrix kernel_matrix = computeGramMatrix(support, support) kernel_matrix += lambda_reg * torch.eye(n_support).expand( tasks_per_batch, n_support, n_support).cuda() block_kernel_matrix = kernel_matrix.repeat( n_way, 1, 1) # (n_way * tasks_per_batch, n_support, n_support) support_labels_one_hot = one_hot( support_labels.view(tasks_per_batch * n_support), n_way) # (tasks_per_batch * n_support, n_way) support_labels_one_hot = support_labels_one_hot.transpose( 0, 1) # (n_way, tasks_per_batch * n_support) support_labels_one_hot = support_labels_one_hot.reshape( n_way * tasks_per_batch, n_support) # (n_way*tasks_per_batch, n_support) G = block_kernel_matrix e = -2.0 * support_labels_one_hot # This is a fake inequlity constraint as qpth does not support QP without an inequality constraint. id_matrix_1 = torch.zeros(tasks_per_batch * n_way, n_support, n_support) C = Variable(id_matrix_1) h = Variable(torch.zeros((tasks_per_batch * n_way, n_support))) dummy = Variable( torch.Tensor()).cuda() # We want to ignore the equality constraint. if double_precision: G, e, C, h = [x.double().cuda() for x in [G, e, C, h]] else: G, e, C, h = [x.float().cuda() for x in [G, e, C, h]] # Solve the following QP to fit SVM: # \hat z = argmin_z 1/2 z^T G z + e^T z # subject to Cz <= h # We use detach() to prevent backpropagation to fixed variables. qp_sol = QPFunction(verbose=False)(G, e.detach(), C.detach(), h.detach(), dummy.detach(), dummy.detach()) # qp_sol = QPFunction(verbose=False)(G, e.detach(), dummy.detach(), dummy.detach(), dummy.detach(), dummy.detach()) # qp_sol (n_way*tasks_per_batch, n_support) qp_sol = qp_sol.reshape(n_way, tasks_per_batch, n_support) # qp_sol (n_way, tasks_per_batch, n_support) qp_sol = qp_sol.permute(1, 2, 0) # qp_sol (tasks_per_batch, n_support, n_way) # Compute the classification score. compatibility = computeGramMatrix(support, query) compatibility = compatibility.float() compatibility = compatibility.unsqueeze(3).expand(tasks_per_batch, n_support, n_query, n_way) qp_sol = qp_sol.reshape(tasks_per_batch, n_support, n_way) logits = qp_sol.float().unsqueeze(2).expand(tasks_per_batch, n_support, n_query, n_way) logits = logits * compatibility logits = torch.sum(logits, 1) return logits
def inner_loop_adapt(self, support, support_labels, query, n_way, n_shot, n_query): """ Fits the support set with ridge regression and returns the classification score on the query set. Parameters: query: a (n_tasks_per_batch, n_query, c, h, w) Tensor. support: a (n_tasks_per_batch, n_support, c, h, w) Tensor. support_labels: a (tasks_per_batch, n_support) Tensor. lambda_reg: a scalar. Represents the strength of L2 regularization. Returns: a (tasks_per_batch, n_query, n_way) Tensor. """ measurements_trajectory = defaultdict(list) assert (query.dim() == 5) assert (support.dim() == 5) # get features orig_query_shape = query.shape orig_support_shape = support.shape support = self._model(support.reshape(-1, *orig_support_shape[2:]), only_features=True).reshape( *orig_support_shape[:2], -1) query = self._model(query.reshape(-1, *orig_query_shape[2:]), only_features=True).reshape( *orig_query_shape[:2], -1) lambda_reg = self._lambda_reg double_precision = self._double_precision tasks_per_batch = query.size(0) total_n_support = support.size( 1) # support samples across all classes in a task total_n_query = query.size( 1) # query samples across all classes in a task d = query.size(2) # dimension assert (query.dim() == 3) assert (support.dim() == 3) assert (query.size(0) == support.size(0) and query.size(2) == support.size(2)) assert (total_n_support == n_way * n_shot ) # total_n_support must equal to n_way * n_shot assert (total_n_query == n_way * n_query ) # total_n_support must equal to n_way * n_shot #Here we solve the dual problem: #Note that the classes are indexed by m & samples are indexed by i. #min_{\alpha} 0.5 \sum_m ||w_m(\alpha)||^2 + \sum_i \sum_m e^m_i alpha^m_i #where w_m(\alpha) = \sum_i \alpha^m_i x_i, #\alpha is an (total_n_support, n_way) matrix kernel_matrix = computeGramMatrix(support, support) kernel_matrix += lambda_reg * torch.eye(total_n_support).expand( tasks_per_batch, total_n_support, total_n_support).cuda() block_kernel_matrix = kernel_matrix.repeat( n_way, 1, 1) #(n_way * tasks_per_batch, total_n_support, total_n_support) support_labels_one_hot = one_hot( support_labels.view(tasks_per_batch * total_n_support), n_way) # (tasks_per_batch * total_n_support, n_way) support_labels_one_hot = support_labels_one_hot.transpose( 0, 1) # (n_way, tasks_per_batch * total_n_support) support_labels_one_hot = support_labels_one_hot.reshape( n_way * tasks_per_batch, total_n_support) # (n_way*tasks_per_batch, total_n_support) G = block_kernel_matrix e = -2.0 * support_labels_one_hot #This is a fake inequlity constraint as qpth does not support QP without an inequality constraint. id_matrix_1 = torch.zeros(tasks_per_batch * n_way, total_n_support, total_n_support) C = Variable(id_matrix_1) h = Variable(torch.zeros((tasks_per_batch * n_way, total_n_support))) dummy = Variable(torch.Tensor()).cuda( ) # We want to ignore the equality constraint. if double_precision: G, e, C, h = [x.double().cuda() for x in [G, e, C, h]] else: G, e, C, h = [x.float().cuda() for x in [G, e, C, h]] # Solve the following QP to fit SVM: # \hat z = argmin_z 1/2 z^T G z + e^T z # subject to Cz <= h # We use detach() to prevent backpropagation to fixed variables. qp_sol = QPFunction(verbose=False)(G, e.detach(), C.detach(), h.detach(), dummy.detach(), dummy.detach()) #qp_sol = QPFunction(verbose=False)(G, e.detach(), dummy.detach(), dummy.detach(), dummy.detach(), dummy.detach()) #qp_sol (n_way*tasks_per_batch, total_n_support) qp_sol = qp_sol.reshape(n_way, tasks_per_batch, total_n_support) #qp_sol (n_way, tasks_per_batch, total_n_support) qp_sol = qp_sol.permute(1, 2, 0) #qp_sol (tasks_per_batch, total_n_support, n_way) # Compute the classification score. compatibility = computeGramMatrix(support, query) compatibility = compatibility.float() compatibility = compatibility.unsqueeze(3).expand( tasks_per_batch, total_n_support, total_n_query, n_way) qp_sol = qp_sol.reshape(tasks_per_batch, total_n_support, n_way) logits = qp_sol.float().unsqueeze(2).expand(tasks_per_batch, total_n_support, total_n_query, n_way) logits = logits * compatibility logits = torch.sum(logits, 1) * self._scale # compute loss and acc on support with torch.no_grad(): compatibility = computeGramMatrix(support, support) compatibility = compatibility.float() compatibility = compatibility.unsqueeze(3).expand( tasks_per_batch, total_n_support, total_n_support, n_way) logits_support = qp_sol.float().unsqueeze(2).expand( tasks_per_batch, total_n_support, total_n_support, n_way) logits_support = logits_support * compatibility logits_support = torch.sum(logits_support, 1) logits_support = logits_support.reshape( -1, logits_support.size(-1)) * self._scale loss = self._inner_loss_func(logits_support, support_labels.reshape(-1)) accu = accuracy(logits_support, support_labels.reshape(-1)) * 100. measurements_trajectory['loss'].append(loss.item()) measurements_trajectory['accu'].append(accu) return logits, measurements_trajectory