예제 #1
0
 def test_process(self, testset):
     args = self.args
     record = np.zeros((args.num_test_episodes, 2))  # loss and acc
     label = torch.arange(args.eval_way, dtype=torch.int16).repeat(
         # args.num_tasks *
         args.eval_query)
     label = label.type(torch.LongTensor)
     if torch.cuda.is_available():
         label = label.cuda()
     test_sampler = CategoriesSampler(
         testset.label,
         args.num_test_episodes,  # args.num_eval_episodes,
         args.eval_way,
         args.eval_shot + args.eval_query)
     test_loader = DataLoader(dataset=testset,
                              batch_sampler=test_sampler,
                              num_workers=args.num_workers,
                              pin_memory=True)
     for i, batch in tqdm(enumerate(test_loader, 1),
                          total=len(test_loader)):
         data = batch[0]
         data = data.to(self.args.device)
         logits = self.model(data)
         loss = F.cross_entropy(logits, label)
         acc = count_acc(logits, label)
         record[i - 1, 0] = loss.item()
         record[i - 1, 1] = acc
     assert (i == record.shape[0])
     vl, _ = compute_confidence_interval(record[:, 0])
     va, vap = compute_confidence_interval(record[:, 1])
     print('{} way {} shot,Test acc={:.4f} + {:.4f}\n'.format(
         args.eval_way, args.eval_shot, va, vap))
     return vl, va, vap
예제 #2
0
    def evaluate_fsl(self, data_loader):
        # restore model args
        args = self.args
        # evaluation mode
        self.model.eval()
        record = np.zeros((args.num_eval_episodes, 2))  # loss and acc
        label = torch.arange(args.eval_way,
                             dtype=torch.int16).repeat(args.eval_query)
        label = label.type(torch.LongTensor)
        if torch.cuda.is_available():
            label = label.cuda()
        print('{} best epoch {}, best val acc={:.4f} + {:.4f}'.format(
            args.test_mode, self.trlog['max_acc_epoch'], self.trlog['max_acc'],
            self.trlog['max_acc_interval']))
        with torch.no_grad():
            for i, batch in tqdm(enumerate(data_loader, 1)):
                if torch.cuda.is_available():
                    data, _ = [_.cuda() for _ in batch]
                else:
                    data = batch[0]

                p = args.eval_shot * args.eval_way
                data_shot, data_query = data[:p], data[p:]
                logits = self.model.forward_fsl(data_shot, data_query)
                loss = F.cross_entropy(logits, label)
                acc = count_acc(logits, label)
                record[i - 1, 0] = loss.item()
                record[i - 1, 1] = acc
        assert (i == record.shape[0])
        vl, _ = compute_confidence_interval(record[:, 0])
        va, vap = compute_confidence_interval(record[:, 1])

        # train mode
        self.model.train()
        return vl, va, vap
예제 #3
0
    def evaluate_gfsl(self, fsl_loader, gfsl_loader, gfsl_dataset):
        # restore model args
        args = self.args
        # evaluation mode
        self.model.eval()
        record = np.zeros((args.num_eval_episodes, 5))  # loss and acc
        label_unseen_query = torch.arange(
            min(args.eval_way,
                self.valset.num_class)).repeat(args.eval_query).long()
        if torch.cuda.is_available():
            label_unseen_query = label_unseen_query.cuda()
        print('{} best epoch {}, best val acc={:.4f} + {:.4f}'.format(
            args.test_mode, self.trlog['max_acc_epoch'], self.trlog['max_acc'],
            self.trlog['max_acc_interval']))
        with torch.no_grad():
            for i, batch in tqdm(enumerate(zip(gfsl_loader, fsl_loader), 1)):
                if torch.cuda.is_available():
                    data_seen, data_unseen, seen_label, unseen_label = batch[
                        0][0].cuda(), batch[1][0].cuda(), batch[0][1].cuda(
                        ), batch[1][1].cuda()
                else:
                    data_seen, data_unseen, seen_label, unseen_label = batch[
                        0][0], batch[1][0], batch[0][1], batch[1][1]

                p2 = args.eval_shot * args.eval_way
                data_unseen_shot, data_unseen_query = data_unseen[:
                                                                  p2], data_unseen[
                                                                      p2:]
                label_unseen_shot, _ = unseen_label[:p2], unseen_label[p2:]
                whole_query = torch.cat([data_seen, data_unseen_query], 0)
                whole_label = torch.cat(
                    [seen_label, label_unseen_query + gfsl_dataset.num_class])
                logits_s, logits_u = self.model.forward_generalized(
                    data_unseen_shot, whole_query)
                # compute un-biased accuracy
                new_logits = torch.cat([logits_s, logits_u], 1)
                record[i - 1, 0] = F.cross_entropy(new_logits,
                                                   whole_label).item()
                record[i - 1, 1] = count_acc(new_logits, whole_label)
                # compute harmonic mean
                HM_nobias, SA_nobias, UA_nobias = count_acc_harmonic_low_shot_joint(
                    torch.cat([logits_s, logits_u], 1), whole_label,
                    seen_label.shape[0])
                record[i - 1, 2:] = np.array([HM_nobias, SA_nobias, UA_nobias])
                del logits_s, logits_u, new_logits
                torch.cuda.empty_cache()

        assert (i == record.shape[0])
        vl, _ = compute_confidence_interval(record[:, 0])
        va, vap = compute_confidence_interval(record[:, 2])

        # train mode
        self.model.train()
        return vl, va, vap
예제 #4
0
    def evaluate_fsl(self):
        args = self.args
        record = np.zeros((args.num_eval_episodes, 2))  # loss and acc
        label = torch.arange(args.eval_way,
                             dtype=torch.int16).repeat(args.eval_query)
        label = label.type(torch.LongTensor)
        if torch.cuda.is_available():
            label = label.cuda()

        for i, batch in tqdm(enumerate(self.test_fsl_loader, 1)):
            if torch.cuda.is_available():
                data, _ = [_.cuda() for _ in batch]
            else:
                data = batch[0]
            p2 = args.eval_shot * args.eval_way

            p = args.eval_shot * args.eval_way
            data_shot, data_query = data[:p], data[p:]
            with torch.no_grad():
                logits = self.model.forward_fsl(data_shot, data_query)
            loss = F.cross_entropy(logits, label)
            acc = count_acc(logits, label)
            record[i - 1, 0] = loss.item()
            record[i - 1, 1] = acc

            del loss, logits
            torch.cuda.empty_cache()
        assert (i == record.shape[0])

        print('-'.join([args.model_class, args.model_path]))
        self.trlog['acc_mean'], self.trlog[
            'acc_interval'] = compute_confidence_interval(record[:, 1])
        print('FSL {}-way Acc {:.5f} + {:.5f}'.format(
            args.eval_way, self.trlog['acc_mean'], self.trlog['acc_interval']))
예제 #5
0
    def evaluate_test(self):
        # restore model args
        args = self.args
        # evaluation mode
        self.model.load_state_dict(torch.load(osp.join(self.args.save_path, 'max_acc.pth'))['params'])
        self.model.eval()
        record = np.zeros((10000, 2)) # loss and acc
        label = torch.arange(args.eval_way, dtype=torch.int16).repeat(args.eval_query)
        label = label.type(torch.LongTensor)
        if torch.cuda.is_available():
            label = label.cuda()
        print('best epoch {}, best val acc={:.4f} + {:.4f}'.format(
                self.trlog['max_acc_epoch'],
                self.trlog['max_acc'],
                self.trlog['max_acc_interval']))
        with torch.no_grad():
            for i, batch in tqdm(enumerate(self.test_loader, 1)):
                if torch.cuda.is_available():
                    data, _ = [_.cuda() for _ in batch]
                else:
                    data = batch[0]

                logits = self.model(data)
                loss = F.cross_entropy(logits, label)
                acc = count_acc(logits, label)
                record[i-1, 0] = loss.item()
                record[i-1, 1] = acc
        assert(i == record.shape[0])
        vl, _ = compute_confidence_interval(record[:,0])
        va, vap = compute_confidence_interval(record[:,1])
        
        self.trlog['test_acc'] = va
        self.trlog['test_acc_interval'] = vap
        self.trlog['test_loss'] = vl

        print('best epoch {}, best val acc={:.4f} + {:.4f}\n'.format(
                self.trlog['max_acc_epoch'],
                self.trlog['max_acc'],
                self.trlog['max_acc_interval']))
        print('Test acc={:.4f} + {:.4f}\n'.format(
                self.trlog['test_acc'],
                self.trlog['test_acc_interval']))

        return vl, va, vap
예제 #6
0
    def evaluate(self, data_loader):
        # restore model args
        args = self.args
        # evaluation mode
        self.model.eval()
        record = np.zeros((args.num_eval_episodes, 2))  # loss and acc
        label = torch.arange(args.eval_way, dtype=torch.int16).repeat(
            # args.num_tasks *
            args.eval_query)
        label = label.type(torch.LongTensor)
        if torch.cuda.is_available():
            label = label.cuda()
        with torch.no_grad():
            for i, batch in tqdm(enumerate(data_loader, 1),
                                 total=len(data_loader),
                                 desc='eval procedure'):
                if torch.cuda.is_available():
                    data, _ = [_.cuda() for _ in batch]
                else:
                    data = batch[0]

                logits = self.model(data)
                loss = F.cross_entropy(logits, label)
                acc = count_acc(logits, label)
                record[i - 1, 0] = loss.item()
                record[i - 1, 1] = acc

        assert (i == record.shape[0])
        vl, _ = compute_confidence_interval(record[:, 0])
        va, vap = compute_confidence_interval(record[:, 1])

        # train mode
        self.model.train()
        if self.args.fix_BN:
            self.model.encoder.eval()
        return vl, va, vap
예제 #7
0
    def evaluate_test(self, use_max_tst=False):
        # restore model args
        args = self.args
        # evaluation mode

        if use_max_tst:
            assert args.tst_criterion != '', 'Please specify a criterion'
            fname = osp.join(self.args.save_path, 'max_tst_criterion.pth')
            criterion = args.tst_criterion
            max_acc_epoch = 'max_tst_criterion_epoch'
            max_acc = 'max_tst_criterion'
            max_acc_interval = 'max_tst_criterion_interval'
            test_acc = 'test_acc_at_max_criterion'
            test_acc_interval = 'test_acc_interval_at_max_criterion'
            test_loss = 'test_loss_at_max_criterion'
        else:
            fname = osp.join(self.args.save_path, 'max_acc.pth')
            criterion = 'SupervisedAcc'
            max_acc_epoch = 'max_acc_epoch'
            max_acc = 'max_acc'
            max_acc_interval = 'max_acc_interval'
            test_acc = 'test_acc'
            test_acc_interval = 'test_acc_interval'
            test_loss = 'test_loss'
        print('\nCriterion selected: {}'.format(criterion))
        print('Reloading model from {}'.format(fname))
        self.model.load_state_dict(torch.load(fname)['params'])

        self.model.eval()
        record = np.zeros((10000, 2)) # loss and acc
        metrics = defaultdict(list)  # all other metrics
        label = torch.arange(args.eval_way, dtype=torch.int16).repeat(args.eval_query)
        label = label.type(torch.LongTensor)
        if torch.cuda.is_available():
            label = label.cuda()
        all_labels = torch.arange(args.eval_way, device=label.device).repeat(args.eval_shot + args.eval_query)

        max_validation_str = 'Maximum value of valid_{} {:.4f} + {:.4f} reached at Epoch {}\n'.format(
                criterion,
                self.trlog[max_acc],
                self.trlog[max_acc_interval],
                 self.trlog[max_acc_epoch])
        print(max_validation_str)

        with torch.no_grad():
            for i, batch in tqdm(enumerate(self.test_loader, 1)):
                if torch.cuda.is_available():
                    data, _ = [_.cuda() for _ in batch]
                else:
                    data = batch[0]

                embeddings, logits = self.model(data, return_feature=True)
                loss = F.cross_entropy(logits, label)
                acc = count_acc(logits, label)
                record[i-1, 0] = loss.item()
                record[i-1, 1] = acc

                if args.tst_free:

                    embeddings_dict = self.model.get_embeddings_dict(embeddings, all_labels)

                    # TST-free part
                    clustering_losses = tst_free.clustering_loss(embeddings_dict, args.sinkhorn_reg, 'wasserstein',
                                                                 sqrt_temperature=np.sqrt(args.temperature),
                                                                 normalize_by_dim=False,
                                                                 clustering_iterations=20, sinkhorn_iterations=20,
                                                                 sinkhorn_iterations_warmstart=4,
                                                                 sanity_check=False)

                    for key, val in clustering_losses.items():
                        metrics[key].append(val)

                if args.debug_fast:
                    print('Debug fast, breaking TEST after 1 mini-batch')
                    record = record[:1]
                    break

        assert(i == record.shape[0])
        vl, _ = compute_confidence_interval(record[:,0])
        va, vap = compute_confidence_interval(record[:,1])
        metric_summaries = {key: compute_confidence_interval(val) for key, val in metrics.items()}

        self.trlog[test_acc] = va
        self.trlog[test_acc_interval] = vap
        self.trlog[test_loss] = vl

        summary_lines = []
        summary_lines.append(max_validation_str)
        summary_lines.append('test_SupervisedAcc {:.4f} + {:.4f} (ep{})'.format(
                self.trlog[test_acc],
                self.trlog[test_acc_interval],
                self.trlog[max_acc_epoch]))
        for key, (mean, std) in metric_summaries.items():
            summary_lines.append('test_{} {:.4f} + {:.4f} (ep{})'.format(key, mean, std, self.trlog[max_acc_epoch]))

        #self.print_metric_summaries(metric_summaries, prefix='\ttest_')
        #self.log_metric_summaries(metric_summaries, 0, prefix='test_')
        self.trlog['TST'] = metric_summaries

        summary_lines_str = '\n'.join(summary_lines)
        print('\n{}'.format(summary_lines_str))

        with open(osp.join(self.args.save_path, 'summary_max_{}.txt'.format(criterion)), 'w') as f:
            f.write(summary_lines_str)
예제 #8
0
    def evaluate(self, data_loader):
        # restore model args
        args = self.args
        # evaluation mode
        self.model.eval()

        accuracies = []
        losses = []

        metrics = OrderedDict()
        label = torch.arange(args.eval_way, dtype=torch.int16).repeat(args.eval_query)
        label = label.type(torch.LongTensor)
        if torch.cuda.is_available():
            label = label.cuda()
        all_labels = torch.arange(args.eval_way, device=label.device).repeat(args.eval_shot + args.eval_query)
        #print('best epoch {}, best val acc={:.4f} + {:.4f}'.format(
        #        self.trlog['max_acc_epoch'],
        #        self.trlog['max_acc'],
        #        self.trlog['max_acc_interval']))
        with torch.no_grad():
            for i, batch in enumerate(data_loader, 1):
                if torch.cuda.is_available():
                    data, _ = [_.cuda() for _ in batch]
                else:
                    data = batch[0]

                embeddings, logits = self.model(data, return_feature=True)

                if args.tst_free:
                    embeddings_logits_dict = self.model.get_embeddings_dict(embeddings, all_labels, logits)

                    for sinkhorn_reg_str in args.sinkhorn_reg:  # loop over all possible regularizations
                        sinkhorn_reg_float = float(sinkhorn_reg_str)

                        transductive_losses = tst_free.transductive_from_logits(embeddings_logits_dict,
                                                                                   regularization=sinkhorn_reg_float)

                        for key, val in transductive_losses.items():
                            key += '_reg{}'.format(sinkhorn_reg_str)
                            metrics.setdefault(key, [])
                            metrics[key].append(val)

                # data contains both support and query sets (typically 25+75 for 5-shot 5-way 15-query)
                loss = F.cross_entropy(logits, label)
                acc = count_acc(logits, label)
                losses.append(loss.item())
                accuracies.append(acc)

                if args.tst_free:

                    # Also do the transductive based on the logits


                    embeddings_dict = self.model.get_embeddings_dict(embeddings, all_labels)

                    # TST-free part
                    for sinkhorn_reg_str in args.sinkhorn_reg:  # loop over all possible regularizations
                        sinkhorn_reg_float = float(sinkhorn_reg_str)
                        clustering_losses = tst_free.clustering_loss(embeddings_dict, sinkhorn_reg_float, 'wasserstein',
                                                                     sqrt_temperature=np.sqrt(args.temperature),
                                                                     normalize_by_dim=False,
                                                                     clustering_iterations=20, sinkhorn_iterations=20,
                                                                     sinkhorn_iterations_warmstart=4,
                                                                     sanity_check=False)

                        for key, val in clustering_losses.items():
                            key += '_reg{}'.format(sinkhorn_reg_str)
                            metrics.setdefault(key, [])
                            metrics[key].append(val)

                if args.debug_fast:
                    print('Debug fast, breaking eval after 1 mini-batch')
                    break

        assert(i == len(losses) and i == len(accuracies))
        vl, _ = compute_confidence_interval(losses)
        va, vap = compute_confidence_interval(accuracies)
        metric_summaries = {key: compute_confidence_interval(val) for key, val in metrics.items()}

        # train mode
        self.model.train()
        if self.args.fix_BN:
            self.model.encoder.eval()

        if args.tst_free:
            return vl, va, vap, metric_summaries
        else:
            return vl, va, vap
예제 #9
0
    def evaluate_test(self):
        # restore model args
        args = self.args
        # evaluation mode
        self.model.load_state_dict(
            torch.load(osp.join(self.args.save_path, 'max_acc.pth'))['params'])
        self.model.eval()

        if args.test_mode == 'FSL':
            record = np.zeros((10000, 2))  # loss and acc
            label = torch.arange(args.eval_way).repeat(args.eval_query).type(
                torch.LongTensor)
            if torch.cuda.is_available():
                label = label.cuda()
            with torch.no_grad():
                for i, batch in enumerate(self.test_fsl_loader, 1):
                    if torch.cuda.is_available():
                        data, _ = [_.cuda() for _ in batch]
                    else:
                        data = batch[0]

                    p = args.eval_shot * args.eval_way
                    data_shot, data_query = data[:p], data[p:]
                    logits = self.model.forward_fsl(data_shot, data_query)
                    loss = F.cross_entropy(logits, label)
                    acc = count_acc(logits, label)
                    record[i - 1, 0] = loss.item()
                    record[i - 1, 1] = acc
            assert (i == record.shape[0])
            vl, _ = compute_confidence_interval(record[:, 0])
            va, vap = compute_confidence_interval(record[:, 1])

            self.trlog['test_acc'] = va
            self.trlog['test_acc_interval'] = vap
            self.trlog['test_loss'] = vl

            print('best epoch {}, best val acc={:.4f} + {:.4f}\n'.format(
                self.trlog['max_acc_epoch'], self.trlog['max_acc'],
                self.trlog['max_acc_interval']))
            print('Test acc={:.4f} + {:.4f}\n'.format(
                self.trlog['test_acc'], self.trlog['test_acc_interval']))

        else:
            record = np.zeros((10000, 5))  # loss and acc
            label_unseen_query = torch.arange(
                min(args.eval_way,
                    self.valset.num_class)).repeat(args.eval_query).long()
            if torch.cuda.is_available():
                label_unseen_query = label_unseen_query.cuda()
            with torch.no_grad():
                for i, batch in tqdm(
                        enumerate(
                            zip(self.test_gfsl_loader, self.test_fsl_loader),
                            1)):
                    if torch.cuda.is_available():
                        data_seen, data_unseen, seen_label, unseen_label = batch[
                            0][0].cuda(), batch[1][0].cuda(), batch[0][1].cuda(
                            ), batch[1][1].cuda()
                    else:
                        data_seen, data_unseen, seen_label, unseen_label = batch[
                            0][0], batch[1][0], batch[0][1], batch[1][1]
                    p2 = args.eval_shot * args.eval_way
                    data_unseen_shot, data_unseen_query = data_unseen[:
                                                                      p2], data_unseen[
                                                                          p2:]
                    label_unseen_shot, _ = unseen_label[:p2], unseen_label[p2:]
                    whole_query = torch.cat([data_seen, data_unseen_query], 0)
                    whole_label = torch.cat([
                        seen_label,
                        label_unseen_query + self.traintestset.num_class
                    ])
                    logits_s, logits_u = self.model.forward_generalized(
                        data_unseen_shot, whole_query)
                    # compute un-biased accuracy
                    new_logits = torch.cat([logits_s, logits_u], 1)
                    record[i - 1, 0] = F.cross_entropy(new_logits,
                                                       whole_label).item()
                    record[i - 1, 1] = count_acc(new_logits, whole_label)
                    # compute harmonic mean
                    HM_nobias, SA_nobias, UA_nobias = count_acc_harmonic_low_shot_joint(
                        torch.cat([logits_s, logits_u], 1), whole_label,
                        seen_label.shape[0])
                    record[i - 1,
                           2:] = np.array([HM_nobias, SA_nobias, UA_nobias])
                    del logits_s, logits_u, new_logits
                    torch.cuda.empty_cache()

            m_list = []
            p_list = []
            for i in range(5):
                m1, p1 = compute_confidence_interval(record[:, i])
                m_list.append(m1)
                p_list.append(p1)

            self.trlog['test_loss'] = m_list[0]
            self.trlog['test_acc'] = m_list[1]
            self.trlog['test_acc_interval'] = p_list[1]
            self.trlog['test_HM_acc'] = m_list[2]
            self.trlog['test_HM_acc_interval'] = p_list[2]
            self.trlog['test_HMSeen_acc'] = m_list[3]
            self.trlog['test_HMSeen_acc_interval'] = p_list[3]
            self.trlog['test_HMUnseen_acc'] = m_list[4]
            self.trlog['test_HMUnseen_acc_interval'] = p_list[4]

            print('best epoch {}, best val acc={:.4f} + {:.4f}\n'.format(
                self.trlog['max_acc_epoch'], self.trlog['max_acc'],
                self.trlog['max_acc_interval']))
            print('Test HM acc={:.4f} + {:.4f}\n'.format(
                self.trlog['test_HM_acc'], self.trlog['test_HM_acc_interval']))
            print('GFSL {}-way Acc w/o Bias {:.5f} + {:.5f}'.format(
                args.eval_way, m_list[1], p_list[1]))
            print('GFSL {}-way HM  w/o Bias {:.5f} + {:.5f}'.format(
                args.eval_way, m_list[2], p_list[2]))
            print('GFSL {}-way HMSeen  w/o Bias {:.5f} + {:.5f}'.format(
                args.eval_way, m_list[3], p_list[3]))
            print('GFSL {}-way HMUnseen  w/o Bias {:.5f} + {:.5f}'.format(
                args.eval_way, m_list[4], p_list[4]))
예제 #10
0
    def evaluate_gfsl(self):
        args = self.args
        label_unseen_query = torch.arange(args.eval_way).repeat(
            args.eval_query).long()
        if torch.cuda.is_available():
            label_unseen_query = label_unseen_query.cuda()

        generalized_few_shot_acc = np.zeros((args.num_eval_episodes, 2))
        generalized_few_shot_delta = np.zeros((args.num_eval_episodes, 4))
        generalized_few_shot_hmeanacc = np.zeros((args.num_eval_episodes, 6))
        generalized_few_shot_hmeanmap = np.zeros((args.num_eval_episodes, 6))
        generalized_few_shot_ausuc = np.zeros((args.num_eval_episodes, 1))
        AUC_record = []

        for i, batch in tqdm(
                enumerate(zip(self.test_gfsl_loader, self.test_fsl_loader),
                          1)):
            if torch.cuda.is_available():
                data_seen, data_unseen, seen_label, unseen_label = batch[0][
                    0].cuda(), batch[1][0].cuda(), batch[0][1].cuda(
                    ), batch[1][1].cuda()
            else:
                data_seen, data_unseen, seen_label, unseen_label = batch[0][
                    0], batch[1][0], batch[0][1], batch[1][1]
            p2 = args.eval_shot * args.eval_way

            data_unseen_shot, data_unseen_query = data_unseen[:
                                                              p2], data_unseen[
                                                                  p2:]
            label_unseen_shot, _ = unseen_label[:p2], unseen_label[p2:]
            whole_query = torch.cat([data_seen, data_unseen_query], 0)
            whole_label = torch.cat(
                [seen_label, label_unseen_query + self.trainset.num_class])
            if args.model_class in ['CLS', 'Castle', 'ACastle']:
                with torch.no_grad():
                    logits_s, logits_u = self.model.forward_generalized(
                        data_unseen_shot, whole_query)
            elif args.model_class in ['ProtoNet']:
                with torch.no_grad():
                    logits_s, logits_u = self.model.forward_generalized(
                        data_unseen_shot, whole_query, self.model.seen_proto)
            # compute un-biased accuracy
            new_logits = torch.cat([logits_s, logits_u], 1)
            if 'acc' in self.criteria or 'hmeanacc' in self.criteria or 'delta' in self.criteria:
                new_logits_acc_biased = torch.cat(
                    [logits_s - self.best_bias_acc, logits_u], 1)
            if 'hmeanmap' in self.criteria:
                new_logits_map_biased = torch.cat(
                    [logits_s - self.best_bias_map, logits_u], 1)
            # Criterion: Acc
            if 'acc' in self.criteria:
                generalized_few_shot_acc[i - 1, 0] = count_acc(
                    new_logits, whole_label)
                # compute biased accuracy
                generalized_few_shot_acc[i - 1, 1] = count_acc(
                    new_logits_acc_biased, whole_label)

            if 'delta' in self.criteria:
                # compute delta value for un-biased logits
                unbiased_detla1, unbiased_detla2 = count_delta_value(
                    new_logits, whole_label, seen_label.shape[0],
                    self.trainset.num_class)
                # compute delta value
                biased_detla1, biased_detla2 = count_delta_value(
                    new_logits_acc_biased, whole_label, seen_label.shape[0],
                    self.trainset.num_class)
                generalized_few_shot_delta[i - 1, :] = np.array([
                    unbiased_detla1, unbiased_detla2, biased_detla1,
                    biased_detla2
                ])

            if 'hmeanacc' in self.criteria:
                # compute harmonic mean
                HM_nobias, SA_nobias, UA_nobias = count_acc_harmonic_low_shot_joint(
                    new_logits, whole_label, seen_label.shape[0])
                HM, SA, UA = count_acc_harmonic_low_shot_joint(
                    new_logits_acc_biased, whole_label, seen_label.shape[0])
                generalized_few_shot_hmeanacc[i - 1, :] = np.array(
                    [HM_nobias, SA_nobias, UA_nobias, HM, SA, UA])

            if 'hmeanmap' in self.criteria:
                # compute harmonic mean
                HM_nobias, SA_nobias, UA_nobias = count_acc_harmonic_MAP(
                    new_logits, whole_label, seen_label.shape[0], 'macro')
                HM, SA, UA = count_acc_harmonic_MAP(new_logits_map_biased,
                                                    whole_label,
                                                    seen_label.shape[0],
                                                    'macro')
                generalized_few_shot_hmeanmap[i - 1, :] = np.array(
                    [HM_nobias, SA_nobias, UA_nobias, HM, SA, UA])

            if 'ausuc' in self.criteria:
                # compute AUSUC
                generalized_few_shot_ausuc[i - 1,
                                           0], temp_auc_record = Compute_AUSUC(
                                               logits_s.detach().cpu().numpy(),
                                               logits_u.detach().cpu().numpy(),
                                               whole_label.cpu().numpy(),
                                               np.arange(
                                                   self.trainset.num_class),
                                               self.trainset.num_class +
                                               np.arange(args.eval_way))
                AUC_record.append(temp_auc_record)

            del logits_s, logits_u, new_logits
            torch.cuda.empty_cache()

        self.AUC_record = AUC_record
        print('-'.join([args.model_class, args.model_path]))
        if 'acc' in self.criteria:
            self.trlog['acc_mean'], self.trlog[
                'acc_interval'] = compute_confidence_interval(
                    generalized_few_shot_acc[:, 0])
            self.trlog['acc_biased_mean'], self.trlog[
                'acc_biased_interval'] = compute_confidence_interval(
                    generalized_few_shot_acc[:, 1])
            print('GFSL {}-way Acc w/o Bias {:.5f} + {:.5f}'.format(
                args.eval_way, self.trlog['acc_mean'],
                self.trlog['acc_interval']))
            print('GFSL {}-way Acc w/  Bias {:.5f} + {:.5f}'.format(
                args.eval_way, self.trlog['acc_biased_mean'],
                self.trlog['acc_biased_interval']))

        if 'delta' in self.criteria:
            self.trlog['detla1_mean'], self.trlog[
                'detla1_interval'] = compute_confidence_interval(
                    generalized_few_shot_delta[:, 0])
            self.trlog['detla2_mean'], self.trlog[
                'detla2_interval'] = compute_confidence_interval(
                    generalized_few_shot_delta[:, 1])
            self.trlog['detla1_biased_mean'], self.trlog[
                'detla1_biased_interval'] = compute_confidence_interval(
                    generalized_few_shot_delta[:, 2])
            self.trlog['detla2_biased_mean'], self.trlog[
                'detla2_biased_interval'] = compute_confidence_interval(
                    generalized_few_shot_delta[:, 3])
            print('GFSL {}-way Detla1 w/o Bias {:.5f} + {:.5f}'.format(
                args.eval_way, self.trlog['detla1_mean'],
                self.trlog['detla1_interval']))
            print('GFSL {}-way Detla1 w/  Bias {:.5f} + {:.5f}'.format(
                args.eval_way, self.trlog['detla1_biased_mean'],
                self.trlog['detla1_biased_interval']))
            print('GFSL {}-way Detla2 w/o Bias {:.5f} + {:.5f}'.format(
                args.eval_way, self.trlog['detla2_mean'],
                self.trlog['detla2_interval']))
            print('GFSL {}-way Detla2 w/  Bias {:.5f} + {:.5f}'.format(
                args.eval_way, self.trlog['detla2_biased_mean'],
                self.trlog['detla2_biased_interval']))

        if 'hmeanacc' in self.criteria:
            self.trlog['HM_mean'], self.trlog[
                'HM_interval'] = compute_confidence_interval(
                    generalized_few_shot_hmeanacc[:, 0])
            self.trlog['S2All_mean'], self.trlog[
                'S2All_interval'] = compute_confidence_interval(
                    generalized_few_shot_hmeanacc[:, 1])
            self.trlog['U2All_mean'], self.trlog[
                'U2All_interval'] = compute_confidence_interval(
                    generalized_few_shot_hmeanacc[:, 2])
            self.trlog['HM_biased_mean'], self.trlog[
                'HM_biased_nterval'] = compute_confidence_interval(
                    generalized_few_shot_hmeanacc[:, 3])
            self.trlog['S2All_biased_mean'], self.trlog[
                'S2All_biased_interval'] = compute_confidence_interval(
                    generalized_few_shot_hmeanacc[:, 4])
            self.trlog['U2All_biased_mean'], self.trlog[
                'U2All_biased_interval'] = compute_confidence_interval(
                    generalized_few_shot_hmeanacc[:, 5])
            print('GFSL {}-way HM_mean w/o Bias {:.5f} + {:.5f}'.format(
                args.eval_way, self.trlog['HM_mean'],
                self.trlog['HM_interval']))
            print('GFSL {}-way HM_mean w/  Bias {:.5f} + {:.5f}'.format(
                args.eval_way, self.trlog['HM_biased_mean'],
                self.trlog['HM_biased_nterval']))
            print('GFSL {}-way S2All_mean w/o Bias {:.5f} + {:.5f}'.format(
                args.eval_way, self.trlog['S2All_mean'],
                self.trlog['S2All_interval']))
            print('GFSL {}-way S2All_mean w/  Bias {:.5f} + {:.5f}'.format(
                args.eval_way, self.trlog['S2All_biased_mean'],
                self.trlog['S2All_biased_interval']))
            print('GFSL {}-way U2All_mean w/o Bias {:.5f} + {:.5f}'.format(
                args.eval_way, self.trlog['U2All_mean'],
                self.trlog['U2All_interval']))
            print('GFSL {}-way U2All_mean w/  Bias {:.5f} + {:.5f}'.format(
                args.eval_way, self.trlog['U2All_biased_mean'],
                self.trlog['U2All_biased_interval']))

        if 'hmeanmap' in self.criteria:
            self.trlog['HM_map_mean'], self.trlog[
                'HM_map_interval'] = compute_confidence_interval(
                    generalized_few_shot_hmeanmap[:, 0])
            self.trlog['S2All_map_mean'], self.trlog[
                'S2All_map_interval'] = compute_confidence_interval(
                    generalized_few_shot_hmeanmap[:, 1])
            self.trlog['U2All_map_mean'], self.trlog[
                'U2All_map_interval'] = compute_confidence_interval(
                    generalized_few_shot_hmeanmap[:, 2])
            self.trlog['HM_map_biased_mean'], self.trlog[
                'HM_map_biased_nterval'] = compute_confidence_interval(
                    generalized_few_shot_hmeanmap[:, 3])
            self.trlog['S2All_map_biased_mean'], self.trlog[
                'S2All_map_biased_interval'] = compute_confidence_interval(
                    generalized_few_shot_hmeanmap[:, 4])
            self.trlog['U2All_map_biased_mean'], self.trlog[
                'U2All_map_biased_interval'] = compute_confidence_interval(
                    generalized_few_shot_hmeanmap[:, 5])
            print('GFSL {}-way HM_map_mean w/o Bias {:.5f} + {:.5f}'.format(
                args.eval_way, self.trlog['HM_map_mean'],
                self.trlog['HM_map_interval']))
            print('GFSL {}-way HM_map_mean w/  Bias {:.5f} + {:.5f}'.format(
                args.eval_way, self.trlog['HM_map_biased_mean'],
                self.trlog['HM_map_biased_nterval']))
            print('GFSL {}-way S2All_map_mean w/o Bias {:.5f} + {:.5f}'.format(
                args.eval_way, self.trlog['S2All_map_mean'],
                self.trlog['S2All_map_interval']))
            print('GFSL {}-way S2All_map_mean w/  Bias {:.5f} + {:.5f}'.format(
                args.eval_way, self.trlog['S2All_map_biased_mean'],
                self.trlog['S2All_map_biased_interval']))
            print('GFSL {}-way U2All_map_mean w/o Bias {:.5f} + {:.5f}'.format(
                args.eval_way, self.trlog['U2All_map_mean'],
                self.trlog['U2All_map_interval']))
            print('GFSL {}-way U2All_map_mean w/  Bias {:.5f} + {:.5f}'.format(
                args.eval_way, self.trlog['U2All_map_biased_mean'],
                self.trlog['U2All_map_biased_interval']))

        if 'ausuc' in self.criteria:
            self.trlog['AUSUC_mean'], self.trlog[
                'AUSUC_interval'] = compute_confidence_interval(
                    generalized_few_shot_ausuc[:, 0])
            print('GFSL {}-way AUSUC {:.5f} + {:.5f}'.format(
                args.eval_way, self.trlog['AUSUC_mean'],
                self.trlog['AUSUC_interval']))
예제 #11
0
                    if torch.cuda.is_available():
                        data, _ = [_.cuda() for _ in batch]
                    else:
                        data = batch[0]
                    data_emb = model(data, is_emb=True)

                    if args.centralize:
                        data_emb = data_emb - class_mean
                    if args.normalize:
                        data_emb = F.normalize(data_emb, dim=1, p=2)

                    split_index = args.way * args.shot
                    data_shot, data_query = data_emb[:split_index], data_emb[
                        split_index:]

                    SVM = LinearSVC(C=best_c,
                                    multi_class='crammer_singer',
                                    dual=False,
                                    max_iter=5000).fit(data_shot.cpu().numpy(),
                                                       shot_label)
                    prediction = SVM.predict(data_query.cpu().numpy())
                    acc = np.mean(prediction == query_label)
                    test_acc_record[i - 1] = acc
                    # print('batch {}: {:.2f}({:.2f})'.format(i, ave_acc.item() * 100, acc * 100))

            m, pm = compute_confidence_interval(test_acc_record)
            ensemble_result.append('{:.4f} + {:.4f}'.format(m, pm))
            print('{} way {} shot,Test acc={:.4f} + {:.4f}, best_gamma:{}'.
                  format(args.way, args.shot, m, pm, best_c))
    print('ensemble result: {}'.format(','.join(ensemble_result)))
예제 #12
0
    def open_evaluate(self, data_loader):
        # restore model args
        args = self.args
        # evaluation mode
        self.model.eval()
        record = np.zeros((args.num_test_episodes, 4))  # loss and acc

        label = torch.arange(args.eval_way, dtype=torch.int16).repeat(args.eval_query)
        label = label.type(torch.LongTensor)
        if torch.cuda.is_available():
            label = label.cuda()
        print('Evaluating ... best epoch {}, SnaTCHer={:.4f} + {:.4f}  acc={:.4f} + {:.4f}'.format(
            self.trlog['max_auc_epoch'],
            self.trlog['max_auc'],
            self.trlog['max_auc_interval'],
            self.trlog['acc'],
            self.trlog['acc_interval']))

        with torch.no_grad():
            for i, batch in enumerate(tqdm(data_loader)):

                data, _ = [_.cuda() for _ in batch]

                logits = self.para_model(data)
                logits = logits.reshape([-1, args.eval_way + args.open_eval_way, args.way])
                klogits = logits[:, :args.eval_way, :].reshape(-1, args.way)
                ulogits = logits[:, args.eval_way:, :].reshape(-1, args.way)
                loss = F.cross_entropy(klogits, label)
                acc = count_acc(klogits, label)

                """ Distance """
                kdist = -(klogits.max(1)[0])
                udist = -(ulogits.max(1)[0])
                kdist = kdist.cpu().detach().numpy()
                udist = udist.cpu().detach().numpy()
                dist_auroc = calc_auroc(kdist, udist)

                """ Snatcher """
                with torch.no_grad():
                    instance_embs = self.para_model.instance_embs
                    support_idx = self.para_model.support_idx
                    query_idx = self.para_model.query_idx

                    support = instance_embs[support_idx.flatten()].view(*(support_idx.shape + (-1,)))
                    query = instance_embs[query_idx.flatten()].view(*(query_idx.shape + (-1,)))
                    emb_dim = support.shape[-1]

                    support = support[:, :, :args.way].contiguous()
                    # get mean of the support
                    bproto = support.mean(dim=1)  # Ntask x NK x d
                    proto = self.para_model.slf_attn(bproto, bproto, bproto)
                    kquery = query[:, :, :args.way].contiguous()
                    uquery = query[:, :, args.way:].contiguous()
                    snatch_known = []
                    for j in range(75):
                        pproto = bproto.clone().detach()
                        """ Algorithm 1 Line 1 """
                        c = klogits.argmax(1)[j]
                        """ Algorithm 1 Line 2 """
                        pproto[0][c] = kquery.reshape(-1, emb_dim)[j]
                        """ Algorithm 1 Line 3 """
                        pproto = self.para_model.slf_attn(pproto, pproto, pproto)[0]
                        pdiff = (pproto - proto).pow(2).sum(-1).sum() / 64.0
                        """ pdiff: d_SnaTCHer in Algorithm 1 """
                        snatch_known.append(pdiff)

                    snatch_unknown = []
                    for j in range(ulogits.shape[0]):
                        pproto = bproto.clone().detach()
                        """ Algorithm 1 Line 1 """
                        c = ulogits.argmax(1)[j]
                        """ Algorithm 1 Line 2 """
                        pproto[0][c] = uquery.reshape(-1, emb_dim)[j]
                        """ Algorithm 1 Line 3 """
                        pproto = self.para_model.slf_attn(pproto, pproto, pproto)[0]
                        pdiff = (pproto - proto).pow(2).sum(-1).sum() / 64.0
                        """ pdiff: d_SnaTCHer in Algorithm 1 """
                        snatch_unknown.append(pdiff)

                    pkdiff = torch.stack(snatch_known)
                    pudiff = torch.stack(snatch_unknown)
                    pkdiff = pkdiff.cpu().detach().numpy()
                    pudiff = pudiff.cpu().detach().numpy()

                    snatch_auroc = calc_auroc(pkdiff, pudiff)
                record[i - 1, 0] = loss.item()
                record[i - 1, 1] = acc
                record[i - 1, 2] = snatch_auroc
                record[i - 1, 3] = dist_auroc

        vl, _ = compute_confidence_interval(record[:, 0])
        va, vap = compute_confidence_interval(record[:, 1])
        auc_sna, auc_sna_p = compute_confidence_interval(record[:, 2])
        auc_dist, auc_dist_p = compute_confidence_interval(record[:, 3])
        print("acc: {:.4f} + {:.4f} Dist: {:.4f} + {:.4f} SnaTCHer: {:.4f} + {:.4f}" \
              .format(va, vap, auc_dist, auc_dist_p, auc_sna, auc_sna_p))
        # train mode
        self.model.train()
        if self.args.fix_BN:
            self.model.encoder.eval()

        return vl, va, vap, auc_sna, auc_sna_p