コード例 #1
0
ファイル: eval.py プロジェクト: yuanwei0908/hfsoftmax
def main():
    global args
    args = parser.parse_args()

    if not os.path.exists(args.output_path):
        comm = 'python extract_feat.py \
                --arch {} \
                --batch-size {} \
                --input-size {} \
                --feature-dim {} \
                --load-path {} \
                --bin-file {} \
                --output-path {}'\
                .format(args.arch, args.batch_size, args.input_size, args.feature_dim,
                        args.load_path, args.bin_file, args.output_path)
        print(' '.join(comm.split()))
        os.system(comm)

    features = np.load(args.output_path).reshape(-1, args.feature_dim)
    _, lbs = bin_loader(args.bin_file)
    print('feature shape: {}'.format(features.shape))
    assert features.shape[0] == 2 * len(lbs), "{} vs {}".format(
        features.shape[0], 2 * len(lbs))

    features = normalize(features)
    _, _, acc, val, val_std, far = evaluate(features,
                                            lbs,
                                            nrof_folds=args.nfolds,
                                            distance_metric=0)
    print("accuracy: {:.4f}({:.4f})".format(acc.mean(), acc.std()))
コード例 #2
0
def evaluation(test_loader, model, num, outfeat_fn, benchmark):
    load_feat = True
    if not os.path.isfile(outfeat_fn) or not load_feat:
        features = extract(test_loader, model, num, outfeat_fn, silent=True)
    else:
        print("loading from: {}".format(outfeat_fn))
        features = np.fromfile(outfeat_fn, dtype=np.float32).reshape(
            -1, args.model.feature_dim)

    if benchmark == "megaface":
        r = test_megaface(features)
        log(' * Megaface: 1e-6 [{}], 1e-5 [{}], 1e-4 [{}]'.format(
            r[-1], r[-2], r[-3]))
        return r[-1]
    else:
        features = normalize(features)
        _, lbs = bin_loader("{}/{}.bin".format(args.test.test_root, benchmark))
        _, _, acc, val, val_std, far = evaluate(features,
                                                lbs,
                                                nrof_folds=args.test.nfolds,
                                                distance_metric=0)

        log(" * {}: accuracy: {:.4f}({:.4f})".format(benchmark, acc.mean(),
                                                     acc.std()))
        return acc.mean()
コード例 #3
0
 def __init__(self, bin_file, transform=None):
     self.img_lst, _ = bin_loader(bin_file)
     self.num = len(self.img_lst)
     self.transform = transform