コード例 #1
0
    def model_eval(self, preds, gts):
        assert(self.opt.EVAL_METRIC in ['mean_accu', 'accuracy']), \
             "Currently don't support the evaluation metric you specified."

        if self.opt.EVAL_METRIC == "mean_accu": 
            res = mean_accuracy(preds, gts)
        elif self.opt.EVAL_METRIC == "accuracy":
            res = accuracy(preds, gts)
        return res
コード例 #2
0
def test(args):
    # prepare data
    dataloader = prepare_data()

    # initialize model
    model_state_dict = None
    fx_pretrained = True

    bn_domain_map = {}
    if cfg.WEIGHTS != '':
        weights_dict = torch.load(cfg.WEIGHTS)
        model_state_dict = weights_dict['weights']
        bn_domain_map = weights_dict['bn_domain_map']
        fx_pretrained = False

    if args.adapted_model:
        num_domains_bn = 2
    else:
        num_domains_bn = 1

    net = model.danet(num_classes=cfg.DATASET.NUM_CLASSES,
                      state_dict=model_state_dict,
                      feature_extractor=cfg.MODEL.FEATURE_EXTRACTOR,
                      fx_pretrained=fx_pretrained,
                      dropout_ratio=cfg.TRAIN.DROPOUT_RATIO,
                      fc_hidden_dims=cfg.MODEL.FC_HIDDEN_DIMS,
                      num_domains_bn=num_domains_bn)

    net = torch.nn.DataParallel(net)

    if torch.cuda.is_available():
        net.cuda()

    # test
    res = {}
    res['path'], res['preds'], res['gt'], res['probs'] = [], [], [], []
    net.eval()

    if cfg.TEST.DOMAIN in bn_domain_map:
        domain_id = bn_domain_map[cfg.TEST.DOMAIN]
    else:
        domain_id = 0

    with torch.no_grad():
        net.module.set_bn_domain(domain_id)
        for sample in iter(dataloader):
            res['path'] += sample['Path']

            if cfg.DATA_TRANSFORM.WITH_FIVE_CROP:
                n, ncrop, c, h, w = sample['Img'].size()
                sample['Img'] = sample['Img'].view(-1, c, h, w)
                img = to_cuda(sample['Img'])
                probs = net(img)['probs']
                probs = probs.view(n, ncrop, -1).mean(dim=1)
            else:
                img = to_cuda(sample['Img'])
                probs = net(img)['probs']

            preds = torch.max(probs, dim=1)[1]
            res['preds'] += [preds]
            res['probs'] += [probs]

            if 'Label' in sample:
                label = to_cuda(sample['Label'])
                res['gt'] += [label]
            print('Processed %d samples.' % len(res['path']))

        preds = torch.cat(res['preds'], dim=0)
        save_preds(res['path'], preds, cfg.SAVE_DIR)

        if 'gt' in res and len(res['gt']) > 0:
            gts = torch.cat(res['gt'], dim=0)
            probs = torch.cat(res['probs'], dim=0)

            assert (cfg.EVAL_METRIC == 'mean_accu'
                    or cfg.EVAL_METRIC == 'accuracy')
            if cfg.EVAL_METRIC == "mean_accu":
                eval_res = mean_accuracy(probs, gts)
                print('Test mean_accu: %.4f' % (eval_res))

            elif cfg.EVAL_METRIC == "accuracy":
                eval_res = accuracy(probs, gts)
                print('Test accuracy: %.4f' % (eval_res))

    print('Finished!')