Esempio n. 1
0
def test_gather_access_footprint():
    knl = lp.make_kernel("{[i,k,j]: 0<=i,j,k<n}",
                         ["c[i, j] = sum(k, a[i, k]*b[k, j]) + a[i,j]"],
                         name="matmul",
                         assumptions="n >= 1")
    knl = lp.add_and_infer_dtypes(knl, dict(a=np.float32, b=np.float32))

    from loopy.statistics import gather_access_footprints, count
    fp = gather_access_footprints(knl)

    for key, footprint in six.iteritems(fp):
        print(key, count(knl, footprint))
Esempio n. 2
0
def test_gather_access_footprint_2():
    knl = lp.make_kernel(
            "{[i]: 0<=i<n}",
            "c[2*i] = a[i]",
            name="matmul", assumptions="n >= 1")
    knl = lp.add_and_infer_dtypes(knl, dict(a=np.float32))

    from loopy.statistics import gather_access_footprints, count
    fp = gather_access_footprints(knl)

    params = {"n": 200}
    for key, footprint in six.iteritems(fp):
        assert count(knl, footprint).eval_with_dict(params) == 200
        print(key, count(knl, footprint))
Esempio n. 3
0
def test_gather_access_footprint_2():
    knl = lp.make_kernel(
            "{[i]: 0<=i<n}",
            "c[2*i] = a[i]",
            name="matmul", assumptions="n >= 1")
    knl = lp.add_and_infer_dtypes(knl, dict(a=np.float32))

    from loopy.statistics import gather_access_footprints, count
    fp = gather_access_footprints(knl)

    params = {"n": 200}
    for key, footprint in six.iteritems(fp):
        assert count(knl, footprint).eval_with_dict(params) == 200
        print(key, count(knl, footprint))
Esempio n. 4
0
def test_gather_access_footprint():
    knl = lp.make_kernel(
            "{[i,k,j]: 0<=i,j,k<n}",
            [
                "c[i, j] = sum(k, a[i, k]*b[k, j]) + a[i,j]"
            ],
            name="matmul", assumptions="n >= 1")
    knl = lp.add_and_infer_dtypes(knl, dict(a=np.float32, b=np.float32))

    from loopy.statistics import gather_access_footprints, count
    fp = gather_access_footprints(knl)

    for key, footprint in six.iteritems(fp):
        print(key, count(knl, footprint))