Пример #1
0
    def loss(self,
             logits,
             label,
             support,
             class_name,
             NPM,
             isQ=False,
             support_weights=None):
        if support_weights is None:
            if isQ is True:
                loss_ce = self.cost(logits, label) / self.L
            else:
                loss_ce = self.cost(logits, label) / self.k_shot
        else:
            logits_softmax = F.softmax(logits, dim=-1).log()  # [N*K, N]
            support_weights_tensor = torch.from_numpy(support_weights).view(
                -1, 1).cuda()  # [N*K, 1]
            logits_times_weights = logits_softmax * support_weights_tensor
            loss_ce = self.nllloss(logits_times_weights, label)

        if NPM is True:
            loss_npm = torch.tensor(0.0, requires_grad=True)
            if isQ is True:
                support_N = support.view((self.n_way, self.L, 256))
            else:
                if support_weights is not None:
                    support_weights = support_weights.reshape(
                        (self.n_way, self.k_shot))
                support_N = support.view((self.n_way, self.k_shot, 256))
            for i, s in enumerate(support_N):
                dist = -neg_dist(s, class_name) / torch.mean(
                    -neg_dist(s, class_name), dim=0)  # [K, N]
                for j, d in enumerate(dist):
                    loss_npm_temp = torch.tensor(0.0, requires_grad=True)
                    for k, di in enumerate(d):
                        loss_npm_temp = loss_npm_temp + torch.exp(d[i] - di)
                    if support_weights is not None:
                        loss_npm = loss_npm + support_weights[i][
                            j] * torch.log(loss_npm_temp)
                    else:
                        if isQ is True:
                            loss_npm = loss_npm + torch.log(
                                loss_npm_temp) / self.n_way / self.L
                        else:
                            loss_npm = loss_npm + torch.log(
                                loss_npm_temp) / self.n_way / self.k_shot

            # print("loss_ce: ", loss_ce, "loss_npm: ", loss_npm * self.lam)
            return loss_ce + self.lam * loss_npm
        else:
            # print("loss_ce: ", loss_ce)
            return loss_ce
Пример #2
0
def train_q(args,
            class_name0,
            query0,
            query_label,
            mymodel_clone,
            zero_shot=False):

    N = mymodel_clone.n_way
    if zero_shot:
        K = 0
    else:
        K = mymodel_clone.k_shot
    # support = mymodel.coder(support0)  # [N*K, 768*2]
    # query1 = mymodel.coder(query0)  # [L*N, 768*2]
    # class_name1 = mymodel.coder(class_name0, is_classname=True)  # [N, 768]
    query1 = mymodel_clone.coder(query0)  # [L*N, 768*2]
    # query1 = None
    class_name1 = mymodel_clone.coder(class_name0,
                                      is_classname=True)  # [N, 768]

    class_name = mymodel_clone(class_name1, is_classname=True)  # ->[N, 256]
    query = mymodel_clone(query1)  # ->[L*N, 256]
    logits = neg_dist(query, class_name)  # -> [L*N, N]
    logits = -logits / torch.mean(logits, dim=0)
    _, pred = torch.max(logits, 1)

    loss_q = mymodel_clone.loss(logits,
                                query_label.view(-1),
                                query,
                                class_name,
                                NPM=args.NPM_Loss,
                                isQ=True)
    right_q = mymodel_clone.accuracy(pred, query_label)

    return loss_q, right_q
Пример #3
0
def SupportWeight(mymodel, support, query1):
    with torch.no_grad():
        query = mymodel(query1)  # ->[N*L, 256]  support:[N*K, 256]
        dist = neg_dist(support, query)  # ->[N*K, N*L]
        dist = torch.sum(dist, dim=1).view(-1)  # [N*K, ]
        dist = -dist / torch.mean(dist)
        weights = F.softmax(dist)  # [N*K, ]
        weights = weights.cpu().numpy()

    return weights
Пример #4
0
def pre_calculate(mymodel, args):
    cuda = torch.cuda.is_available()
    if cuda:
        mymodel = mymodel.cuda()

    mymodel.eval()
    with torch.no_grad():
        file_name = args.train
        file_class = args.class_name_file
        json_data = json.load(open(file_name, 'r'))
        json_class_name = json.load(open(
            file_class, 'r'))  # {"P931":[...], "P903":[...], ...}
        classes = list(json_data.keys())  # P931
        class_name_list = []
        class_names_final = {}
        json_class_name_prob = {}
        json_class_name_dist = {}
        for i, class_name in enumerate(json_data):
            class_name_list.append(class_name)
            val = json_class_name[class_name]
            class_name_val = val[0] + "it means " + val[1]
            class_name_val = [class_name_val.split()]
            # class_names_final[class_name] = class_name_val
            class_name_val = mymodel.coder(class_name_val,
                                           is_classname=True)  # [1, 768]
            if i == 0:
                class_name_val_list = class_name_val
            else:
                class_name_val_list = torch.cat(
                    (class_name_val_list, class_name_val), dim=0)
            # print(class_name_val.shape, type(class_name_val))
            # class_name_val_cpu = class_name_val.cpu()
            # del class_name_val
            class_names_final[class_name] = class_name_val
            # print(class_names_final[class_name].shape, type(class_names_final[class_name]))

        dist_metrix = neg_dist(class_name_val_list,
                               class_name_val_list)  # [N, N]

        for i, d in enumerate(dist_metrix):
            if i == 0:
                dist_metrix_nodiag = del_tensor_ele(d, i).view((1, -1))
            else:
                dist_metrix_nodiag = torch.cat(
                    (dist_metrix_nodiag, del_tensor_ele(d, i).view((1, -1))),
                    dim=0)

        prob_metrix = F.softmax(dist_metrix_nodiag, dim=1)  # [N, N]
        prob_metrix = prob_metrix.cpu().numpy()

        return prob_metrix
Пример #5
0
def pre_calculate(mymodel, args):
    cuda = torch.cuda.is_available()
    if cuda:
        mymodel = mymodel.cuda()

    mymodel.eval()
    with torch.no_grad():
        file_name = args.train
        file_class = args.class_name_file
        json_data = json.load(open(file_name, 'r'))
        json_class_name = json.load(open(file_class, 'r'))  # {"P931":[...], "P903":[...], ...}
        classes = list(json_data.keys())  # P931
        class_name_list = []
        class_names_final = {}
        json_class_name_prob = {}
        json_class_name_dist = {}
        for i, class_name in enumerate(json_data):
            class_name_list.append(class_name)
            val = json_class_name[class_name]
            class_name_val = val[0] + "it means " + val[1]
            class_name_val = [class_name_val.split()]
            # class_names_final[class_name] = class_name_val
            class_name_val = mymodel.coder(class_name_val, is_classname=True)  # [1, 768]
            if i == 0:
                class_name_val_list = class_name_val
            else:
                class_name_val_list = torch.cat((class_name_val_list, class_name_val), dim=0)
            # print(class_name_val.shape, type(class_name_val))
            # class_name_val_cpu = class_name_val.cpu()
            # del class_name_val
            class_names_final[class_name] = class_name_val
            # print(class_names_final[class_name].shape, type(class_names_final[class_name]))

        # class_names_val_list:[N, 768]
        for i in range(class_name_val_list.shape[0]):
            rel = json_data[class_name_list[i]]
            examples = mymodel.coder(rel)  # [700, 768*2]
            examples1 = examples[:, :768]
            examples2 = examples[:, 768:]
            examples = (examples1 + examples2)/2  # [700, 768]
            cnv = class_name_val_list[i].view((1, -1))  # [1, 768]
            dist = -neg_dist(cnv, examples)  # [1, 700]
            dist = F.softmax(dist)
            dist = dist.cpu().numpy()
            if i == 0:
                dist_list = dist
            else:
                dist_list = np.concatenate((dist_list, dist), axis=0)

        return dist_list  # [N个类,700]
Пример #6
0
def train_one_batch(args,
                    class_name0,
                    support0,
                    support_label,
                    query0,
                    query_label,
                    mymodel,
                    task_lr,
                    it,
                    zero_shot=False):

    N = mymodel.n_way
    if zero_shot:
        K = 0
    else:
        K = mymodel.k_shot
    support = mymodel.coder(support0)  # [N*K, 768*2]
    query1 = mymodel.coder(query0)  # [L*N, 768*2]
    # query1 = None
    class_name1 = mymodel.coder(class_name0, is_classname=True)  # [N, 768]

    class_name = mymodel(class_name1, is_classname=True)  # ->[N, 256]
    support = mymodel(support)  # ->[N*K, 256]
    logits = neg_dist(support, class_name)  # -> [N*K, N]
    logits = -logits / torch.mean(logits, dim=0)
    _, pred = torch.max(logits, 1)

    if args.SW is True:
        support_weights = SupportWeight(mymodel, support, query1)
        loss_s = mymodel.loss(logits,
                              support_label.view(-1),
                              support,
                              class_name,
                              NPM=args.NPM_Loss,
                              support_weights=support_weights)
    else:
        loss_s = mymodel.loss(logits,
                              support_label.view(-1),
                              support,
                              class_name,
                              NPM=args.NPM_Loss)
    right_s = mymodel.accuracy(pred, support_label)

    return loss_s, right_s, query1, class_name1