def test_op_attributes(attribute: str, shorthand: str, input_id: int): """Test the various attributes that can be applied to ops. Args: attribute (str): Name of the attribute shorthand (str): Shorthand of the attribute e.g. VirtualGraphId -> VGraphId input_id (int): Long int for the id to use for the attribute. """ _, graphs = create_ir(["A"]) g = graphs[0] settings = _ir.Settings(g, "new_settings") num_inputs = _ir.NumInputs(1, 1) opid = _ir.OperatorIdentifier("ai.onnx", "Identity", 1, num_inputs, 1) op = _ir.Op(opid, settings) getter = getattr(op, "get" + attribute) setter = getattr(op, "set" + attribute) hasser = getattr(op, "has" + attribute) get_optional = getattr(op, "getOptional" + shorthand) assert not hasser() id_ = getattr(_ir, "Optional" + shorthand) # Unset optional setter(id_()) assert get_optional() == id_() with pytest.raises(popart.popart_exception) as e_info: getter() assert (e_info.value.args[0] == f"Cannot return {attribute} for Op") assert not hasser() # Set optional == 0 setter(id_(input_id)) assert getter() == input_id assert hasser() assert get_optional() == id_(input_id)
def test_shapes(shape1: List[int], shape2: List[int], expected: List[int], dtype: str): """Test the shapes and np broadcasting. Don't really need to test the broadcasting as that is tested at C++ level. But try a few cases to be sure binding works correctly. Args: shape1 (List[int]): First tensor shape shape2 (List[int]): Second Tensor Shape expected (List[int]): Expected shape dtype (Str): Popart data type to use """ ir, graphs = create_ir(["A"]) g = graphs[0] settings = _ir.Settings(g, "new_settings") num_inputs = _ir.NumInputs(1, 1) opid = _ir.OperatorIdentifier("ai.onnx", "Identity", 1, num_inputs, 1) op = _ir.Op(opid, settings) shape = op.prettyNpOut(shape1, shape2) assert shape == list(expected) t1 = _ir.TensorInfo(dtype, shape1) t2 = _ir.TensorInfo(dtype, shape2) shape = op.prettyNpOut(t1, t2) assert shape == _ir.TensorInfo(dtype, expected)
def test_multi_graph(): """Test adding ops to multiple graphs. """ ir, graphs = create_ir(["A", "B"]) g = graphs[0] h = graphs[1] settings_g = _ir.Settings(g, "settings_g") settings_h = _ir.Settings(h, "settings_h") num_inputs = _ir.NumInputs(1, 1) opid = _ir.OperatorIdentifier("ai.onnx", "Identity", 1, num_inputs, 1) op1 = _ir.Op(opid, settings_g) op2 = _ir.Op(opid, settings_h) assert op1.id == 100 # default Id assert op2.id == 101 # default Id + 1 assert op1.opid == opid == op2.opid assert op1.getGraph() == g assert op2.getGraph() == h
def test_graph_in_outs(): """Test default behaviour for no inputs or outputs. """ ir, graphs = create_ir(["A"]) g = graphs[0] settings = _ir.Settings(g, "new_settings") num_inputs = _ir.NumInputs(1, 1) opid = _ir.OperatorIdentifier("ai.onnx", "Identity", 1, num_inputs, 1) op = _ir.Op(opid, settings) assert op.hasInput(0) == False assert op.hasOutput(0) == False assert op.optionalInputs() == set() assert op.getInBatchAxis(0) == 0 assert op.getOutBatchAxis(0) == 0
def test_op_creation(): """Test simple op creation. """ ir, graphs = create_ir(["A"]) g = graphs[0] settings = _ir.Settings(g, "new_settings") num_inputs = _ir.NumInputs(1, 1) opid = _ir.OperatorIdentifier("ai.onnx", "Identity", 1, num_inputs, 1) op = _ir.Op(opid, settings) assert op.id == 100 # default Id assert op.opid == opid assert op.opid.domain == "ai.onnx" assert op.opid.type == "Identity" assert op.opid.version == 1 assert op.opid.numOutputs == 1
def test_op_clone(): """Op::Clone is pure virtual, this should throw an error. Derived classes should be able to call without issue. """ ir, graphs = create_ir(["A"]) g = graphs[0] settings = _ir.Settings(g, "new_settings") num_inputs = _ir.NumInputs(1, 1) opid = _ir.OperatorIdentifier("ai.onnx", "Identity", 1, num_inputs, 1) op = _ir.Op(opid, settings) with pytest.raises(RuntimeError) as e_info: op2 = op.clone() assert ( e_info.value.args[0] == "RuntimeError: Tried to call pure virtual function \"Op::clone\"")
def create_dummy_op(op_domain: str, op_type: str, op_version: int, num_inputs: int, num_outputs: int) -> _ir.Op: """Create an op with the provided properties. Args: op_domain (str): Op domain op_type (str): Op type name op_version (int): Op version num_inputs (int): Max = min number of outputs num_outputs (int): Number of outputs Returns: _ir.Op: The op in question. """ ir, graphs = create_ir(["graph_123"]) graph = graphs[0] settings = _ir.Settings(graph, "new_settings") num_inputs_obj = _ir.NumInputs(num_inputs, num_inputs) opid = _ir.OperatorIdentifier(op_domain, op_type, op_version, num_inputs_obj, num_outputs) return _ir.Op(opid, settings), ir, graph
def test_bools(): """Test default behaviour of bool returns. """ ir, graphs = create_ir(["A"]) g = graphs[0] settings = _ir.Settings(g, "new_settings") num_inputs = _ir.NumInputs(1, 1) opid = _ir.OperatorIdentifier("ai.onnx", "Identity", 1, num_inputs, 1) op = _ir.Op(opid, settings) assert op.isInplaceViewChange() == False assert op.isOutplaceViewChange() == False assert op.isLossOp() == False assert op.isIpuCopyOp() == False assert op.isOptimizerOp() == False assert op.requiresRandomSeed() == False assert op.isOutlineable() assert op.hasSideEffect() == False assert op.isNorm() == False assert op.canBeReplacedByIdentity() == False assert op.copiesOptimizerTensors() == False assert op.inputsUnmodifiable() == False assert op.isElementWiseUnary() == False
def test_string_methods(op_name: str, domain: str, op_type: str, op_num: int, op_version: int): """Test various string methods (name, id etc) Args: op_name (str): Name for the op domain (str): Domain e.g. ai.onnx op_type (str): Op type op_num (int): Op number to test against (default 100) op_version (int): Op version """ ir, graphs = create_ir(["A"]) g = graphs[0] settings = _ir.Settings(g, "new_settings") num_inputs = _ir.NumInputs(1, 1) opid = _ir.OperatorIdentifier(domain, op_type, op_version, num_inputs, 1) op = _ir.Op(opid, settings) op.setName(op_name) assert op.getName() == op_name assert op.name() == op_name assert op.str() == f"{op_num} ({domain}.{op_type}:{op_version})" assert op.debugName( ) == f"Op({op_name} ({domain}.{op_type}:{op_version}), inputs=[], outputs=[])"