コード例 #1
0
def main():
    args = sys.argv[1:]
    if len(args) < 3:
        print >> sys.stderr, 'check_dist.py a.png b.png out.pfm'
        sys.exit(1)

    if use_c:
        pfm_arg = args[2]
        if os.path.splitext(pfm_arg)[1] == '.mat':
            nnf = read_mat(pfm_arg)
            if nnf.shape[2] != 3:
                nnf = numpy.dstack(
                    (nnf, numpy.zeros((nnf.shape[0], nnf.shape[1]))))
            pfm_arg = '_temp_dist.pfm'
            pfm.writepfm(nnf, pfm_arg)
        ans = subprocess.check_output('./patchtable check_dist %s %s %s' %
                                      (args[0], args[1], pfm_arg),
                                      shell=True)
        sys.stdout.write(ans)
        return
    a = skimage.img_as_float(skimage.io.imread(args[0]))
    b = skimage.img_as_float(skimage.io.imread(args[1]))
    if os.path.splitext(args[2])[1] == '.pfm':
        nnf = pfm.readpfm(args[2])
    else:
        nnf = read_mat(args[2])

    beh = b.shape[0] - patch_w + 1
    bew = b.shape[1] - patch_w + 1

    nnf_h = a.shape[0] - patch_w + 1
    nnf_w = a.shape[1] - patch_w + 1
    dimg = numpy.zeros((nnf_h, nnf_w))
    for ay in range(nnf_h):
        for ax in range(nnf_w):
            bx = nnf[ay, ax, 0]
            by = nnf[ay, ax, 1]
            assert 0 <= bx < bew, (bx, bew)
            assert 0 <= by < beh, (by, beh)
            apatch = a[ay:ay + patch_w, ax:ax + patch_w, :].flatten()
            bpatch = b[by:by + patch_w, bx:bx + patch_w, :].flatten()
            d = (apatch - bpatch)
            d = numpy.sqrt(numpy.sum(d * d))
            dimg[ay, ax] = d

            #d_nnf = numpy.sqrt(nnf[ay,ax,2])
            #delta = abs(d-d_nnf)
            #if delta > 1e-4:
            #    print 'error:', ay, ax, d, dcorrect
    #print '0,0 nnf:', nnf[0,0,:]
    #print '0,0 correct distance:', dimg[0,0]
    #print '%d,0 nnf:'%(nnf_w-1), nnf[0,nnf_w-1,:]
    #print '%d,0 correct distance:'%(nnf_w-1), dimg[0,nnf_w-1]
    print numpy.mean(dimg.flatten())
コード例 #2
0
ファイル: check_dist.py プロジェクト: caomw/patchtable
def main():
    args = sys.argv[1:]
    if len(args) < 3:
        print >> sys.stderr, 'check_dist.py a.png b.png out.pfm'
        sys.exit(1)
    
    if use_c:
        pfm_arg = args[2]
        if os.path.splitext(pfm_arg)[1] == '.mat':
            nnf = read_mat(pfm_arg)
            if nnf.shape[2] != 3:
                nnf = numpy.dstack((nnf, numpy.zeros((nnf.shape[0], nnf.shape[1]))))
            pfm_arg = '_temp_dist.pfm'
            pfm.writepfm(nnf, pfm_arg) 
        ans = subprocess.check_output('./patchtable check_dist %s %s %s' % (args[0], args[1], pfm_arg), shell=True)
        sys.stdout.write(ans)
        return
    a = skimage.img_as_float(skimage.io.imread(args[0]))
    b = skimage.img_as_float(skimage.io.imread(args[1]))
    if os.path.splitext(args[2])[1] == '.pfm':
        nnf = pfm.readpfm(args[2])
    else:
        nnf = read_mat(args[2])
        
    beh = b.shape[0]-patch_w+1
    bew = b.shape[1]-patch_w+1
    
    nnf_h = a.shape[0]-patch_w+1
    nnf_w = a.shape[1]-patch_w+1
    dimg = numpy.zeros((nnf_h, nnf_w))
    for ay in range(nnf_h):
        for ax in range(nnf_w):
            bx = nnf[ay,ax,0]
            by = nnf[ay,ax,1]
            assert 0 <= bx < bew, (bx, bew)
            assert 0 <= by < beh, (by, beh)
            apatch = a[ay:ay+patch_w, ax:ax+patch_w, :].flatten()
            bpatch = b[by:by+patch_w, bx:bx+patch_w, :].flatten()
            d = (apatch-bpatch)
            d = numpy.sqrt(numpy.sum(d*d))
            dimg[ay,ax] = d
            
            #d_nnf = numpy.sqrt(nnf[ay,ax,2])
            #delta = abs(d-d_nnf)
            #if delta > 1e-4:
            #    print 'error:', ay, ax, d, dcorrect
    #print '0,0 nnf:', nnf[0,0,:]
    #print '0,0 correct distance:', dimg[0,0]
    #print '%d,0 nnf:'%(nnf_w-1), nnf[0,nnf_w-1,:]
    #print '%d,0 correct distance:'%(nnf_w-1), dimg[0,nnf_w-1]
    print numpy.mean(dimg.flatten())
コード例 #3
0
ファイル: stress.py プロジェクト: ymchen7/patchtable
def main():
    def make_prev_nnf():
        prev_nnf_x = numpy.asarray(numpy.random.randint(bew, size=(aeh, aew)), float)
        prev_nnf_y = numpy.asarray(numpy.random.randint(beh, size=(aeh, aew)), float)
        if use_allowed:
            for y in range(aeh):
                for x in range(aew):
                    while not allowed[prev_nnf_y[y,x], prev_nnf_x[y,x], 0]:
                        prev_nnf_y[y,x] = random.randrange(beh)
                        prev_nnf_x[y,x] = random.randrange(bew)
        prev_nnf_dist = numpy.zeros((aeh, aew))
        prev_nnf = numpy.dstack((prev_nnf_x, prev_nnf_y, prev_nnf_dist))
        prev_nnf_filename = 'prev_nnf.pfm'
        pfm.writepfm(prev_nnf, prev_nnf_filename)

        return prev_nnf_filename


    ntest = 100
    patch_w = 8

    if not imgcollection_bug:
        print '=' * 80
        print 'Stress Test Descriptor'
        print '=' * 80
        s = subprocess.check_output('./patchtable test_descriptor', shell=True)
        last_line = s.strip().split('\n')[-1].strip()
        if last_line != 'test_descriptor: OK':
            raise ValueError('descriptor test failed')
        else:
            print 'Passed'

    simple_test_count = 10

    for i in range(ntest):
        random.seed(i)
        numpy.random.seed(i)
        print '=' * 80
        print 'Stress Test %d/%d' % (i, ntest)
        print '=' * 80
        print
        a = gen_image()
        b = gen_image()
        allowed = gen_allowed(b)
        
        aeh = a.shape[0]-patch_w+1
        aew = a.shape[1]-patch_w+1
        beh = b.shape[0]-patch_w+1
        bew = b.shape[1]-patch_w+1
        
        imwrite(a, 'atest.png')
        imwrite(b, 'btest.png')
        imwrite(allowed, 'allowed.png')
        if os.path.exists('out.pfm'):
            os.remove('out.pfm')
        suffix = ''

        use_allowed = False
        r2 = random.randrange(2) == 0
        if i <= simple_test_count:
            r2 = 1
        if r2:
            use_allowed = True

        r = random.randrange(5)
        if i <= simple_test_count:
            r = 3
        speed = random.randrange(min_speed,max_speed)
        if r == 0:
            suffix = ' -speed %d' % speed
        elif r == 1:
            suffix = ' -speed %d -coherence_spatial %f' % (speed, random.uniform(0, 100))
        elif r == 2:
            suffix = ' -speed %d -is_descriptor 1' % speed
        elif r == 3:
            prev_nnf_filename = make_prev_nnf()
            suffix = ' -speed %d -prev_nnf %s -recalc_dist_temporal 1 -coherence_spatial %f -coherence_temporal %f' % (speed, prev_nnf_filename, random.uniform(0, 100), random.uniform(0, 100))

        if use_allowed:
            suffix += ' -allowed_patches allowed.png'

        if imgcollection_bug:
            prev_nnf_filename = make_prev_nnf()
            coherence_spatial = coherence_temporal = 10.0
            suffix = ' -prev_nnf %s -coherence_spatial %f -coherence_temporal %f -limit 1000000 -allowed_patches allowed.png' % (prev_nnf_filename, coherence_spatial, coherence_temporal)

        cmd = './patchtable match atest.png btest.png out.pfm%s' % suffix
        print cmd
        s = subprocess.check_output(cmd, shell=True)
        print s
        mean_dist = float_last_line(s, 'mean_dist:')
        mean_dist_recomputed = float_last_line(s, 'mean_dist_recomputed:')
        if abs(mean_dist-mean_dist_recomputed) > 1e-4:
            raise ValueError('mean_dist (%f) and mean_dist_recomputed (%f) differ' % (mean_dist, mean_dist_recomputed))

        ann = pfm.readpfm('out.pfm')
        xmin = numpy.min(ann[:,:,0])
        xmax = numpy.max(ann[:,:,0])
        ymin = numpy.min(ann[:,:,1])
        ymax = numpy.max(ann[:,:,1])
        assert xmin >= 0 and ymin >= 0, (xmin, ymin)
        assert xmax < bew and ymax < beh, (xmax, ymax, bew, beh)
        
        if use_allowed:
            for y in range(ann.shape[0]):
                for x in range(ann.shape[1]):
                    bx = ann[y,x,0]
                    by = ann[y,x,1]
                    allowed_val = allowed[by,bx,0]
                    if not allowed_val:
                        raise ValueError('accessed disallowed patch: %d, %d => %d, %d, %d' % (x, y, bx, by, allowed_val))
    print
    print 'Unit tests passed.'