Exemplo n.º 1
0
def _setup_call_and_repeat(
    pb_ir: _ir.Ir, pb_top_graph: _ir.Graph, pb_bottom_graph: _ir.Graph
) -> Tuple[_ir.Graph, _ir.op.CallOp, _ir.op.LoopOp]:
    """Setup the call and repeat ops, as well as the middle graph that the loop op will loop.

    Args:
        pb_ir (_ir.Ir): The _ir level Ir
        pb_top_graph (_ir.Graph): The _ir top level graph that will contain the loop op.
        pb_bottom_graph (_ir.Graph): The _ir user defined subgraph that will be called.

    Returns:
        Tuple[_ir.Graph, _ir.op.CallOp, _ir.op.LoopOp]: The created _ir-level middle graph, call op
            and loop op.
    """
    # This is the graph we will repeat.
    pb_middle_graph = pb_ir.createGraph(
        _ir.GraphId(
            pb_ir.createUniqueSubgraphId(
                f"{pb_bottom_graph.id.str()}__loop_wrapper")))

    opid = _ir.OperatorIdentifier("ai.graphcore", "Call", 1, _ir.NumInputs(),
                                  0)
    op_name = pb_middle_graph.id.str() + '__call__' + pb_bottom_graph.id.str()

    ctx = get_current_context()
    # Call the bottom_graph
    pb_callop = pb_middle_graph.createOp_CallOp(opid, pb_bottom_graph,
                                                ctx._get_op_settings(op_name))

    opid = _ir.OperatorIdentifier("ai.onnx", "Loop", 11, _ir.NumInputs(), 0)
    op_name = pb_top_graph.id.str() + '__loop__' + pb_middle_graph.id.str()

    # Loop the middle_graph
    pb_loop_op = pb_top_graph.createOp_LoopOp(opid,
                                              ctx._get_op_settings(op_name),
                                              pb_middle_graph)

    # Add mandatory loop iterator tensor to subgraph (is not an output)
    repeatIterId = _ir.addScope(pb_middle_graph, "Iterator___")
    pb_middle_graph.addInput(repeatIterId,
                             _ir.TensorInfo(_ir.DataType.INT32, ()))

    # Add mandatory loop condition tensor to subgraph (is also an output)
    repeatCondId = _ir.addScope(pb_middle_graph, "LoopCond___")
    pb_middle_graph.addInput(repeatCondId,
                             _ir.TensorInfo(_ir.DataType.BOOL, ()))
    pb_middle_graph.markAsOutput(repeatCondId)

    return pb_middle_graph, pb_callop, pb_loop_op
Exemplo n.º 2
0
def make_sub_graph(ir: _ir.Ir, ins: Dict[int, _ir.TensorInfo]) -> _ir.Graph:
    """
    Makes the following subgraph, with len(ins) inputs. 

    input0  input1  input2  ...  input n
    │       │       │            │
    │       │       │            │
    │       │       │            │
    └─►add ◄┘       │            │
        │           │            │
        └──────►add◄┘            │
                │                │
                │                │
                │                │
                └────►add ...    ▼


                               add
                                │
                                ▼
                             softmax
                                │
                                ▼
                               out

    Args:
        ir (_ir.Ir): The ir to add the subgraph to
        ins (Dict[int, _ir.TensorInfo]): The map of in indices to tensorinfos.

    Returns:
        _ir.Graph: The subgraph in question.
    """
    g = ir.createGraph(_ir.GraphId("fwd"))

    for i, tinfo in ins.items():
        g.addInput(_ir.addScope(g, f"in{i}"), tinfo)

    inputs = g.getInputIds()

    t = g.getTensor(inputs[0])
    for i in range(1, len(ins)):
        settings = _ir.Settings(g, f"add{i}")
        opid = _ir.OperatorIdentifier("ai.onnx", f"Add{i}", 1,
                                      _ir.NumInputs(2, 2), 1)
        add = g.createConnectedOp_AddOp({
            0: t.id,
            1: inputs[i]
        }, {0: _ir.addScope(g, f"add{i}")}, opid, settings)
        t = add.outTensor(0)

    settings = _ir.Settings(g, "softmax0")
    opid = _ir.OperatorIdentifier("ai.onnx", "SoftMax", 1, _ir.NumInputs(1, 1),
                                  1)
    sm = g.createConnectedOp_SoftmaxOp({0: t.id}, {0: _ir.addScope(g, "sm0")},
                                       opid=opid,
                                       axis_=0,
                                       settings=settings)

    g.markAsOutput(sm.outTensor(0).id)

    return g