示例#1
0
    def test_eval_logic(self, dataloader, param: FixMatchParams):
        from thexp.calculate import accuracy as acc

        param.topk = param.default([1, 5])

        with torch.no_grad():
            noisy_mem = torch.zeros(50000,
                                    device=self.device,
                                    dtype=torch.long)
            count_dict = Meter()
            for batch_data in dataloader:
                ids, xs, labels = batch_data
                preds = self.predict(xs)
                noisy_ys = preds.argmax(dim=1)
                noisy_mem[ids] = noisy_ys
                total, topk_res = acc.classify(preds, labels, topk=param.topk)
                count_dict["total"] += total
                for i, topi_res in zip(param.topk, topk_res):
                    count_dict["top{}".format(i)] += topi_res

        import numpy as np
        noisy_mem = noisy_mem.detach().cpu().numpy()
        np.save('noisy_{}.npy'.format(count_dict['top1']), noisy_mem)
        self.logger.info()
        return count_dict
示例#2
0
def test_classify():
    labels = torch.tensor([0, 1, 2, 3])
    preds = torch.tensor(
        [[5, 4, 3, 2], [5, 4, 3, 2], [5, 4, 3, 2], [5, 4, 3, 2]],
        dtype=torch.float)
    total, res = acc.classify(preds, labels, topk=(1, 2, 3, 4))
    assert total == 4
    assert res[0] == 1 and res[1] == 2 and res[2] == 3 and res[3] == 4, str(
        res)
示例#3
0
 def test_eval_logic(self, dataloader, param: Params):
     from thexp.calculate import accuracy as acc
     with torch.no_grad():
         count_dict = Meter()
         for xs, labels in dataloader:
             xs, labels = xs.to(self.device), labels.to(self.device)
             preds = self.predict(xs)
             total, topk_res = acc.classify(preds, labels, topk=param.topk)
             count_dict["total"] += total
             for i, topi_res in zip(param.topk, topk_res):
                 count_dict[i] += topi_res
     return count_dict
示例#4
0
文件: acc.py 项目: sqiangcao99/thexp
 def test_eval_logic(self, dataloader, param: Params):
     from thexp.calculate import accuracy as acc
     param.topk = param.default([1, 5])
     with torch.no_grad():
         count_dict = Meter()
         for xs, labels in dataloader:
             preds = self.predict(xs)
             total, topk_res = acc.classify(preds, labels, topk=param.topk)
             count_dict["total"] += total
             for i, topi_res in zip(param.topk, topk_res):
                 count_dict["top{}".format(i)] += topi_res
     return count_dict