def forward(self, x): nBatch = x.size(0) # FC-ReLU-(BN)-FC-ReLU-(BN)-QP-Softmax x = x.view(nBatch, -1) x = F.relu(self.fc1(x)) if self.bn: x = self.bn1(x) x = F.relu(self.fc2(x)) if self.bn: x = self.bn2(x) L = self.M * self.L Q = L.mm(L.t()) + self.eps * Variable(torch.eye(self.nCls)).cuda() Q = Q.unsqueeze(0).expand(nBatch, self.nCls, self.nCls) G = self.G.unsqueeze(0).expand(nBatch, self.nineq, self.nCls) z0 = self.qp_z0(x) s0 = self.qp_s0(x) h = z0.mm(self.G.t()) + s0 e = Variable(torch.Tensor()) inputs = x x = QPFunction()(inputs.double(), Q.double(), G.double(), h.double(), e, e) x = x.float() # x = x[:,:10].float() return F.log_softmax(x)
def forward(self, inputs): shot = inputs.shape[1] kernel_matrices = torch.bmm(inputs, inputs.transpose(1, 2)) kernel_matrices += self._eps * torch.eye(shot) kernel_diags = torch.diagonal(kernel_matrices, dim1=-2, dim2=-1) Q = 2 * kernel_matrices p = -kernel_diags A = torch.ones(1, shot) b = torch.ones(1) G = -torch.eye(shot) h = torch.zeros(shot) alphas = QPFunction(verbose=False)( Q, p, G.detach(), h.detach(), A.detach(), b.detach(), ) alphas = alphas.unsqueeze(-1) centers = torch.sum(alphas * inputs, dim=self._dim) # `keepdim=True` here to avoid unsqueezing in `CentersDistance`, which # could be used for vanilla protonet? return centers
def projF(x): nBatch = x.size(0) Q = self.Q.unsqueeze(0).expand(nBatch, nCls, nCls) G = self.G.unsqueeze(0).expand(nBatch, nCls, nCls) h = self.h.unsqueeze(0).expand(nBatch, nCls) A = self.A.unsqueeze(0).expand(nBatch, 1, nCls) b = self.b.unsqueeze(0).expand(nBatch, 1) x = QPFunction()(-x.double(), Q, G, h, A, b).float() x = x.log() return x
def project_action(self, u, x): Px = x @ self.P G = 2 * Px.expand(self.B.shape[0], Px.shape[0], Px.shape[1]).bmm(self.B).transpose(0, 1) h = (Px * x).sum(-1).unsqueeze(1) + \ 2 * Px.expand(self.B.shape[0], Px.shape[0], Px.shape[1]).bmm(self.A).transpose(0, 1).matmul( x.unsqueeze(2)).squeeze(2) Q = torch.eye(u.shape[-1], device=x.device).unsqueeze(0).expand(u.shape[0], u.shape[-1], u.shape[-1]) res = QPFunction(verbose=-1)(Q.double(), -u.double(), G.double(), -h.double(), self.e, self.e) return res.type(TORCH_DTYPE)
def forward(self, puzzles): nBatch = puzzles.size(0) x = puzzles.view(nBatch,-1) x = self.fc_in(x) e = Variable(torch.Tensor()) h = self.G.mv(self.z)+self.s x = QPFunction(verbose=False)( self.Q, x, self.G, h, e, e, ) x = self.fc_out(x) x = x.view_as(puzzles) return x
def svm_logits(query, support, labels, ways, shots, C_reg=0.1, max_iters=15): num_support = support.size(0) num_query = query.size(0) device = support.device kernel = support @ support.t() I_ways = torch.eye(ways).to(device) block_kernel = kronecker(kernel, I_ways) block_kernel.add_(torch.eye(ways * num_support, device=device)) labels_onehot = onehot(labels, dim=ways).view(1, -1).to(device) I_sw = torch.eye(num_support * ways, device=device) I_s = torch.eye(num_support, device=device) h = C_reg * labels_onehot A = kronecker(I_s, torch.ones(1, ways, device=device)) b = torch.zeros(1, num_support, device=device) qp = QPFunction(verbose=False, maxIter=max_iters) qp_solution = qp(block_kernel, -labels_onehot, I_sw, h, A, b) qp_solution = qp_solution.reshape(num_support, ways) qp_solution = qp_solution.unsqueeze(1).expand(num_support, num_query, ways) compatibility = support @ query.t() compatibility = compatibility.unsqueeze(2).expand(num_support, num_query, ways) logits = qp_solution * compatibility logits = torch.sum(logits, dim=0) return logits
def fit_(self, support, labels, ways=None, C_reg=None, max_iters=None): if C_reg is None: C_reg = self.C_reg if max_iters is None: max_iters = self.max_iters if self._normalize: support = self.normalize(support) if ways is None: ways = len(torch.unique(labels)) num_support = support.size(0) device = support.device kernel = support @ support.t() I_ways = torch.eye(ways).to(device) block_kernel = kronecker(kernel, I_ways) block_kernel.add_(torch.eye(ways * num_support, device=device)) labels_onehot = onehot(labels, dim=ways).view(1, -1).to(device) I_sw = torch.eye(num_support * ways, device=device) I_s = torch.eye(num_support, device=device) h = C_reg * labels_onehot A = kronecker(I_s, torch.ones(1, ways, device=device)) b = torch.zeros(1, num_support, device=device) qp = QPFunction(verbose=False, maxIter=max_iters) qp_solution = qp(block_kernel, -labels_onehot, I_sw, h, A, b) self.qp_solution = qp_solution.reshape(num_support, ways) self.support = support self.num_support = num_support self.ways = ways
def forward(self, x): nBatch = x.size(0) # FC-ReLU-(BN)-FC-ReLU-(BN)-QP-Softmax x = x.view(nBatch, -1) x = x.unsqueeze(0) x = x.float() tmp = self.fc1(x) x = F.relu(tmp) x = x.squeeze(2) #if self.bn: #x = self.bn1(x) #x = F.relu(self.fc2(x)) #if self.bn: #x = self.bn2(x) L = self.M * self.L Q = L.mm(L.t()) + self.eps * Variable(torch.eye(self.nCls)).cuda() p = self.p.double() h = self.G.mv(self.z0) + self.s0 G = self.G.double() Q = Q.double() h = h.double() print(Q.size(), p.size(), G.size(), h.size()) e = Variable(torch.Tensor()) x = QPFunction(verbose=True)(Q, p, G, h, e, e).cuda() print(x) return F.log_softmax(x, dim=1)
def forward(self, puzzles): nBatch = puzzles.size(0) p = -puzzles.view(nBatch, -1) b = self.A.mv(self.log_z0.exp()) if self.qp_solver == 'qpth': y = QPFunction(verbose=-1)(self.Q, p.double(), self.G, self.h, self.A, b).float().view_as(puzzles) elif self.qp_solver == 'osqpth': _l = torch.cat((b, torch.full(self.h.shape, float('-inf'), device=self.h.device, dtype=self.h.dtype)), dim=0) _u = torch.cat((b, self.h), dim=0) Q_data = self.Q[self.Q_idx[0], self.Q_idx[1]] AG = torch.cat((self.A, self.G), dim=0) AG_data = AG[self.AG_idx[0], self.AG_idx[1]] y = OSQP(self.Q_idx, self.Q.shape, self.AG_idx, AG.shape, diff_mode=DiffModes.FULL)(Q_data, p.double(), AG_data, _l, _u).float().view_as(puzzles) else: assert False return y
def forward(self, puzzles): nBatch = puzzles.size(0) p = -puzzles.view(nBatch, -1) return QPFunction(verbose=-1)( self.Q, p.double(), self.G, self.h, self.A, self.b ).float().view_as(puzzles)
def forward(self, z0, mu, dg, d2g): nBatch, n = z0.size() Q = torch.cat([torch.diag(d2g[i] + 1).unsqueeze(0) for i in range(nBatch)], 0).double() p = (dg - d2g*z0 - mu).double() G = self.G.unsqueeze(0).expand(nBatch, self.G.size(0), self.G.size(1)) h = self.h.unsqueeze(0).expand(nBatch, self.h.size(0)) out = QPFunction(verbose=False)(p, Q, G, h, self.e, self.e) return out
def forward(self, puzzles): nBatch = puzzles.size(0) Q = self.Q.unsqueeze(0).expand(nBatch, self.Q.size(0), self.Q.size(1)) p = -puzzles.view(nBatch, -1) G = self.G.unsqueeze(0).expand(nBatch, self.G.size(0), self.G.size(1)) h = self.h.unsqueeze(0).expand(nBatch, self.h.size(0)) A = self.A.unsqueeze(0).expand(nBatch, self.A.size(0), self.A.size(1)) b = self.b.unsqueeze(0).expand(nBatch, self.b.size(0)) return QPFunction(verbose=False)(p.double(), Q, G, h, A, b).float().view_as(puzzles)
def forward(self, puzzles): nBatch = puzzles.size(0) p = -puzzles.view(nBatch, -1) h2 = self.G2.mv(self.z2) + self.s2 G = torch.cat((self.G1, self.G2), 0) h = torch.cat((self.h1, h2), 0) e = Variable(torch.Tensor()) return QPFunction(verbose=False)(self.Q, p.double(), G, h, e, e).float().view_as(puzzles)
def forward(self, x, Q, p, G, h, m): print("Cuda current device", torch.cuda.current_device()) nBatch = x.size(0) if (m > 1): p = p.float().t() else: p = p.float() G = G.float() #.cuda() Q = Q.float() if (m >= 2): Q = Q.unsqueeze(0) h = h.float() #print(Q.size(),p.size(),G.size(),h.size()) e = Variable(torch.Tensor(), requires_grad=True) x = QPFunction(verbose=True)(Q, p, G, h, e, e) #.cuda() x = x.view(10, -1) ##this was not needed earlier return F.log_softmax(x, dim=1)
def do_lipschitz_projection(self): """ Perform the Lipschitz projection step by solving the QP """ with torch.no_grad(): if self.QP == "qpth": # qpth library proj_coefficients = QPFunction(verbose=False)(nn.Parameter(self.Q), -2.0*self.coefficients, nn.Parameter(self.G), nn.Parameter(self.h), nn.Parameter(self.e), nn.Parameter(self.e)) self.coefficients_vect_.data = proj_coefficients.view(-1) elif self.QP == "cvxpy": # cvxpylayers library """ # row_wise verification proj_coefficients = torch.empty(self.coefficients.data.shape) for i in range(self.coefficients.data.shape[0]): proj_coefficient, = self.qp(-2.0*self.coefficients.data[i, :]) proj_coefficients[i, :] = proj_coefficient self.coefficients_vect_.data = proj_coefficients.view(-1) """ proj_coefficients, = self.qp(-2.0 * self.coefficients.data) self.coefficients_vect_.data = proj_coefficients.view(-1)
def forward(self, y): nBatch, k = y.size() Q_scale = torch.cat([torch.diag(torch.cat( [self.one, y[i], y[i]])).unsqueeze(0) for i in range(nBatch)], 0) Q = self.Q.unsqueeze(0).expand_as(Q_scale).mul(Q_scale) p_scale = torch.cat([Variable(torch.ones(nBatch,1).cuda()), y, y], 1) p = self.p.unsqueeze(0).expand_as(p_scale).mul(p_scale) G = self.G.unsqueeze(0).expand(nBatch, self.G.size(0), self.G.size(1)) h = self.h.unsqueeze(0).expand(nBatch, self.h.size(0)) e = Variable(torch.Tensor().cuda()).double() out = QPFunction(verbose=False)\ (p.double(), Q.double(), G.double(), h.double(), e, e).float() return out[:,:1]
def forward(self, x): nBatch = x.size(0) x = self.fc1(x) L = self.M*self.L Q = L.mm(L.t()) + self.args.eps*Variable(torch.eye(self.nHidden)).cuda() Q = Q.unsqueeze(0).expand(nBatch, self.nHidden, self.nHidden) G = self.G.unsqueeze(0).expand(nBatch, self.nineq, self.nHidden) h = self.G.mv(self.z0)+self.s0 h = h.unsqueeze(0).expand(nBatch, self.nineq) e = Variable(torch.Tensor()) x = QPFunction()(Q, x, G, h, e, e) x = x[:,:self.nFeatures] return x
def forward(self, x): nBatch = x.size(0) # FC-ReLU-QP-FC-Softmax x = x.view(nBatch, -1) x = F.relu(self.fc1(x)) Q = self.Q.unsqueeze(0).expand(nBatch, self.Q.size(0), self.Q.size(1)) p = -x.view(nBatch, -1) G = self.G.unsqueeze(0).expand(nBatch, self.G.size(0), self.G.size(1)) h = self.h.unsqueeze(0).expand(nBatch, self.h.size(0)) A = self.A.unsqueeze(0).expand(nBatch, self.A.size(0), self.A.size(1)) b = self.b.unsqueeze(0).expand(nBatch, self.b.size(0)) x = QPFunction(verbose=False)(p.double(), Q, G, h, A, b).float() x = self.fc2(x) return F.log_softmax(x)
def forward(self, x): nBatch = x.size(0) # FC-ReLU-(BN)-FC-ReLU-(BN)-QP-Softmax x = x.view(nBatch, -1) x = F.relu(self.fc1(x)) if self.bn: x = self.bn1(x) x = F.relu(self.fc2(x)) if self.bn: x = self.bn2(x) L = self.M * self.L Q = L.mm(L.t()) + self.eps * Variable(torch.eye(self.nCls)).cuda() h = self.G.mv(self.z0) + self.s0 e = Variable(torch.Tensor()) x = QPFunction(verbose=False)(Q, x, G, h, e, e) return F.log_softmax(x)
def forward(self, x): nBatch = x.size(0) x = F.max_pool2d(self.conv1(x), 2) x = F.max_pool2d(self.conv2(x), 2) x = x.view(nBatch, -1) L = self.M * self.L Q = L.mm(L.t()) + self.eps * Variable(torch.eye(self.nHidden)).cuda() Q = Q.unsqueeze(0).expand(nBatch, self.nHidden, self.nHidden) G = self.G.unsqueeze(0).expand(nBatch, self.nineq, self.nHidden) z0 = self.qp_z0(x) s0 = self.qp_s0(x) h = z0.mm(self.G.t()) + s0 e = Variable(torch.Tensor()) inputs = self.qp_o(x) x = QPFunction()(inputs, Q, G, h, e, e) x = x[:, :10] return F.log_softmax(x)
def forward(self, x): nBatch = x.size(0) L = self.M*self.L Q = L.mm(L.t()) + self.args.eps*Variable(torch.eye(self.nHidden)).cuda() Q = Q.unsqueeze(0).expand(nBatch, self.nHidden, self.nHidden) nI = Variable(-torch.eye(self.nFeatures-1).type_as(Q.data)) G = torch.cat(( torch.cat(( self.D, nI), 1), torch.cat((-self.D, nI), 1) )) G = G.unsqueeze(0).expand(nBatch, self.nineq, self.nHidden) h = self.h.unsqueeze(0).expand(nBatch, self.nineq) e = Variable(torch.Tensor()) # p = torch.cat((-x, self.lam.unsqueeze(0).expand(nBatch, self.nFeatures-1)), 1) p = torch.cat((-x, Parameter(13.*torch.ones(nBatch, self.nFeatures-1).cuda())), 1) x = QPFunction()(Q.double(), p.double(), G.double(), h.double(), e, e).float() x = x[:,:self.nFeatures] return x
def forward(self, log_prices): prices = torch.exp(log_prices) nBatch = prices.size(0) T = self.T Q = self.Q.unsqueeze(0).expand(nBatch, self.Q.size(0), self.Q.size(1)) c = torch.cat( [prices, -prices, -(self.lam * self.B * torch.ones(T, device=DEVICE)).unsqueeze(0).expand(nBatch,T)], 1) A = self.A.unsqueeze(0).expand(nBatch, self.A.size(0), self.A.size(1)) b = self.b.unsqueeze(0).expand(nBatch, self.b.size(0)) Ae = self.Ae.unsqueeze(0).expand(nBatch, self.Ae.size(0), self.Ae.size(1)) be = self.be.unsqueeze(0).expand(nBatch, self.be.size(0)) out = QPFunction(verbose=True)\ (Q.double(), c.double(), A.double(), b.double(), Ae.double(), be.double()) return out
def forward(self, y): nBatch, k = y.size() eps2 = 1e-8 Q_scale = torch.cat([ torch.diag(torch.cat([self.one, y[i] + eps2, y[i] + eps2 ])).unsqueeze(0) for i in range(nBatch) ], 0) Q = self.Q.unsqueeze(0).expand_as(Q_scale).mul(Q_scale) p_scale = torch.cat([torch.ones(nBatch, 1, device=DEVICE), y, y], 1) p = self.p.unsqueeze(0).expand_as(p_scale).mul(p_scale) G = self.G.unsqueeze(0).expand(nBatch, self.G.size(0), self.G.size(1)) h = self.h.unsqueeze(0).expand(nBatch, self.h.size(0)) e = torch.DoubleTensor() if USE_GPU: e = e.cuda() out = QPFunction(verbose=False)\ (Q.double(), p.double(), G.double(), h.double(), e, e).float() return out[:, :1]
def MetaOptNetHead_SVM(query, support, support_labels, n_way, n_shot, C_reg=0.1, double_precision=False, maxIter=15): """ Fits the support set with multi-class SVM and returns the classification score on the query set. This is the multi-class SVM presented in: On the Algorithmic Implementation of Multiclass Kernel-based Vector Machines (Crammer and Singer, Journal of Machine Learning Research 2001). This model is the classification head that we use for the final version. Parameters: query: a (tasks_per_batch, n_query, d) Tensor. support: a (tasks_per_batch, n_support, d) Tensor. support_labels: a (tasks_per_batch, n_support) Tensor. n_way: a scalar. Represents the number of classes in a few-shot classification task. n_shot: a scalar. Represents the number of support examples given per class. C_reg: a scalar. Represents the cost parameter C in SVM. Returns: a (tasks_per_batch, n_query, n_way) Tensor. """ tasks_per_batch = query.size(0) n_support = support.size(1) n_query = query.size(1) assert (query.dim() == 3) assert (support.dim() == 3) assert (query.size(0) == support.size(0) and query.size(2) == support.size(2)) assert (n_support == n_way * n_shot ) # n_support must equal to n_way * n_shot # Here we solve the dual problem: # Note that the classes are indexed by m & samples are indexed by i. # min_{\alpha} 0.5 \sum_m ||w_m(\alpha)||^2 + \sum_i \sum_m e^m_i alpha^m_i # s.t. \alpha^m_i <= C^m_i \forall m,i , \sum_m \alpha^m_i=0 \forall i # where w_m(\alpha) = \sum_i \alpha^m_i x_i, # and C^m_i = C if m = y_i, # C^m_i = 0 if m != y_i. # This borrows the notation of liblinear. # \alpha is an (n_support, n_way) matrix kernel_matrix = computeGramMatrix(support, support) id_matrix_0 = torch.eye(n_way).expand(tasks_per_batch, n_way, n_way).cuda() block_kernel_matrix = batched_kronecker(kernel_matrix, id_matrix_0) block_kernel_matrix += 1.0 * torch.eye(n_way * n_support).expand( tasks_per_batch, n_way * n_support, n_way * n_support).cuda() support_labels_one_hot = one_hot( support_labels.view(tasks_per_batch * n_support), n_way) # (tasks_per_batch * n_support, n_support) support_labels_one_hot = support_labels_one_hot.view( tasks_per_batch, n_support, n_way) support_labels_one_hot = support_labels_one_hot.reshape( tasks_per_batch, n_support * n_way) G = block_kernel_matrix e = -1.0 * support_labels_one_hot # print (G.size()) # This part is for the inequality constraints: # \alpha^m_i <= C^m_i \forall m,i # where C^m_i = C if m = y_i, # C^m_i = 0 if m != y_i. id_matrix_1 = torch.eye(n_way * n_support).expand(tasks_per_batch, n_way * n_support, n_way * n_support) C = Variable(id_matrix_1) h = Variable(C_reg * support_labels_one_hot) # print (C.size(), h.size()) # This part is for the equality constraints: # \sum_m \alpha^m_i=0 \forall i id_matrix_2 = torch.eye(n_support).expand(tasks_per_batch, n_support, n_support).cuda() A = Variable( batched_kronecker(id_matrix_2, torch.ones(tasks_per_batch, 1, n_way).cuda())) b = Variable(torch.zeros(tasks_per_batch, n_support)) if double_precision: G, e, C, h, A, b = [x.double().cuda() for x in [G, e, C, h, A, b]] else: G, e, C, h, A, b = [x.float().cuda() for x in [G, e, C, h, A, b]] # Solve the following QP to fit SVM: # \hat z = argmin_z 1/2 z^T G z + e^T z # subject to Cz <= h # We use detach() to prevent backpropagation to fixed variables. qp_sol = QPFunction(verbose=False, maxIter=maxIter)(G, e.detach(), C.detach(), h.detach(), A.detach(), b.detach()) # Compute the classification score. compatibility = computeGramMatrix(support, query) compatibility = compatibility.float() compatibility = compatibility.unsqueeze(3).expand(tasks_per_batch, n_support, n_query, n_way) qp_sol = qp_sol.reshape(tasks_per_batch, n_support, n_way) logits = qp_sol.float().unsqueeze(2).expand(tasks_per_batch, n_support, n_query, n_way) logits = logits * compatibility logits = torch.sum(logits, 1) return logits
def MetaOptNetHead_Ridge(query, support, support_labels, n_way, n_shot, lambda_reg=50.0, double_precision=False): """ Fits the support set with ridge regression and returns the classification score on the query set. Parameters: query: a (tasks_per_batch, n_query, d) Tensor. support: a (tasks_per_batch, n_support, d) Tensor. support_labels: a (tasks_per_batch, n_support) Tensor. n_way: a scalar. Represents the number of classes in a few-shot classification task. n_shot: a scalar. Represents the number of support examples given per class. lambda_reg: a scalar. Represents the strength of L2 regularization. Returns: a (tasks_per_batch, n_query, n_way) Tensor. """ tasks_per_batch = query.size(0) n_support = support.size(1) n_query = query.size(1) assert (query.dim() == 3) assert (support.dim() == 3) assert (query.size(0) == support.size(0) and query.size(2) == support.size(2)) assert (n_support == n_way * n_shot ) # n_support must equal to n_way * n_shot # Here we solve the dual problem: # Note that the classes are indexed by m & samples are indexed by i. # min_{\alpha} 0.5 \sum_m ||w_m(\alpha)||^2 + \sum_i \sum_m e^m_i alpha^m_i # where w_m(\alpha) = \sum_i \alpha^m_i x_i, # \alpha is an (n_support, n_way) matrix kernel_matrix = computeGramMatrix(support, support) kernel_matrix += lambda_reg * torch.eye(n_support).expand( tasks_per_batch, n_support, n_support).cuda() block_kernel_matrix = kernel_matrix.repeat( n_way, 1, 1) # (n_way * tasks_per_batch, n_support, n_support) support_labels_one_hot = one_hot( support_labels.view(tasks_per_batch * n_support), n_way) # (tasks_per_batch * n_support, n_way) support_labels_one_hot = support_labels_one_hot.transpose( 0, 1) # (n_way, tasks_per_batch * n_support) support_labels_one_hot = support_labels_one_hot.reshape( n_way * tasks_per_batch, n_support) # (n_way*tasks_per_batch, n_support) G = block_kernel_matrix e = -2.0 * support_labels_one_hot # This is a fake inequlity constraint as qpth does not support QP without an inequality constraint. id_matrix_1 = torch.zeros(tasks_per_batch * n_way, n_support, n_support) C = Variable(id_matrix_1) h = Variable(torch.zeros((tasks_per_batch * n_way, n_support))) dummy = Variable( torch.Tensor()).cuda() # We want to ignore the equality constraint. if double_precision: G, e, C, h = [x.double().cuda() for x in [G, e, C, h]] else: G, e, C, h = [x.float().cuda() for x in [G, e, C, h]] # Solve the following QP to fit SVM: # \hat z = argmin_z 1/2 z^T G z + e^T z # subject to Cz <= h # We use detach() to prevent backpropagation to fixed variables. qp_sol = QPFunction(verbose=False)(G, e.detach(), C.detach(), h.detach(), dummy.detach(), dummy.detach()) # qp_sol = QPFunction(verbose=False)(G, e.detach(), dummy.detach(), dummy.detach(), dummy.detach(), dummy.detach()) # qp_sol (n_way*tasks_per_batch, n_support) qp_sol = qp_sol.reshape(n_way, tasks_per_batch, n_support) # qp_sol (n_way, tasks_per_batch, n_support) qp_sol = qp_sol.permute(1, 2, 0) # qp_sol (tasks_per_batch, n_support, n_way) # Compute the classification score. compatibility = computeGramMatrix(support, query) compatibility = compatibility.float() compatibility = compatibility.unsqueeze(3).expand(tasks_per_batch, n_support, n_query, n_way) qp_sol = qp_sol.reshape(tasks_per_batch, n_support, n_way) logits = qp_sol.float().unsqueeze(2).expand(tasks_per_batch, n_support, n_query, n_way) logits = logits * compatibility logits = torch.sum(logits, 1) return logits
def prof_instance(nz, nBatch, nTrials, cuda=True): nineq, neq = nz, 0 assert (neq == 0) L = npr.rand(nBatch, nz, nz) Q = np.matmul(L, L.transpose((0, 2, 1))) + 1e-3 * np.eye(nz, nz) G = npr.randn(nBatch, nineq, nz) z0 = npr.randn(nBatch, nz) s0 = npr.rand(nBatch, nineq) p = npr.randn(nBatch, nz) h = np.matmul(G, np.expand_dims(z0, axis=(2))).squeeze(2) + s0 A = npr.randn(nBatch, neq, nz) b = np.matmul(A, np.expand_dims(z0, axis=(2))).squeeze(2) lm = nn.Linear(nz, nz) p, L, Q, G, z0, s0, h = [torch.Tensor(x) for x in [p, L, Q, G, z0, s0, h]] if cuda: p, L, Q, G, z0, s0, h = [x.cuda() for x in [p, L, Q, G, z0, s0, h]] lm = lm.cuda() if neq > 0: A = torch.Tensor(A) b = torch.Tensor(b) else: A, b = [torch.Tensor()] * 2 if cuda: A = A.cuda() b = b.cuda() p, L, Q, G, z0, s0, h, A, b = [ Variable(x) for x in [p, L, Q, G, z0, s0, h, A, b] ] p.requires_grad = True linearf_times = [] linearb_times = [] for i in range(nTrials + 1): start = time.time() zhat_l = lm(p) linearf_times.append(time.time() - start) start = time.time() zhat_l.backward(torch.ones(nBatch, nz).cuda()) linearb_times.append(time.time() - start) linearf_times = linearf_times[1:] linearb_times = linearb_times[1:] qpthf_times = [] qpthb_times = [] for i in range(nTrials + 1): start = time.time() qpf = QPFunction() zhat_b = qpf(Q, p, G, h, A, b) qpthf_times.append(time.time() - start) start = time.time() zhat_b.backward(torch.ones(nBatch, nz).cuda()) qpthb_times.append(time.time() - start) qpthf_times = qpthf_times[1:] qpthb_times = qpthb_times[1:] return np.array(linearf_times), np.array(qpthf_times), \ np.array(linearb_times), np.array(qpthb_times)
def MetaOptNetHead_SVM_WW(query, support, support_labels, n_way, n_shot, C_reg=0.00001, double_precision=False): """ Fits the support set with multi-class SVM and returns the classification score on the query set. This is the multi-class SVM presented in: Support Vector Machines for Multi Class Pattern Recognition (Weston and Watkins, ESANN 1999). Parameters: query: a (tasks_per_batch, n_query, d) Tensor. support: a (tasks_per_batch, n_support, d) Tensor. support_labels: a (tasks_per_batch, n_support) Tensor. n_way: a scalar. Represents the number of classes in a few-shot classification task. n_shot: a scalar. Represents the number of support examples given per class. C_reg: a scalar. Represents the cost parameter C in SVM. Returns: a (tasks_per_batch, n_query, n_way) Tensor. """ """ Fits the support set with multi-class SVM and returns the classification score on the query set. This is the multi-class SVM presented in: Support Vector Machines for Multi Class Pattern Recognition (Weston and Watkins, ESANN 1999). Parameters: query: a (tasks_per_batch, n_query, d) Tensor. support: a (tasks_per_batch, n_support, d) Tensor. support_labels: a (tasks_per_batch, n_support) Tensor. n_way: a scalar. Represents the number of classes in a few-shot classification task. n_shot: a scalar. Represents the number of support examples given per class. C_reg: a scalar. Represents the cost parameter C in SVM. Returns: a (tasks_per_batch, n_query, n_way) Tensor. """ tasks_per_batch = query.size(0) n_support = support.size(1) n_query = query.size(1) assert (query.dim() == 3) assert (support.dim() == 3) assert (query.size(0) == support.size(0) and query.size(2) == support.size(2)) assert (n_support == n_way * n_shot ) # n_support must equal to n_way * n_shot #In theory, \alpha is an (n_support, n_way) matrix #NOTE: In this implementation, we solve for a flattened vector of size (n_way*n_support) #In order to turn it into a matrix, you must first reshape it into an (n_way, n_support) matrix #then transpose it, resulting in (n_support, n_way) matrix kernel_matrix = computeGramMatrix(support, support) + torch.ones( tasks_per_batch, n_support, n_support).cuda() id_matrix_0 = torch.eye(n_way).expand(tasks_per_batch, n_way, n_way).cuda() block_kernel_matrix = batched_kronecker(id_matrix_0, kernel_matrix) kernel_matrix_mask_x = support_labels.reshape( tasks_per_batch, n_support, 1).expand(tasks_per_batch, n_support, n_support) kernel_matrix_mask_y = support_labels.reshape( tasks_per_batch, 1, n_support).expand(tasks_per_batch, n_support, n_support) kernel_matrix_mask = (kernel_matrix_mask_x == kernel_matrix_mask_y).float() block_kernel_matrix_inter = kernel_matrix_mask * kernel_matrix block_kernel_matrix += block_kernel_matrix_inter.repeat(1, n_way, n_way) kernel_matrix_mask_second_term = support_labels.reshape( tasks_per_batch, n_support, 1).expand(tasks_per_batch, n_support, n_support * n_way) kernel_matrix_mask_second_term = kernel_matrix_mask_second_term == torch.arange( n_way).long().repeat(n_support).reshape(n_support, n_way).transpose( 1, 0).reshape(1, -1).repeat(n_support, 1).cuda() kernel_matrix_mask_second_term = kernel_matrix_mask_second_term.float() block_kernel_matrix -= ( 2.0 - 1e-4) * (kernel_matrix_mask_second_term * kernel_matrix.repeat(1, 1, n_way)).repeat(1, n_way, 1) Y_support = one_hot(support_labels.view(tasks_per_batch * n_support), n_way) Y_support = Y_support.view(tasks_per_batch, n_support, n_way) Y_support = Y_support.transpose(1, 2) # (tasks_per_batch, n_way, n_support) Y_support = Y_support.reshape(tasks_per_batch, n_way * n_support) G = block_kernel_matrix e = -2.0 * torch.ones(tasks_per_batch, n_way * n_support) id_matrix = torch.eye(n_way * n_support).expand(tasks_per_batch, n_way * n_support, n_way * n_support) C_mat = C_reg * torch.ones(tasks_per_batch, n_way * n_support).cuda() - C_reg * Y_support C = Variable(torch.cat((id_matrix, -id_matrix), 1)) #C = Variable(torch.cat((id_matrix_masked, -id_matrix_masked), 1)) zer = torch.zeros(tasks_per_batch, n_way * n_support).cuda() h = Variable(torch.cat((C_mat, zer), 1)) dummy = Variable( torch.Tensor()).cuda() # We want to ignore the equality constraint. if double_precision: G, e, C, h = [x.double().cuda() for x in [G, e, C, h]] else: G, e, C, h = [x.cuda() for x in [G, e, C, h]] # Solve the following QP to fit SVM: # \hat z = argmin_z 1/2 z^T G z + e^T z # subject to Cz <= h # We use detach() to prevent backpropagation to fixed variables. #qp_sol = QPFunction(verbose=False)(G, e.detach(), C.detach(), h.detach(), dummy.detach(), dummy.detach()) qp_sol = QPFunction(verbose=False)(G, e, C, h, dummy.detach(), dummy.detach()) # Compute the classification score. compatibility = computeGramMatrix(support, query) + torch.ones( tasks_per_batch, n_support, n_query).cuda() compatibility = compatibility.float() compatibility = compatibility.unsqueeze(1).expand(tasks_per_batch, n_way, n_support, n_query) qp_sol = qp_sol.float() qp_sol = qp_sol.reshape(tasks_per_batch, n_way, n_support) A_i = torch.sum(qp_sol, 1) # (tasks_per_batch, n_support) A_i = A_i.unsqueeze(1).expand(tasks_per_batch, n_way, n_support) qp_sol = qp_sol.float().unsqueeze(3).expand(tasks_per_batch, n_way, n_support, n_query) Y_support_reshaped = Y_support.reshape(tasks_per_batch, n_way, n_support) Y_support_reshaped = A_i * Y_support_reshaped Y_support_reshaped = Y_support_reshaped.unsqueeze(3).expand( tasks_per_batch, n_way, n_support, n_query) logits = (Y_support_reshaped - qp_sol) * compatibility logits = torch.sum(logits, 2) return logits.transpose(1, 2)
def MetaOptNetHead_SVM_He(query, support, support_labels, n_way, n_shot, C_reg=0.01, double_precision=False): """ Fits the support set with multi-class SVM and returns the classification score on the query set. This is the multi-class SVM presented in: A simplified multi-class support vector machine with reduced dual optimization (He et al., Pattern Recognition Letter 2012). This SVM is desirable because the dual variable of size is n_support (as opposed to n_way*n_support in the Weston&Watkins or Crammer&Singer multi-class SVM). This model is the classification head that we have initially used for our project. This was dropped since it turned out that it performs suboptimally on the meta-learning scenarios. Parameters: query: a (tasks_per_batch, n_query, d) Tensor. support: a (tasks_per_batch, n_support, d) Tensor. support_labels: a (tasks_per_batch, n_support) Tensor. n_way: a scalar. Represents the number of classes in a few-shot classification task. n_shot: a scalar. Represents the number of support examples given per class. C_reg: a scalar. Represents the cost parameter C in SVM. Returns: a (tasks_per_batch, n_query, n_way) Tensor. """ tasks_per_batch = query.size(0) n_support = support.size(1) n_query = query.size(1) assert (query.dim() == 3) assert (support.dim() == 3) assert (query.size(0) == support.size(0) and query.size(2) == support.size(2)) assert (n_support == n_way * n_shot ) # n_support must equal to n_way * n_shot kernel_matrix = computeGramMatrix(support, support) V = (support_labels * n_way - torch.ones(tasks_per_batch, n_support, n_way).cuda()) / (n_way - 1) G = computeGramMatrix(V, V).detach() G = kernel_matrix * G e = Variable(-1.0 * torch.ones(tasks_per_batch, n_support)) id_matrix = torch.eye(n_support).expand(tasks_per_batch, n_support, n_support) C = Variable(torch.cat((id_matrix, -id_matrix), 1)) h = Variable( torch.cat((C_reg * torch.ones(tasks_per_batch, n_support), torch.zeros(tasks_per_batch, n_support)), 1)) dummy = Variable( torch.Tensor()).cuda() # We want to ignore the equality constraint. if double_precision: G, e, C, h = [x.double().cuda() for x in [G, e, C, h]] else: G, e, C, h = [x.cuda() for x in [G, e, C, h]] # Solve the following QP to fit SVM: # \hat z = argmin_z 1/2 z^T G z + e^T z # subject to Cz <= h # We use detach() to prevent backpropagation to fixed variables. qp_sol = QPFunction(verbose=False)(G, e.detach(), C.detach(), h.detach(), dummy.detach(), dummy.detach()) # Compute the classification score. compatibility = computeGramMatrix(query, support) compatibility = compatibility.float() logits = qp_sol.float().unsqueeze(1).expand(tasks_per_batch, n_query, n_support) logits = logits * compatibility logits = logits.view(tasks_per_batch, n_query, n_shot, n_way) logits = torch.sum(logits, 2) return logits
def forward(self, data): n = data.shape[0] vs = torch.unsqueeze(data[:, 0], 1) vnexts = torch.unsqueeze(data[:, 1], 1) us = torch.unsqueeze(data[:, 2], 1) mu = self.mu #mu = torch.tensor([1.0]) beta = vnexts - vs - us G = torch.tensor([[1.0, -1, 1], [-1, 1, 1], [-1, -1, 0]]) Gpad = torch.tensor([[1.0, -1, 1, 0, 0, 0], [-1, 1, 1, 0, 0, 0], [-1, -1, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0]]) fmats = torch.tensor([[1.0, 1, 0], [-1, -1, 0], [0, 0, 1]]).unsqueeze(0).repeat(n, 1, 1) fvecs = torch.cat((vs, us, mu * torch.ones(vs.shape)), 1).unsqueeze(2) f = torch.bmm(fmats, fvecs) batch_zeros = torch.zeros(f.shape) fpad = torch.cat((f, batch_zeros), 1) # For prediction error A = torch.tensor([[1.0, -1, 0, 0, 0, 0], [-1, 1, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0]]).repeat(n, 1, 1) beta_zeros = torch.zeros(beta.shape) b = torch.cat((-2 * beta, 2 * beta, beta_zeros, beta_zeros, beta_zeros, beta_zeros), 1).unsqueeze(2) slack_penalty = torch.tensor([0.0, 0, 0, 1, 1, 1]).repeat(n, 1).unsqueeze(2) a1 = 1 a2 = 1 a3 = 1 Q = 2 * a1 * A + 2 * a2 * Gpad pdb.set_trace() p = a1 * b + a2 * fpad + a3 * slack_penalty # Constrain lambda and slacks to be >= 0 R = -torch.eye(6) h = torch.zeros((1, 6)) # Constrain G lambda + f >= 0 #R = torch.cat((R, -G)) # Should not have second negative here? R = torch.cat((R, -torch.cat((G, -torch.eye(3)), 1))) #R = torch.cat((R, -torch.cat((G, torch.zeros(3,3)), 1))) h = torch.cat((h.transpose(0, 1), f.unsqueeze(1))) h = h.transpose(0, 1) Qmod = 0.5 * (Q + Q.transpose(0, 1)) + 0.001 * torch.eye(6) z = QPFunction(check_Q_spd=False)(Qmod, p, R, h, torch.tensor([]), torch.tensor([])) #print(self.scipy_optimize(0.5 * (Q + Q.transpose(0, 1)), p, R, h)) #assert(torch.all(torch.matmul(R, z.transpose(0, 1)) \ # <= (h.transpose(0, 1) + torch.ones(h.shape) * 1e-5))) lcp_slack = torch.matmul(Gpad, z.transpose(0, 1)).transpose(0, 1) + fpad #print(z[0]) cost = 0.5 * torch.matmul(z, torch.matmul(Qmod, z.transpose(0, 1))) \ + torch.matmul(p, z.transpose(0, 1)) + a1 * beta**2 return cost
def MetaOptNetHead_Ridge(query, support, support_labels, n_way, n_shot, device, lambda_reg=100.0, double_precision=False): tasks_per_batch = query.size(0) n_support = support.size(1) n_query = query.size(1) assert (query.dim() == 3) assert (support.dim() == 3) assert (query.size(0) == support.size(0) and query.size(2) == support.size(2)) assert (n_support == n_way * n_shot ) # n_support must equal to n_way * n_shot kernel_matrix = computeGramMatrix(support, support) kernel_matrix += lambda_reg * torch.eye(n_support).expand( tasks_per_batch, n_support, n_support).cuda() block_kernel_matrix = kernel_matrix.repeat( n_way, 1, 1) #(n_way * tasks_per_batch, n_support, n_support) support_labels_one_hot = one_hot( support_labels.view(tasks_per_batch * n_support), n_way, device) # (tasks_per_batch * n_support, n_way) support_labels_one_hot = support_labels_one_hot.transpose( 0, 1) # (n_way, tasks_per_batch * n_support) support_labels_one_hot = support_labels_one_hot.reshape( n_way * tasks_per_batch, n_support) # (n_way*tasks_per_batch, n_support) G = block_kernel_matrix e = -2.0 * support_labels_one_hot #This is a fake inequlity constraint as qpth does not support QP without an inequality constraint. id_matrix_1 = torch.zeros(tasks_per_batch * n_way, n_support, n_support) C = Variable(id_matrix_1) h = Variable(torch.zeros((tasks_per_batch * n_way, n_support))) dummy = Variable(torch.Tensor()).cuda() if double_precision: G, e, C, h = [x.double().cuda() for x in [G, e, C, h]] else: G, e, C, h = [x.float().cuda() for x in [G, e, C, h]] qp_sol = QPFunction(verbose=False)(G, e.detach(), C.detach(), h.detach(), dummy.detach(), dummy.detach()) #qp_sol = QPFunction(verbose=False)(G, e.detach(), dummy.detach(), dummy.detach(), dummy.detach(), dummy.detach()) #qp_sol (n_way*tasks_per_batch, n_support) qp_sol = qp_sol.reshape(n_way, tasks_per_batch, n_support) #qp_sol (n_way, tasks_per_batch, n_support) qp_sol = qp_sol.permute(1, 2, 0) #qp_sol (tasks_per_batch, n_support, n_way) # Compute the classification score. compatibility = computeGramMatrix(support, query) compatibility = compatibility.float() compatibility = compatibility.unsqueeze(3).expand(tasks_per_batch, n_support, n_query, n_way) qp_sol = qp_sol.reshape(tasks_per_batch, n_support, n_way) logits = qp_sol.float().unsqueeze(2).expand(tasks_per_batch, n_support, n_query, n_way) logits = logits * compatibility logits = torch.sum(logits, 1) return logits