Пример #1
0
 def extract_metric(self, data_loader, orig_pred_y):
     res = 0
     self.model.train()
     for _ in range(self.iter_time):
         _, pred, _ = common_predict(data_loader, self.model, self.device)
         res = res + pred.eq(orig_pred_y)
     self.model.eval()
     res = common_ten2numpy(res.float() / self.iter_time)
     return res
Пример #2
0
 def run(self, data_loader, iter_time, module_id):
     res = []
     while len(res) <= iter_time:
         print('this is the %d model' % (len(res)))
         mutate_model, is_fail = self.get_mutate_model()
         if not is_fail:
             _, pred_y, _ = common_predict(data_loader, mutate_model, self.device, 
                                           module_id=module_id)
             res.append(common_ten2numpy(pred_y).reshape([-1, 1]))
     return np.concatenate(res, axis=1)
Пример #3
0
 def extract_metric(self, data_loader):
     fx, _ = self.get_penultimate(data_loader)
     score = []
     for target in range(self.class_num):
         tmp = (fx - self.u_list[target]).mm(self.std_value)
         tmp = tmp.mm((fx - self.u_list[target]).transpose(dim0=0, dim1=1))
         tmp = tmp.diagonal().reshape([-1, 1])
         score.append(-tmp)
     score = torch.cat(score, dim=1)
     score = common_ten2numpy(torch.max(score, dim=1)[0])
     return score
Пример #4
0
 def _uncertainty_calculate(self, data_loader):
     self.model.eval()
     _, orig_pred, _ = self._predict_result(data_loader, self.model)
     mc_result = []
     print('calculating uncertainty ...')
     self.model.train()
     for i in tqdm(range(self.iter_time)):
         _, res, _ = self._predict_result(data_loader, self.model)
         mc_result.append(common_ten2numpy(res).reshape([-1, 1]))
     mc_result = np.concatenate(mc_result, axis=1)
     score = self.label_chgrate(orig_pred, mc_result)
     return 1 - score
Пример #5
0
 def _uncertainty_calculate(self, data_loader):
     print('Dissactor uncertainty evaluation ...')
     weight_list = [0, 1, 2]
     result = []
     _, pred_y, _ = common_predict(data_loader, self.model, self.device)
     # pred_y = pred_y.to(self.device)
     svscore_list, sub_num = self.get_svscore(data_loader, pred_y)
     for weight in weight_list:
         pv_score = self.get_pvscore(svscore_list, sub_num,
                                     weight).detach().cpu()
         result.append(1 - common_ten2numpy(pv_score)
                       )  # pick the 1-score as uncertainty score
     return result
    def __init__(self, instance: BasicModule, device):
        super(BasicUncertainty, self).__init__()
        self.instance = instance
        self.device = device
        self.train_batch_size = instance.train_batch_size
        self.test_batch_size = instance.test_batch_size
        self.model = instance.model.to(device)
        self.class_num = instance.class_num
        self.save_dir = instance.save_dir
        self.module_id = instance.module_id
        self.softmax = nn.Softmax(dim=1)
        self.test_path = instance.test_path

        # handle train data and oracle
        self.train_y = instance.train_y
        self.train_pred_pos, self.train_pred_y =\
            instance.train_pred_pos, instance.train_pred_y
        self.train_loader = instance.train_loader
        self.train_num = len(self.train_y)
        self.train_oracle = np.int32(
            common_ten2numpy(self.train_pred_y).reshape([-1]) == \
                common_ten2numpy(self.train_y).reshape([-1])
        )
        # handle val data and oracle
        self.val_y = instance.val_y
        self.val_pred_pos, self.val_pred_y = \
            instance.val_pred_pos, instance.val_pred_y
        self.val_loader = instance.val_loader
        self.val_num = len(self.val_y)
        self.val_oracle = np.int32(
            common_ten2numpy(self.val_pred_y).reshape([-1]) == \
                common_ten2numpy(self.val_y).reshape([-1])
        )
        # handle ood data and oracle
        if instance.ood_path is not None:
            self.ood_y = instance.ood_y
            self.ood_pred_pos, self.ood_pred_y = \
                instance.ood_pred_pos, instance.ood_pred_y
            self.ood_loader = instance.ood_loader
            self.ood_num = len(self.ood_y)
            self.ood_oracle = np.int32(
                common_ten2numpy(self.ood_pred_y).reshape([-1]) == \
                    common_ten2numpy(self.ood_y).reshape([-1])
            )

        if self.test_path is not None:
            self.test_y = instance.test_y
            self.test_pred_pos, self.test_pred_y = \
                instance.test_pred_pos, instance.test_pred_y
            self.test_loader = instance.test_loader
            self.test_num = len(self.test_y)
            self.test_oracle = np.int32(
                common_ten2numpy(self.test_pred_y).reshape([-1]) == \
                    common_ten2numpy(self.test_y).reshape([-1])
            )

        else:
            self.test_y1, self.test_y2, self.test_y3 = \
                instance.test_y1, instance.test_y2, instance.test_y3
            self.test_pred_pos1, self.test_pred_y1 = \
                instance.test_pred_pos1, instance.test_pred_y1
            self.test_loader1 = instance.test_loader1
            self.test_num1 = len(self.test_y1)
            self.test_oracle1 = np.int32(
                common_ten2numpy(self.test_pred_y1).reshape([-1]) == \
                    common_ten2numpy(self.test_y1).reshape([-1])
            )
            self.test_pred_pos2, self.test_pred_y2 = \
                instance.test_pred_pos2, instance.test_pred_y2
            self.test_loader2 = instance.test_loader2
            self.test_num2 = len(self.test_y2)
            self.test_oracle2 = np.int32(
                common_ten2numpy(self.test_pred_y2).reshape([-1]) == \
                    common_ten2numpy(self.test_y2).reshape([-1])
            )
            self.test_pred_pos3, self.test_pred_y3 = \
                instance.test_pred_pos3, instance.test_pred_y3
            self.test_loader3 = instance.test_loader3
            self.test_num3 = len(self.test_y3)
            self.test_oracle3 = np.int32(
                common_ten2numpy(self.test_pred_y3).reshape([-1]) == \
                    common_ten2numpy(self.test_y3).reshape([-1])
            )