def _setup_call_and_repeat( pb_ir: _ir.Ir, pb_top_graph: _ir.Graph, pb_bottom_graph: _ir.Graph ) -> Tuple[_ir.Graph, _ir.op.CallOp, _ir.op.LoopOp]: """Setup the call and repeat ops, as well as the middle graph that the loop op will loop. Args: pb_ir (_ir.Ir): The _ir level Ir pb_top_graph (_ir.Graph): The _ir top level graph that will contain the loop op. pb_bottom_graph (_ir.Graph): The _ir user defined subgraph that will be called. Returns: Tuple[_ir.Graph, _ir.op.CallOp, _ir.op.LoopOp]: The created _ir-level middle graph, call op and loop op. """ # This is the graph we will repeat. pb_middle_graph = pb_ir.createGraph( _ir.GraphId( pb_ir.createUniqueSubgraphId( f"{pb_bottom_graph.id.str()}__loop_wrapper"))) opid = _ir.OperatorIdentifier("ai.graphcore", "Call", 1, _ir.NumInputs(), 0) op_name = pb_middle_graph.id.str() + '__call__' + pb_bottom_graph.id.str() ctx = get_current_context() # Call the bottom_graph pb_callop = pb_middle_graph.createOp_CallOp(opid, pb_bottom_graph, ctx._get_op_settings(op_name)) opid = _ir.OperatorIdentifier("ai.onnx", "Loop", 11, _ir.NumInputs(), 0) op_name = pb_top_graph.id.str() + '__loop__' + pb_middle_graph.id.str() # Loop the middle_graph pb_loop_op = pb_top_graph.createOp_LoopOp(opid, ctx._get_op_settings(op_name), pb_middle_graph) # Add mandatory loop iterator tensor to subgraph (is not an output) repeatIterId = _ir.addScope(pb_middle_graph, "Iterator___") pb_middle_graph.addInput(repeatIterId, _ir.TensorInfo(_ir.DataType.INT32, ())) # Add mandatory loop condition tensor to subgraph (is also an output) repeatCondId = _ir.addScope(pb_middle_graph, "LoopCond___") pb_middle_graph.addInput(repeatCondId, _ir.TensorInfo(_ir.DataType.BOOL, ())) pb_middle_graph.markAsOutput(repeatCondId) return pb_middle_graph, pb_callop, pb_loop_op
def make_sub_graph(ir: _ir.Ir, ins: Dict[int, _ir.TensorInfo]) -> _ir.Graph: """ Makes the following subgraph, with len(ins) inputs. input0 input1 input2 ... input n │ │ │ │ │ │ │ │ │ │ │ │ └─►add ◄┘ │ │ │ │ │ └──────►add◄┘ │ │ │ │ │ │ │ └────►add ... ▼ add │ ▼ softmax │ ▼ out Args: ir (_ir.Ir): The ir to add the subgraph to ins (Dict[int, _ir.TensorInfo]): The map of in indices to tensorinfos. Returns: _ir.Graph: The subgraph in question. """ g = ir.createGraph(_ir.GraphId("fwd")) for i, tinfo in ins.items(): g.addInput(_ir.addScope(g, f"in{i}"), tinfo) inputs = g.getInputIds() t = g.getTensor(inputs[0]) for i in range(1, len(ins)): settings = _ir.Settings(g, f"add{i}") opid = _ir.OperatorIdentifier("ai.onnx", f"Add{i}", 1, _ir.NumInputs(2, 2), 1) add = g.createConnectedOp_AddOp({ 0: t.id, 1: inputs[i] }, {0: _ir.addScope(g, f"add{i}")}, opid, settings) t = add.outTensor(0) settings = _ir.Settings(g, "softmax0") opid = _ir.OperatorIdentifier("ai.onnx", "SoftMax", 1, _ir.NumInputs(1, 1), 1) sm = g.createConnectedOp_SoftmaxOp({0: t.id}, {0: _ir.addScope(g, "sm0")}, opid=opid, axis_=0, settings=settings) g.markAsOutput(sm.outTensor(0).id) return g