Exemple #1
0
    def forward(self, x):
        nBatch = x.size(0)

        # FC-ReLU-(BN)-FC-ReLU-(BN)-QP-Softmax
        x = x.view(nBatch, -1)
        x = F.relu(self.fc1(x))
        if self.bn:
            x = self.bn1(x)
        x = F.relu(self.fc2(x))
        if self.bn:
            x = self.bn2(x)

        L = self.M * self.L
        Q = L.mm(L.t()) + self.eps * Variable(torch.eye(self.nCls)).cuda()
        Q = Q.unsqueeze(0).expand(nBatch, self.nCls, self.nCls)
        G = self.G.unsqueeze(0).expand(nBatch, self.nineq, self.nCls)
        z0 = self.qp_z0(x)
        s0 = self.qp_s0(x)
        h = z0.mm(self.G.t()) + s0
        e = Variable(torch.Tensor())
        inputs = x
        x = QPFunction()(inputs.double(), Q.double(), G.double(), h.double(),
                         e, e)
        x = x.float()
        # x = x[:,:10].float()

        return F.log_softmax(x)
def MetaOptNetHead_SVM(query,
                       support,
                       support_labels,
                       n_way,
                       n_shot,
                       C_reg=0.1,
                       double_precision=False,
                       maxIter=15):
    """
    Fits the support set with multi-class SVM and
    returns the classification score on the query set.

    This is the multi-class SVM presented in:
    On the Algorithmic Implementation of Multiclass Kernel-based Vector Machines
    (Crammer and Singer, Journal of Machine Learning Research 2001).
    This model is the classification head that we use for the final version.
    Parameters:
      query:  a (tasks_per_batch, n_query, d) Tensor.
      support:  a (tasks_per_batch, n_support, d) Tensor.
      support_labels: a (tasks_per_batch, n_support) Tensor.
      n_way: a scalar. Represents the number of classes in a few-shot classification task.
      n_shot: a scalar. Represents the number of support examples given per class.
      C_reg: a scalar. Represents the cost parameter C in SVM.
    Returns: a (tasks_per_batch, n_query, n_way) Tensor.
    """
    tasks_per_batch = query.size(0)
    n_support = support.size(1)
    n_query = query.size(1)
    assert (query.dim() == 3)
    assert (support.dim() == 3)
    assert (query.size(0) == support.size(0)
            and query.size(2) == support.size(2))
    assert (n_support == n_way * n_shot
            )  # n_support must equal to n_way * n_shot

    # Here we solve the dual problem:
    # Note that the classes are indexed by m & samples are indexed by i.
    # min_{\alpha}  0.5 \sum_m ||w_m(\alpha)||^2 + \sum_i \sum_m e^m_i alpha^m_i
    # s.t.  \alpha^m_i <= C^m_i \forall m,i , \sum_m \alpha^m_i=0 \forall i

    # where w_m(\alpha) = \sum_i \alpha^m_i x_i,
    # and C^m_i = C if m  = y_i,
    # C^m_i = 0 if m != y_i.
    # This borrows the notation of liblinear.

    # \alpha is an (n_support, n_way) matrix
    kernel_matrix = computeGramMatrix(support, support)
    id_matrix_0 = torch.eye(n_way).expand(tasks_per_batch, n_way, n_way).cuda()
    block_kernel_matrix = batched_kronecker(kernel_matrix, id_matrix_0)

    block_kernel_matrix += 1.0 * torch.eye(n_way * n_support).expand(
        tasks_per_batch, n_way * n_support, n_way * n_support).cuda()
    support_labels_one_hot = one_hot(
        support_labels.view(tasks_per_batch * n_support),
        n_way)  # (tasks_per_batch * n_support, n_support)
    support_labels_one_hot = support_labels_one_hot.view(
        tasks_per_batch, n_support, n_way)
    support_labels_one_hot = support_labels_one_hot.reshape(
        tasks_per_batch, n_support * n_way)

    G = block_kernel_matrix
    e = -1.0 * support_labels_one_hot
    # print (G.size())
    # This part is for the inequality constraints:
    # \alpha^m_i <= C^m_i \forall m,i
    # where C^m_i = C if m  = y_i,
    # C^m_i = 0 if m != y_i.
    id_matrix_1 = torch.eye(n_way * n_support).expand(tasks_per_batch,
                                                      n_way * n_support,
                                                      n_way * n_support)
    C = Variable(id_matrix_1)
    h = Variable(C_reg * support_labels_one_hot)
    # print (C.size(), h.size())
    # This part is for the equality constraints:
    # \sum_m \alpha^m_i=0 \forall i
    id_matrix_2 = torch.eye(n_support).expand(tasks_per_batch, n_support,
                                              n_support).cuda()
    A = Variable(
        batched_kronecker(id_matrix_2,
                          torch.ones(tasks_per_batch, 1, n_way).cuda()))
    b = Variable(torch.zeros(tasks_per_batch, n_support))

    if double_precision:
        G, e, C, h, A, b = [x.double().cuda() for x in [G, e, C, h, A, b]]
    else:
        G, e, C, h, A, b = [x.float().cuda() for x in [G, e, C, h, A, b]]
    # Solve the following QP to fit SVM:
    #        \hat z =   argmin_z 1/2 z^T G z + e^T z
    #                 subject to Cz <= h
    # We use detach() to prevent backpropagation to fixed variables.
    qp_sol = QPFunction(verbose=False, maxIter=maxIter)(G, e.detach(),
                                                        C.detach(), h.detach(),
                                                        A.detach(), b.detach())

    # Compute the classification score.
    compatibility = computeGramMatrix(support, query)
    compatibility = compatibility.float()
    compatibility = compatibility.unsqueeze(3).expand(tasks_per_batch,
                                                      n_support, n_query,
                                                      n_way)
    qp_sol = qp_sol.reshape(tasks_per_batch, n_support, n_way)
    logits = qp_sol.float().unsqueeze(2).expand(tasks_per_batch, n_support,
                                                n_query, n_way)
    logits = logits * compatibility
    logits = torch.sum(logits, 1)
    return logits
def MetaOptNetHead_Ridge(query,
                         support,
                         support_labels,
                         n_way,
                         n_shot,
                         lambda_reg=50.0,
                         double_precision=False):
    """
    Fits the support set with ridge regression and
    returns the classification score on the query set.
    Parameters:
      query:  a (tasks_per_batch, n_query, d) Tensor.
      support:  a (tasks_per_batch, n_support, d) Tensor.
      support_labels: a (tasks_per_batch, n_support) Tensor.
      n_way: a scalar. Represents the number of classes in a few-shot classification task.
      n_shot: a scalar. Represents the number of support examples given per class.
      lambda_reg: a scalar. Represents the strength of L2 regularization.
    Returns: a (tasks_per_batch, n_query, n_way) Tensor.
    """

    tasks_per_batch = query.size(0)
    n_support = support.size(1)
    n_query = query.size(1)

    assert (query.dim() == 3)
    assert (support.dim() == 3)
    assert (query.size(0) == support.size(0)
            and query.size(2) == support.size(2))
    assert (n_support == n_way * n_shot
            )  # n_support must equal to n_way * n_shot

    # Here we solve the dual problem:
    # Note that the classes are indexed by m & samples are indexed by i.
    # min_{\alpha}  0.5 \sum_m ||w_m(\alpha)||^2 + \sum_i \sum_m e^m_i alpha^m_i

    # where w_m(\alpha) = \sum_i \alpha^m_i x_i,

    # \alpha is an (n_support, n_way) matrix
    kernel_matrix = computeGramMatrix(support, support)
    kernel_matrix += lambda_reg * torch.eye(n_support).expand(
        tasks_per_batch, n_support, n_support).cuda()
    block_kernel_matrix = kernel_matrix.repeat(
        n_way, 1, 1)  # (n_way * tasks_per_batch, n_support, n_support)
    support_labels_one_hot = one_hot(
        support_labels.view(tasks_per_batch * n_support),
        n_way)  # (tasks_per_batch * n_support, n_way)
    support_labels_one_hot = support_labels_one_hot.transpose(
        0, 1)  # (n_way, tasks_per_batch * n_support)
    support_labels_one_hot = support_labels_one_hot.reshape(
        n_way * tasks_per_batch,
        n_support)  # (n_way*tasks_per_batch, n_support)

    G = block_kernel_matrix
    e = -2.0 * support_labels_one_hot

    # This is a fake inequlity constraint as qpth does not support QP without an inequality constraint.
    id_matrix_1 = torch.zeros(tasks_per_batch * n_way, n_support, n_support)
    C = Variable(id_matrix_1)
    h = Variable(torch.zeros((tasks_per_batch * n_way, n_support)))
    dummy = Variable(
        torch.Tensor()).cuda()  # We want to ignore the equality constraint.

    if double_precision:
        G, e, C, h = [x.double().cuda() for x in [G, e, C, h]]

    else:
        G, e, C, h = [x.float().cuda() for x in [G, e, C, h]]

    # Solve the following QP to fit SVM:
    #        \hat z =   argmin_z 1/2 z^T G z + e^T z
    #                 subject to Cz <= h
    # We use detach() to prevent backpropagation to fixed variables.
    qp_sol = QPFunction(verbose=False)(G, e.detach(), C.detach(), h.detach(),
                                       dummy.detach(), dummy.detach())
    # qp_sol = QPFunction(verbose=False)(G, e.detach(), dummy.detach(), dummy.detach(), dummy.detach(), dummy.detach())

    # qp_sol (n_way*tasks_per_batch, n_support)
    qp_sol = qp_sol.reshape(n_way, tasks_per_batch, n_support)
    # qp_sol (n_way, tasks_per_batch, n_support)
    qp_sol = qp_sol.permute(1, 2, 0)
    # qp_sol (tasks_per_batch, n_support, n_way)

    # Compute the classification score.
    compatibility = computeGramMatrix(support, query)
    compatibility = compatibility.float()
    compatibility = compatibility.unsqueeze(3).expand(tasks_per_batch,
                                                      n_support, n_query,
                                                      n_way)
    qp_sol = qp_sol.reshape(tasks_per_batch, n_support, n_way)
    logits = qp_sol.float().unsqueeze(2).expand(tasks_per_batch, n_support,
                                                n_query, n_way)
    logits = logits * compatibility
    logits = torch.sum(logits, 1)

    return logits
Exemple #4
0
def MetaOptNetHead_SVM_WW(query,
                          support,
                          support_labels,
                          n_way,
                          n_shot,
                          C_reg=0.00001,
                          double_precision=False):
    """
    Fits the support set with multi-class SVM and 
    returns the classification score on the query set.
    
    This is the multi-class SVM presented in:
    Support Vector Machines for Multi Class Pattern Recognition
    (Weston and Watkins, ESANN 1999).
    
    Parameters:
      query:  a (tasks_per_batch, n_query, d) Tensor.
      support:  a (tasks_per_batch, n_support, d) Tensor.
      support_labels: a (tasks_per_batch, n_support) Tensor.
      n_way: a scalar. Represents the number of classes in a few-shot classification task.
      n_shot: a scalar. Represents the number of support examples given per class.
      C_reg: a scalar. Represents the cost parameter C in SVM.
    Returns: a (tasks_per_batch, n_query, n_way) Tensor.
    """
    """
    Fits the support set with multi-class SVM and 
    returns the classification score on the query set.
    
    This is the multi-class SVM presented in:
    Support Vector Machines for Multi Class Pattern Recognition
    (Weston and Watkins, ESANN 1999).
    
    Parameters:
      query:  a (tasks_per_batch, n_query, d) Tensor.
      support:  a (tasks_per_batch, n_support, d) Tensor.
      support_labels: a (tasks_per_batch, n_support) Tensor.
      n_way: a scalar. Represents the number of classes in a few-shot classification task.
      n_shot: a scalar. Represents the number of support examples given per class.
      C_reg: a scalar. Represents the cost parameter C in SVM.
    Returns: a (tasks_per_batch, n_query, n_way) Tensor.
    """
    tasks_per_batch = query.size(0)
    n_support = support.size(1)
    n_query = query.size(1)

    assert (query.dim() == 3)
    assert (support.dim() == 3)
    assert (query.size(0) == support.size(0)
            and query.size(2) == support.size(2))
    assert (n_support == n_way * n_shot
            )  # n_support must equal to n_way * n_shot

    #In theory, \alpha is an (n_support, n_way) matrix
    #NOTE: In this implementation, we solve for a flattened vector of size (n_way*n_support)
    #In order to turn it into a matrix, you must first reshape it into an (n_way, n_support) matrix
    #then transpose it, resulting in (n_support, n_way) matrix
    kernel_matrix = computeGramMatrix(support, support) + torch.ones(
        tasks_per_batch, n_support, n_support).cuda()

    id_matrix_0 = torch.eye(n_way).expand(tasks_per_batch, n_way, n_way).cuda()
    block_kernel_matrix = batched_kronecker(id_matrix_0, kernel_matrix)

    kernel_matrix_mask_x = support_labels.reshape(
        tasks_per_batch, n_support, 1).expand(tasks_per_batch, n_support,
                                              n_support)
    kernel_matrix_mask_y = support_labels.reshape(
        tasks_per_batch, 1, n_support).expand(tasks_per_batch, n_support,
                                              n_support)
    kernel_matrix_mask = (kernel_matrix_mask_x == kernel_matrix_mask_y).float()

    block_kernel_matrix_inter = kernel_matrix_mask * kernel_matrix
    block_kernel_matrix += block_kernel_matrix_inter.repeat(1, n_way, n_way)

    kernel_matrix_mask_second_term = support_labels.reshape(
        tasks_per_batch, n_support, 1).expand(tasks_per_batch, n_support,
                                              n_support * n_way)
    kernel_matrix_mask_second_term = kernel_matrix_mask_second_term == torch.arange(
        n_way).long().repeat(n_support).reshape(n_support, n_way).transpose(
            1, 0).reshape(1, -1).repeat(n_support, 1).cuda()
    kernel_matrix_mask_second_term = kernel_matrix_mask_second_term.float()

    block_kernel_matrix -= (
        2.0 - 1e-4) * (kernel_matrix_mask_second_term *
                       kernel_matrix.repeat(1, 1, n_way)).repeat(1, n_way, 1)

    Y_support = one_hot(support_labels.view(tasks_per_batch * n_support),
                        n_way)
    Y_support = Y_support.view(tasks_per_batch, n_support, n_way)
    Y_support = Y_support.transpose(1,
                                    2)  # (tasks_per_batch, n_way, n_support)
    Y_support = Y_support.reshape(tasks_per_batch, n_way * n_support)

    G = block_kernel_matrix

    e = -2.0 * torch.ones(tasks_per_batch, n_way * n_support)
    id_matrix = torch.eye(n_way * n_support).expand(tasks_per_batch,
                                                    n_way * n_support,
                                                    n_way * n_support)

    C_mat = C_reg * torch.ones(tasks_per_batch,
                               n_way * n_support).cuda() - C_reg * Y_support

    C = Variable(torch.cat((id_matrix, -id_matrix), 1))
    #C = Variable(torch.cat((id_matrix_masked, -id_matrix_masked), 1))
    zer = torch.zeros(tasks_per_batch, n_way * n_support).cuda()

    h = Variable(torch.cat((C_mat, zer), 1))

    dummy = Variable(
        torch.Tensor()).cuda()  # We want to ignore the equality constraint.

    if double_precision:
        G, e, C, h = [x.double().cuda() for x in [G, e, C, h]]
    else:
        G, e, C, h = [x.cuda() for x in [G, e, C, h]]

    # Solve the following QP to fit SVM:
    #        \hat z =   argmin_z 1/2 z^T G z + e^T z
    #                 subject to Cz <= h
    # We use detach() to prevent backpropagation to fixed variables.
    #qp_sol = QPFunction(verbose=False)(G, e.detach(), C.detach(), h.detach(), dummy.detach(), dummy.detach())
    qp_sol = QPFunction(verbose=False)(G, e, C, h, dummy.detach(),
                                       dummy.detach())

    # Compute the classification score.
    compatibility = computeGramMatrix(support, query) + torch.ones(
        tasks_per_batch, n_support, n_query).cuda()
    compatibility = compatibility.float()
    compatibility = compatibility.unsqueeze(1).expand(tasks_per_batch, n_way,
                                                      n_support, n_query)
    qp_sol = qp_sol.float()
    qp_sol = qp_sol.reshape(tasks_per_batch, n_way, n_support)
    A_i = torch.sum(qp_sol, 1)  # (tasks_per_batch, n_support)
    A_i = A_i.unsqueeze(1).expand(tasks_per_batch, n_way, n_support)
    qp_sol = qp_sol.float().unsqueeze(3).expand(tasks_per_batch, n_way,
                                                n_support, n_query)
    Y_support_reshaped = Y_support.reshape(tasks_per_batch, n_way, n_support)
    Y_support_reshaped = A_i * Y_support_reshaped
    Y_support_reshaped = Y_support_reshaped.unsqueeze(3).expand(
        tasks_per_batch, n_way, n_support, n_query)
    logits = (Y_support_reshaped - qp_sol) * compatibility

    logits = torch.sum(logits, 2)

    return logits.transpose(1, 2)
Exemple #5
0
def MetaOptNetHead_SVM_He(query,
                          support,
                          support_labels,
                          n_way,
                          n_shot,
                          C_reg=0.01,
                          double_precision=False):
    """
    Fits the support set with multi-class SVM and 
    returns the classification score on the query set.
    
    This is the multi-class SVM presented in:
    A simplified multi-class support vector machine with reduced dual optimization
    (He et al., Pattern Recognition Letter 2012).
    
    This SVM is desirable because the dual variable of size is n_support
    (as opposed to n_way*n_support in the Weston&Watkins or Crammer&Singer multi-class SVM).
    This model is the classification head that we have initially used for our project.
    This was dropped since it turned out that it performs suboptimally on the meta-learning scenarios.
    
    Parameters:
      query:  a (tasks_per_batch, n_query, d) Tensor.
      support:  a (tasks_per_batch, n_support, d) Tensor.
      support_labels: a (tasks_per_batch, n_support) Tensor.
      n_way: a scalar. Represents the number of classes in a few-shot classification task.
      n_shot: a scalar. Represents the number of support examples given per class.
      C_reg: a scalar. Represents the cost parameter C in SVM.
    Returns: a (tasks_per_batch, n_query, n_way) Tensor.
    """

    tasks_per_batch = query.size(0)
    n_support = support.size(1)
    n_query = query.size(1)

    assert (query.dim() == 3)
    assert (support.dim() == 3)
    assert (query.size(0) == support.size(0)
            and query.size(2) == support.size(2))
    assert (n_support == n_way * n_shot
            )  # n_support must equal to n_way * n_shot

    kernel_matrix = computeGramMatrix(support, support)

    V = (support_labels * n_way -
         torch.ones(tasks_per_batch, n_support, n_way).cuda()) / (n_way - 1)
    G = computeGramMatrix(V, V).detach()
    G = kernel_matrix * G

    e = Variable(-1.0 * torch.ones(tasks_per_batch, n_support))
    id_matrix = torch.eye(n_support).expand(tasks_per_batch, n_support,
                                            n_support)
    C = Variable(torch.cat((id_matrix, -id_matrix), 1))
    h = Variable(
        torch.cat((C_reg * torch.ones(tasks_per_batch, n_support),
                   torch.zeros(tasks_per_batch, n_support)), 1))
    dummy = Variable(
        torch.Tensor()).cuda()  # We want to ignore the equality constraint.

    if double_precision:
        G, e, C, h = [x.double().cuda() for x in [G, e, C, h]]
    else:
        G, e, C, h = [x.cuda() for x in [G, e, C, h]]

    # Solve the following QP to fit SVM:
    #        \hat z =   argmin_z 1/2 z^T G z + e^T z
    #                 subject to Cz <= h
    # We use detach() to prevent backpropagation to fixed variables.
    qp_sol = QPFunction(verbose=False)(G, e.detach(), C.detach(), h.detach(),
                                       dummy.detach(), dummy.detach())

    # Compute the classification score.
    compatibility = computeGramMatrix(query, support)
    compatibility = compatibility.float()

    logits = qp_sol.float().unsqueeze(1).expand(tasks_per_batch, n_query,
                                                n_support)
    logits = logits * compatibility
    logits = logits.view(tasks_per_batch, n_query, n_shot, n_way)
    logits = torch.sum(logits, 2)

    return logits
Exemple #6
0
def MetaOptNetHead_Ridge(query,
                         support,
                         support_labels,
                         n_way,
                         n_shot,
                         device,
                         lambda_reg=100.0,
                         double_precision=False):
    tasks_per_batch = query.size(0)
    n_support = support.size(1)
    n_query = query.size(1)

    assert (query.dim() == 3)
    assert (support.dim() == 3)
    assert (query.size(0) == support.size(0)
            and query.size(2) == support.size(2))
    assert (n_support == n_way * n_shot
            )  # n_support must equal to n_way * n_shot

    kernel_matrix = computeGramMatrix(support, support)
    kernel_matrix += lambda_reg * torch.eye(n_support).expand(
        tasks_per_batch, n_support, n_support).cuda()

    block_kernel_matrix = kernel_matrix.repeat(
        n_way, 1, 1)  #(n_way * tasks_per_batch, n_support, n_support)

    support_labels_one_hot = one_hot(
        support_labels.view(tasks_per_batch * n_support), n_way,
        device)  # (tasks_per_batch * n_support, n_way)
    support_labels_one_hot = support_labels_one_hot.transpose(
        0, 1)  # (n_way, tasks_per_batch * n_support)
    support_labels_one_hot = support_labels_one_hot.reshape(
        n_way * tasks_per_batch,
        n_support)  # (n_way*tasks_per_batch, n_support)

    G = block_kernel_matrix
    e = -2.0 * support_labels_one_hot

    #This is a fake inequlity constraint as qpth does not support QP without an inequality constraint.
    id_matrix_1 = torch.zeros(tasks_per_batch * n_way, n_support, n_support)
    C = Variable(id_matrix_1)
    h = Variable(torch.zeros((tasks_per_batch * n_way, n_support)))
    dummy = Variable(torch.Tensor()).cuda()

    if double_precision:
        G, e, C, h = [x.double().cuda() for x in [G, e, C, h]]

    else:
        G, e, C, h = [x.float().cuda() for x in [G, e, C, h]]

    qp_sol = QPFunction(verbose=False)(G, e.detach(), C.detach(), h.detach(),
                                       dummy.detach(), dummy.detach())
    #qp_sol = QPFunction(verbose=False)(G, e.detach(), dummy.detach(), dummy.detach(), dummy.detach(), dummy.detach())

    #qp_sol (n_way*tasks_per_batch, n_support)
    qp_sol = qp_sol.reshape(n_way, tasks_per_batch, n_support)
    #qp_sol (n_way, tasks_per_batch, n_support)
    qp_sol = qp_sol.permute(1, 2, 0)
    #qp_sol (tasks_per_batch, n_support, n_way)

    # Compute the classification score.
    compatibility = computeGramMatrix(support, query)
    compatibility = compatibility.float()
    compatibility = compatibility.unsqueeze(3).expand(tasks_per_batch,
                                                      n_support, n_query,
                                                      n_way)
    qp_sol = qp_sol.reshape(tasks_per_batch, n_support, n_way)
    logits = qp_sol.float().unsqueeze(2).expand(tasks_per_batch, n_support,
                                                n_query, n_way)
    logits = logits * compatibility
    logits = torch.sum(logits, 1)

    return logits
    def inner_loop_adapt(self, support, support_labels, query, n_way, n_shot,
                         n_query):
        """
        Fits the support set with ridge regression and 
        returns the classification score on the query set.
        Parameters:
        query:  a (n_tasks_per_batch, n_query, c, h, w) Tensor.
        support:  a (n_tasks_per_batch, n_support, c, h, w) Tensor.
        support_labels: a (tasks_per_batch, n_support) Tensor.
        lambda_reg: a scalar. Represents the strength of L2 regularization.
        Returns: a (tasks_per_batch, n_query, n_way) Tensor.
        """

        measurements_trajectory = defaultdict(list)

        assert (query.dim() == 5)
        assert (support.dim() == 5)

        # get features
        orig_query_shape = query.shape
        orig_support_shape = support.shape

        support = self._model(support.reshape(-1, *orig_support_shape[2:]),
                              only_features=True).reshape(
                                  *orig_support_shape[:2], -1)
        query = self._model(query.reshape(-1, *orig_query_shape[2:]),
                            only_features=True).reshape(
                                *orig_query_shape[:2], -1)

        lambda_reg = self._lambda_reg
        double_precision = self._double_precision
        tasks_per_batch = query.size(0)
        total_n_support = support.size(
            1)  # support samples across all classes in a task
        total_n_query = query.size(
            1)  # query samples across all classes in a task
        d = query.size(2)  # dimension

        assert (query.dim() == 3)
        assert (support.dim() == 3)
        assert (query.size(0) == support.size(0)
                and query.size(2) == support.size(2))
        assert (total_n_support == n_way * n_shot
                )  # total_n_support must equal to n_way * n_shot
        assert (total_n_query == n_way * n_query
                )  # total_n_support must equal to n_way * n_shot

        #Here we solve the dual problem:
        #Note that the classes are indexed by m & samples are indexed by i.
        #min_{\alpha}  0.5 \sum_m ||w_m(\alpha)||^2 + \sum_i \sum_m e^m_i alpha^m_i

        #where w_m(\alpha) = \sum_i \alpha^m_i x_i,

        #\alpha is an (total_n_support, n_way) matrix
        kernel_matrix = computeGramMatrix(support, support)
        kernel_matrix += lambda_reg * torch.eye(total_n_support).expand(
            tasks_per_batch, total_n_support, total_n_support).cuda()

        block_kernel_matrix = kernel_matrix.repeat(
            n_way, 1,
            1)  #(n_way * tasks_per_batch, total_n_support, total_n_support)

        support_labels_one_hot = one_hot(
            support_labels.view(tasks_per_batch * total_n_support),
            n_way)  # (tasks_per_batch * total_n_support, n_way)
        support_labels_one_hot = support_labels_one_hot.transpose(
            0, 1)  # (n_way, tasks_per_batch * total_n_support)
        support_labels_one_hot = support_labels_one_hot.reshape(
            n_way * tasks_per_batch,
            total_n_support)  # (n_way*tasks_per_batch, total_n_support)

        G = block_kernel_matrix
        e = -2.0 * support_labels_one_hot

        #This is a fake inequlity constraint as qpth does not support QP without an inequality constraint.
        id_matrix_1 = torch.zeros(tasks_per_batch * n_way, total_n_support,
                                  total_n_support)
        C = Variable(id_matrix_1)
        h = Variable(torch.zeros((tasks_per_batch * n_way, total_n_support)))
        dummy = Variable(torch.Tensor()).cuda(
        )  # We want to ignore the equality constraint.

        if double_precision:
            G, e, C, h = [x.double().cuda() for x in [G, e, C, h]]

        else:
            G, e, C, h = [x.float().cuda() for x in [G, e, C, h]]

        # Solve the following QP to fit SVM:
        #        \hat z =   argmin_z 1/2 z^T G z + e^T z
        #                 subject to Cz <= h
        # We use detach() to prevent backpropagation to fixed variables.
        qp_sol = QPFunction(verbose=False)(G, e.detach(), C.detach(),
                                           h.detach(), dummy.detach(),
                                           dummy.detach())
        #qp_sol = QPFunction(verbose=False)(G, e.detach(), dummy.detach(), dummy.detach(), dummy.detach(), dummy.detach())

        #qp_sol (n_way*tasks_per_batch, total_n_support)
        qp_sol = qp_sol.reshape(n_way, tasks_per_batch, total_n_support)
        #qp_sol (n_way, tasks_per_batch, total_n_support)
        qp_sol = qp_sol.permute(1, 2, 0)
        #qp_sol (tasks_per_batch, total_n_support, n_way)

        # Compute the classification score.
        compatibility = computeGramMatrix(support, query)
        compatibility = compatibility.float()
        compatibility = compatibility.unsqueeze(3).expand(
            tasks_per_batch, total_n_support, total_n_query, n_way)
        qp_sol = qp_sol.reshape(tasks_per_batch, total_n_support, n_way)
        logits = qp_sol.float().unsqueeze(2).expand(tasks_per_batch,
                                                    total_n_support,
                                                    total_n_query, n_way)
        logits = logits * compatibility
        logits = torch.sum(logits, 1) * self._scale

        # compute loss and acc on support
        with torch.no_grad():
            compatibility = computeGramMatrix(support, support)
            compatibility = compatibility.float()
            compatibility = compatibility.unsqueeze(3).expand(
                tasks_per_batch, total_n_support, total_n_support, n_way)
            logits_support = qp_sol.float().unsqueeze(2).expand(
                tasks_per_batch, total_n_support, total_n_support, n_way)
            logits_support = logits_support * compatibility
            logits_support = torch.sum(logits_support, 1)
            logits_support = logits_support.reshape(
                -1, logits_support.size(-1)) * self._scale
            loss = self._inner_loss_func(logits_support,
                                         support_labels.reshape(-1))
            accu = accuracy(logits_support, support_labels.reshape(-1)) * 100.
            measurements_trajectory['loss'].append(loss.item())
            measurements_trajectory['accu'].append(accu)

        return logits, measurements_trajectory
    def inner_loop_adapt(self, support, support_labels, query, n_way, n_shot,
                         n_query):
        """
        Fits the support set with multi-class SVM and 
        returns the classification score on the query set.
        
        This is the multi-class SVM presented in:
        On the Algorithmic Implementation of Multiclass Kernel-based Vector Machines
        (Crammer and Singer, Journal of Machine Learning Research 2001).
        This model is the classification head that we use for the final version.
        Parameters:
        query:  a (tasks_per_batch, n_query, c, h, w) Tensor.
        support:  a (tasks_per_batch, n_support, c, h, w) Tensor.
        support_labels: a (tasks_per_batch, n_support) Tensor.
        n_way: a scalar. Represents the number of classes in a few-shot classification task.
        n_shot: a scalar. Represents the number of support examples given per class.
        C_reg: a scalar. Represents the cost parameter C in SVM.
        Returns: a (tasks_per_batch, n_query, n_way) Tensor.
        """

        measurements_trajectory = defaultdict(list)

        assert (query.dim() == 5)
        assert (support.dim() == 5)

        # get features
        orig_query_shape = query.shape
        orig_support_shape = support.shape
        support = self._model(support.reshape(-1, *orig_support_shape[2:]),
                              only_features=True).reshape(
                                  *orig_support_shape[:2], -1)
        query = self._model(query.reshape(-1, *orig_query_shape[2:]),
                            only_features=True).reshape(
                                *orig_query_shape[:2], -1)

        tasks_per_batch = query.size(0)
        total_n_support = support.size(
            1)  # support samples across all classes in a task
        total_n_query = query.size(
            1)  # query samples across all classes in a task
        d = query.size(2)  # dimension

        C_reg = self._C_reg
        maxIter = self._max_iter

        assert (query.dim() == 3)
        assert (support.dim() == 3)
        assert (query.size(0) == support.size(0)
                and query.size(2) == support.size(2))
        assert (total_n_support == n_way * n_shot
                )  # total_n_support must equal to n_way * n_shot
        assert (total_n_query == n_way * n_query
                )  # total_n_query must equal to n_way * n_query

        #Here we solve the dual problem:
        #Note that the classes are indexed by m & samples are indexed by i.
        #min_{\alpha}  0.5 \sum_m ||w_m(\alpha)||^2 + \sum_i \sum_m e^m_i alpha^m_i
        #s.t.  \alpha^m_i <= C^m_i \forall m,i , \sum_m \alpha^m_i=0 \forall i

        #where w_m(\alpha) = \sum_i \alpha^m_i x_i,
        #and C^m_i = C if m  = y_i,
        #C^m_i = 0 if m != y_i.
        #This borrows the notation of liblinear.

        #\alpha is an (total_n_support, n_way) matrix
        kernel_matrix = computeGramMatrix(support, support)

        id_matrix_0 = torch.eye(n_way).expand(tasks_per_batch, n_way,
                                              n_way).cuda()
        block_kernel_matrix = batched_kronecker(kernel_matrix, id_matrix_0)
        #This seems to help avoid PSD error from the QP solver.
        block_kernel_matrix += 1.0 * torch.eye(n_way * total_n_support).expand(
            tasks_per_batch, n_way * total_n_support,
            n_way * total_n_support).cuda()

        support_labels_one_hot = one_hot(
            support_labels.view(tasks_per_batch * total_n_support), n_way)
        # (tasks_per_batch * total_n_support, n_way)
        support_labels_one_hot = support_labels_one_hot.view(
            tasks_per_batch, total_n_support, n_way)
        support_labels_one_hot = support_labels_one_hot.reshape(
            tasks_per_batch, total_n_support * n_way)
        # (tasks_per_batch, total_n_support * n_way)

        G = block_kernel_matrix
        e = -1.0 * support_labels_one_hot
        #This part is for the inequality constraints:
        #\alpha^m_i <= C^m_i \forall m,i
        #where C^m_i = C if m  = y_i,
        #C^m_i = 0 if m != y_i.
        id_matrix_1 = torch.eye(n_way * total_n_support).expand(
            tasks_per_batch, n_way * total_n_support, n_way * total_n_support)
        C = Variable(id_matrix_1)
        h = Variable(C_reg * support_labels_one_hot)
        #print (C.size(), h.size())
        #This part is for the equality constraints:
        #\sum_m \alpha^m_i=0 \forall i
        id_matrix_2 = torch.eye(total_n_support).expand(
            tasks_per_batch, total_n_support, total_n_support).cuda()

        A = Variable(
            batched_kronecker(id_matrix_2,
                              torch.ones(tasks_per_batch, 1, n_way).cuda()))
        b = Variable(torch.zeros(tasks_per_batch, total_n_support))

        if self._double_precision:
            G, e, C, h, A, b = [x.double().cuda() for x in [G, e, C, h, A, b]]
        else:
            G, e, C, h, A, b = [x.float().cuda() for x in [G, e, C, h, A, b]]

        # Solve the following QP to fit SVM:
        #        \hat z =   argmin_z 1/2 z^T G z + e^T z
        #                 subject to Cz <= h
        # We use detach() to prevent backpropagation to fixed variables.
        qp_sol = QPFunction(verbose=False, maxIter=maxIter)(G, e.detach(),
                                                            C.detach(),
                                                            h.detach(),
                                                            A.detach(),
                                                            b.detach())
        # G is not detached, that is the only one that needs gradients, since its a function of phi(x).

        qp_sol = qp_sol.reshape(tasks_per_batch, total_n_support, n_way)

        # Compute the classification score for query.
        compatibility_query = computeGramMatrix(support, query)
        compatibility_query = compatibility_query.float()
        compatibility_query = compatibility_query.unsqueeze(3).expand(
            tasks_per_batch, total_n_support, total_n_query, n_way)
        logits_query = qp_sol.float().unsqueeze(2).expand(
            tasks_per_batch, total_n_support, total_n_query, n_way)
        logits_query = logits_query * compatibility_query
        logits_query = torch.sum(logits_query, 1) * self._scale

        # Compute the classification score for support.
        with torch.no_grad():
            compatibility_support = computeGramMatrix(support, support)
            compatibility_support = compatibility_support.float()
            compatibility_support = compatibility_support.unsqueeze(3).expand(
                tasks_per_batch, total_n_support, total_n_support, n_way)
            logits_support = qp_sol.float().unsqueeze(2).expand(
                tasks_per_batch, total_n_support, total_n_support, n_way)
            logits_support = logits_support * compatibility_support
            logits_support = torch.sum(logits_support, 1) * self._scale

        # compute loss and acc on support
        logits_support = logits_support.reshape(-1, logits_support.size(-1))
        labels_support = support_labels.reshape(-1)

        loss = self._inner_loss_func(logits_support, labels_support)
        accu = accuracy(logits_support, labels_support)
        measurements_trajectory['loss'].append(loss.item())
        measurements_trajectory['accu'].append(accu)

        return logits_query, measurements_trajectory
Exemple #9
0
    def CMloss_Fnorm(self, query, support, support_labels, n_way, n_shot):

        tasks_per_batch = query.size(0)
        n_support = support.size(1)
        n_query = query.size(1)

        assert (query.dim() == 3)
        assert (support.dim() == 3)
        assert (query.size(0) == support.size(0) and query.size(2) == support.size(2))
        assert (n_support == n_way * n_shot)  # n_support must equal to n_way * n_shot

        # Here we solve the dual problem:
        # Note that the classes are indexed by m & samples are indexed by i.
        # min_{\alpha}  0.5 \sum_m ||w_m(\alpha)||^2 + \sum_i \sum_m e^m_i alpha^m_i
        # s.t.  \alpha^m_i <= C^m_i \forall m,i , \sum_m \alpha^m_i=0 \forall i

        # where w_m(\alpha) = \sum_i \alpha^m_i x_i,
        # and C^m_i = C if m  = y_i,
        # C^m_i = 0 if m != y_i.
        # This borrows the notation of liblinear.

        # \alpha is an (n_support, n_way) matrix
        kernel_matrix = computeGramMatrix(support, support)

        id_matrix_0 = torch.eye(n_way).expand(tasks_per_batch, n_way, n_way).cuda()
        block_kernel_matrix = batched_kronecker(kernel_matrix, id_matrix_0)
        # This seems to help avoid PSD error from the QP solver.
        block_kernel_matrix += 1.0 * torch.eye(n_way * n_support).expand(tasks_per_batch, n_way * n_support,
                                                                         n_way * n_support).cuda()

        support_labels_one_hot = one_hot(support_labels.view(tasks_per_batch * n_support),
                                         n_way)  # (tasks_per_batch * n_support, n_support)
        support_labels_one_hot = support_labels_one_hot.view(tasks_per_batch, n_support, n_way)
        support_labels_one_hot = support_labels_one_hot.reshape(tasks_per_batch, n_support * n_way)

        G = block_kernel_matrix
        e = -1.0 * support_labels_one_hot
        # print (G.size())
        # This part is for the inequality constraints:
        # \alpha^m_i <= C^m_i \forall m,i
        # where C^m_i = C if m  = y_i,
        # C^m_i = 0 if m != y_i.
        id_matrix_1 = torch.eye(n_way * n_support).expand(tasks_per_batch, n_way * n_support, n_way * n_support)
        C = Variable(id_matrix_1)
        h = Variable(self.C_reg * support_labels_one_hot)
        # print (C.size(), h.size())
        # This part is for the equality constraints:
        # \sum_m \alpha^m_i=0 \forall i
        id_matrix_2 = torch.eye(n_support).expand(tasks_per_batch, n_support, n_support).cuda()

        A = Variable(batched_kronecker(id_matrix_2, torch.ones(tasks_per_batch, 1, n_way).cuda()))
        b = Variable(torch.zeros(tasks_per_batch, n_support))
        # print (A.size(), b.size())
        if self.double_precision:
            G, e, C, h, A, b = [x.double().cuda() for x in [G, e, C, h, A, b]]
        else:
            G, e, C, h, A, b = [x.float().cuda() for x in [G, e, C, h, A, b]]

        # Solve the following QP to fit SVM:
        #        \hat z =   argmin_z 1/2 z^T G z + e^T z
        #                 subject to Cz <= h
        # We use detach() to prevent backpropagation to fixed variables.
        qp_sol = QPFunction(verbose=False, maxIter=self.maxIter)(G, e.detach(), C.detach(), h.detach(), A.detach(),
                                                                 b.detach())

        # Compute the classification score.
        qp_sol = qp_sol.reshape(tasks_per_batch, n_support, n_way)
        logits = qp_sol.float().unsqueeze(2).expand(tasks_per_batch, n_support, n_query, n_way)
        w_query = torch.bmm(qp_sol.transpose(1,2), support)

        dis_matrix = self.support_w - w_query
        revese_loss = 0
        for i in range(tasks_per_batch):
            revese_loss += (torch.trace(torch.mm(dis_matrix[i,:,:], dis_matrix[i,:,:].t()))).sqrt()

        revese_loss = 1.0 * revese_loss / (n_way* tasks_per_batch)


        return revese_loss
Exemple #10
0
def MetaOptNetHead_OC_SVM_batched(query,
                                  support,
                                  support_labels,
                                  n_way,
                                  n_shot,
                                  nu=0.1,
                                  double_precision=True,
                                  maxIter=40):
    """
    Fits the support set with OC-SVM and
    returns the classification score on the query set.

    Parameters:
      query:  a (tasks_per_batch, n_query, d) Tensor.
      support:  a (tasks_per_batch, n_support, d) Tensor.
      support_labels: a (tasks_per_batch, n_support) Tensor.
      n_way: a scalar. Represents the number of classes in a few-shot classification task.
      n_shot: a scalar. Represents the number of support examples given per class.
      C_reg: a scalar. Represents the cost parameter C in SVM.
    Returns: a (tasks_per_batch, n_query, n_way) Tensor.
    """

    tasks_per_batch = query.size(0)
    n_support = support.size(1)
    n_query = query.size(1)

    assert (query.dim() == 3)
    assert (support.dim() == 3)
    assert (query.size(0) == support.size(0)
            and query.size(2) == support.size(2))
    assert (n_support == n_shot)  # n_support must equal to n_shot

    # we solve the dual problem

    # OCC
    n_way = 1

    kernel_matrix = computeGramMatrix(support, support)
    Q = kernel_matrix.cuda()
    # #This seems to help avoid PSD error from the QP solver. (as done in the original MetaOptNet with SVM head)

    Q_spd = 1.0 * torch.eye(n_support).expand(tasks_per_batch, n_support,
                                              n_support).cuda()
    Q += Q_spd

    p = Variable(torch.zeros((tasks_per_batch, n_support)))
    A = Variable(torch.ones((tasks_per_batch, 1, n_support)))
    b = Variable(torch.ones(tasks_per_batch, 1))
    G = Variable(
        torch.cat([-1.0 * torch.eye(n_support),
                   torch.eye(n_support)],
                  dim=0).expand(tasks_per_batch, 2 * n_support, n_support))
    h_task = torch.cat([
        torch.zeros((n_support, 1)).cuda(),
        (1 / (nu * n_support)) * torch.ones((n_support, 1)).cuda()
    ],
                       dim=0)
    h = torch.squeeze(
        Variable(h_task.expand(tasks_per_batch, 2 * n_support, 1)))

    if double_precision:
        Q, p, G, h, A, b = [x.double().cuda() for x in [Q, p, G, h, A, b]]
    else:
        Q, p, G, h, A, b = [x.float().cuda() for x in [Q, p, G, h, A, b]]

    qp_sol = QPFunction(verbose=False, maxIter=maxIter)(Q, p.detach(),
                                                        G.detach(), h.detach(),
                                                        A.detach(), b.detach())

    # Compute the classification score.

    if double_precision:
        qp_sol = qp_sol.float()
    w = torch.bmm(qp_sol.reshape((tasks_per_batch, 1, n_support)), support)

    logits_tmp = torch.bmm(query, w.transpose(1, 2))
    logits = []
    for i in range(tasks_per_batch):
        S = (qp_sol[i] > 1e-3).flatten()
        all_ro = torch.mm(support[i][S], w[i].transpose(0, 1))
        ro = all_ro[0]
        logit = torch.squeeze(logits_tmp[i] - ro)
        logits.append(logit)
    logits = torch.stack(logits)
    return logits