コード例 #1
0
 def _uncertainty_calculate(self, data_loader):
     score, _, _ = common_predict(data_loader,
                                  self,
                                  self.device,
                                  module_id=self.module_id)
     score = common_get_maxpos(score)
     return score
コード例 #2
0
 def extract_metric(self, data_loader, orig_pred_y):
     res = 0
     self.model.train()
     for _ in range(self.iter_time):
         _, pred, _ = common_predict(data_loader, self.model, self.device)
         res = res + pred.eq(orig_pred_y)
     self.model.eval()
     res = common_ten2numpy(res.float() / self.iter_time)
     return res
コード例 #3
0
 def run(self, data_loader, iter_time, module_id):
     res = []
     while len(res) <= iter_time:
         print('this is the %d model' % (len(res)))
         mutate_model, is_fail = self.get_mutate_model()
         if not is_fail:
             _, pred_y, _ = common_predict(data_loader, mutate_model, self.device, 
                                           module_id=module_id)
             res.append(common_ten2numpy(pred_y).reshape([-1, 1]))
     return np.concatenate(res, axis=1)
コード例 #4
0
 def _uncertainty_calculate(self, data_loader):
     print('Dissactor uncertainty evaluation ...')
     weight_list = [0, 1, 2]
     result = []
     _, pred_y, _ = common_predict(data_loader, self.model, self.device)
     # pred_y = pred_y.to(self.device)
     svscore_list, sub_num = self.get_svscore(data_loader, pred_y)
     for weight in weight_list:
         pv_score = self.get_pvscore(svscore_list, sub_num,
                                     weight).detach().cpu()
         result.append(1 - common_ten2numpy(pv_score)
                       )  # pick the 1-score as uncertainty score
     return result
コード例 #5
0
 def get_submodel_prediction(self, data_loader):
     res = []
     sub_res_list, sub_num, y = self.instance.get_hiddenstate(
         data_loader, self.device)
     for i in range(len(sub_num)):
         save_path = self.get_submodel_path(sub_num[i])
         linear_model = torch.load(save_path, map_location=self.device)
         linear_model.eval()
         hidden = sub_res_list[i]
         data_loader = build_loader(hidden, y, self.test_batch_size)
         pred_pos, pred_y, _ = common_predict(data_loader,
                                              linear_model,
                                              self.device,
                                              train_sub=True)
         res.append(pred_pos)
         print('test accuracy for', self.__class__.__name__, 'submodel ',
               sub_num[i], 'is',
               torch.sum(y.eq(pred_y), dtype=torch.float).item() / len(y))
     return res, sub_num
コード例 #6
0
    def _uncertainty_calculate(self, data_loader):
        # self: ModelWithTemperature(
        #     (model): Code2Vec(
        #         (node_embedding): Embedding(125344, 100, padding_idx=1)
        #         (path_embedding): Embedding(557961, 100)
        #         (out): Linear(in_features=100, out_features=48557, bias=True)
        #         (drop): Dropout(p=0.5, inplace=False)
        #     )
        #     (softmax): Softmax(dim=1)
        # )

        # self.model:  Code2Vec(
        #     (node_embedding): Embedding(125344, 100, padding_idx=1)
        #     (path_embedding): Embedding(557961, 100)
        #     (out): Linear(in_features=100, out_features=48557, bias=True)
        #     (drop): Dropout(p=0.5, inplace=False)
        # )
        score, _, _ = common_predict(data_loader, self, self.device)
        score = common_get_maxpos(score)
        return score
コード例 #7
0
    def train_sub_model(self, lr, epoch):
        print("train sub models ...")
        sub_res_list, sub_num, label = self.instance.get_hiddenstate(
            self.train_loader, self.device)
        for i, sub_res in enumerate(sub_res_list):
            linear = nn.Linear(len(sub_res[1]), self.class_num).to(self.device)
            my_loss = nn.CrossEntropyLoss()
            optimizer = optim.SGD(linear.parameters(), lr=lr)
            data_loader = build_loader(sub_res, label, self.train_batch_size)
            linear.train()
            for _ in range(epoch):
                for x, y in data_loader:
                    x = x.to(self.device)
                    y = y.to(self.device).view([-1])
                    linear.zero_grad()
                    pred = linear(x)
                    loss = my_loss(pred, y)
                    loss.backward()
                    optimizer.step()
                    # detach
                    x = x.detach().cpu()
                    y = y.detach().cpu()
                    pred = pred.detach().cpu()

            linear.eval()

            _, pred_y, _ = common_predict(data_loader,
                                          linear,
                                          device=self.device,
                                          train_sub=True,
                                          module_id=self.module_id)
            acc = common_cal_accuracy(pred_y, self.train_y)
            print('feature number for sub-model is', len(sub_res[0]),
                  'finish training the sub-model', sub_num[i], 'for ',
                  self.instance.__class__.__name__, 'accuracy is', acc)

            save_path = self.get_submodel_path(sub_num[i])
            torch.save(linear, save_path)
            print('save sub model in ', save_path)
コード例 #8
0
 def _uncertainty_calculate(self, data_loader):
     pred_pos, _, _ = common_predict(data_loader, self.model, self.device)
     return common_get_entropy(pred_pos)
コード例 #9
0
 def _uncertainty_calculate(self, data_loader):
     pred_pos, _, _ = common_predict(data_loader,
                                     self.model,
                                     self.device,
                                     module_id=self.module_id)
     return common_get_maxpos(pred_pos)
コード例 #10
0
 def _uncertainty_calculate(self, data_loader):
     return common_predict(data_loader,
                           self.model,
                           self.device,
                           module_id=self.module_id)