Esempio n. 1
0
def test_unsupported_operations():
    opset_version = 10

    def opidToStr(opid):
        return f'{opid.type}_{opid.version}'

    unsupportedOps = popart.getUnsupportedOperations(opset_version)
    unsupportedOps = [opidToStr(i) for i in unsupportedOps]
    unsupportedOps = set(unsupportedOps)

    supportedOps = popart.getSupportedOperations(False)
    supportedOps = [opidToStr(i) for i in supportedOps]
    supportedOps = set(supportedOps)

    # Make sure both sets contain elements.
    assert len(unsupportedOps) > 0 and len(supportedOps) > 0
    # There should be no ops that are in both lists.
    assert len(unsupportedOps & supportedOps) == 0
Esempio n. 2
0
def test_constant_is_supported():
    ops = popart.getSupportedOperations(False)
    ops = [i.type for i in ops if i.type == 'Constant']
    assert len(ops) > 0
    print(ops)
Esempio n. 3
0
def test_shape_is_supported():
    ops = popart.getSupportedOperations(False)
    ops = [i.type for i in ops if i.type == 'Shape']
    assert len(ops) > 0
    print(ops)
Esempio n. 4
0
#!/usr/bin/python
# Copyright (c) 2018 Graphcore Ltd. All rights reserved.

import sys

if len(sys.argv) < 3:
    sys.exit(
        "gen_supported_ops.py <path of popart python install> <output file>")

print("Looking for popart module in " + sys.argv[1])
for p in sys.argv[1].split(':'):
    sys.path.append(p)
import popart

supported_ops = popart.getSupportedOperations(False)

ops = dict([[x, []] for x in set([x.domain for x in supported_ops])])
for op in supported_ops:
    ops[op.domain].append({"type": op.type, "version": op.version})

print("Writing supported ops to " + sys.argv[2])
with open(sys.argv[2], "w") as f:
    for domain in ops:
        print("Domain: " + domain, file=f)
        print('-' * (8 + len(domain)), file=f)
        print("", file=f)

        for op in ops[domain]:
            print(f"- {op['type']}-{op['version']}", file=f)

        print("", file=f)