def test_unsupported_operations(): opset_version = 10 def opidToStr(opid): return f'{opid.type}_{opid.version}' unsupportedOps = popart.getUnsupportedOperations(opset_version) unsupportedOps = [opidToStr(i) for i in unsupportedOps] unsupportedOps = set(unsupportedOps) supportedOps = popart.getSupportedOperations(False) supportedOps = [opidToStr(i) for i in supportedOps] supportedOps = set(supportedOps) # Make sure both sets contain elements. assert len(unsupportedOps) > 0 and len(supportedOps) > 0 # There should be no ops that are in both lists. assert len(unsupportedOps & supportedOps) == 0
def test_constant_is_supported(): ops = popart.getSupportedOperations(False) ops = [i.type for i in ops if i.type == 'Constant'] assert len(ops) > 0 print(ops)
def test_shape_is_supported(): ops = popart.getSupportedOperations(False) ops = [i.type for i in ops if i.type == 'Shape'] assert len(ops) > 0 print(ops)
#!/usr/bin/python # Copyright (c) 2018 Graphcore Ltd. All rights reserved. import sys if len(sys.argv) < 3: sys.exit( "gen_supported_ops.py <path of popart python install> <output file>") print("Looking for popart module in " + sys.argv[1]) for p in sys.argv[1].split(':'): sys.path.append(p) import popart supported_ops = popart.getSupportedOperations(False) ops = dict([[x, []] for x in set([x.domain for x in supported_ops])]) for op in supported_ops: ops[op.domain].append({"type": op.type, "version": op.version}) print("Writing supported ops to " + sys.argv[2]) with open(sys.argv[2], "w") as f: for domain in ops: print("Domain: " + domain, file=f) print('-' * (8 + len(domain)), file=f) print("", file=f) for op in ops[domain]: print(f"- {op['type']}-{op['version']}", file=f) print("", file=f)