Beispiel #1
0
def _rewrite_while_loop_functions(tf_ssa, fn):
    """
    Rewrite tf.while_loop's sub-graphs with get_tuple, make_tuple,
    function_entry and return ops. This rewrite is required in order to convert
    functional form control flow v2 nodes 'StatelessWhile' and 'While'.

    Parameters
    ----------
    tf_ssa: NetworkEnsemble
        An object that contains multiple functions / sub-graphs.
    fn: SSAFunction
        Function that contains graph to operate on.

    Example
    -------

    Input:

        Before pass "main" graph:

            [placeholder/args_0] --> [while] --> [identity]

        Before pass "cond" graph:

            [const/mean] -------\
            [placeholder] --> [mean] --> [greater]
            [const/greater/y] --------------/

        Before pass "body" graph:

            [const/sub/y] ------\
            [placeholder] ---> [sub]

    Output:

        After pass "main" graph:

            [placeholder/args_0] --> [make_tuple] --> [while] --> [get_tuple] --> [identity]

        After pass "cond" graph:

                                      [const/mean] ------\
            [entry] -> [get_tuple] -> [placeholder] -> [mean] -> [greater] -> [make_tuple] -> [return]
                                      [const/greater/y] ------------/

        After pass "body" graph:

                                      [const/sub/y] ----\
            [entry] -> [get_tuple] -> [placeholder] -> [sub] -> [make_tuple] -> [return]
    """
    for while_name, while_node in fn.graph.copy().items():
        if while_node.op not in {"StatelessWhile", "While"}:
            continue

        cond_fn_name = while_node.attr.get("cond")
        body_fn_name = while_node.attr.get("body")

        msg = "Rewriting '{}' ({}) sub-graphs: cond '{}', body '{}'"
        logging.info(
            msg.format(while_node.name, while_node.op, cond_fn_name,
                       body_fn_name))

        cond_fn = tf_ssa.functions.get(cond_fn_name)
        body_fn = tf_ssa.functions.get(body_fn_name)

        # insert function entry nodes
        cond_entry = _insert_function_entry(cond_fn)
        body_entry = _insert_function_entry(body_fn)

        # pack node inputs to a single tuple
        while_input_tuple = _insert_make_tuple(
            fn, "make_tuple/{}".format(while_name))
        for wi in while_node.inputs:
            disconnect_edge(fn.graph, wi, while_node.name)
            connect_edge(fn.graph, wi, while_input_tuple)
        connect_edge(fn.graph, while_input_tuple, while_node.name)

        # unpack node outputs to multiple get_tuples
        for i, wo in enumerate(while_node.outputs):
            # utilize FunctionDef's ret to make sure function outputs and
            # node outputs order matches when multiple outputs are there.
            o_original = fn.graph[wo].original_node
            while_input = [
                n for n in o_original.input if str(n).startswith(while_name)
            ][0]
            while_index = while_input.split(":")[-1]
            if while_index != 0:
                identity_postfix = "identity_{}".format(while_index)
            else:  # access identity "0"
                identity_postfix = "identity"

            identity_keys = [
                t for t in body_fn.ret.keys() if t.endswith(identity_postfix)
            ]
            if len(identity_keys) != 1:
                raise NotImplementedError("Branch not found.")

            mapped_name = body_fn.ret[identity_keys[0]].split(":")[0]
            idx = body_fn.outputs.index(mapped_name)

            loop_output = _insert_get_tuple(
                fn, "get_tuple/{}/{}".format(idx, while_input), idx)

            edge_idx = fn.graph[wo].inputs.index(while_node.name)
            replace_dest(fn.graph, while_node, wo, loop_output)
            connect_edge_at_index(fn.graph, loop_output, wo, edge_idx)

        # fetch inputs using get_tuple for cond fn
        for i, ci in enumerate(cond_fn.inputs):
            cond_input = _insert_get_tuple(cond_fn,
                                           "get_tuple/{}/{}".format(i, ci), i)
            connect_edge(cond_fn.graph, cond_entry, cond_input)
            replace_node(cond_fn.graph, ci, cond_input)
            delete_node(cond_fn.graph, ci)

        # fetch inputs using get_tuple for body fn
        for i, bi in enumerate(body_fn.inputs):
            new_name = "get_tuple/{}/{}".format(i, bi)

            if bi in body_fn.outputs:  # input is also an output
                body_fn.outputs[body_fn.outputs.index(bi)] = new_name

            body_input = _insert_get_tuple(body_fn, new_name, i)

            connect_edge(body_fn.graph, body_entry, body_input)
            replace_node(body_fn.graph, bi, body_input)
            delete_node(body_fn.graph, bi)

        # returns a tuple of value(s) as output for cond fn
        cond_output = _insert_make_tuple(cond_fn)
        for co in cond_fn.outputs:
            connect_edge(cond_fn.graph, co, cond_output.name)

        cond_return = _insert_return(cond_fn)
        connect_edge(cond_fn.graph, cond_output.name, cond_return.name)

        # returns a tuple of value(s) as output for body branch
        body_output = _insert_make_tuple(body_fn)

        for bo in body_fn.outputs:
            connect_edge(body_fn.graph, bo, body_output.name)

        body_return = _insert_return(body_fn)
        connect_edge(body_fn.graph, body_output.name, body_return.name)
Beispiel #2
0
def _rewrite_cond_functions(tf_ssa, fn):
    r"""
    Rewrite tf.cond's sub-graphs with get_tuple, make_tuple, function_entry and
    return ops. This rewrite is required in order to convert functional form
    control flow v2 nodes 'StatelessIf' and 'If'.

    Parameters
    ----------
    tf_ssa: NetworkEnsemble
        An object that contains multiple functions / sub-graphs.
    fn: SSAFunction
        Function that contains graph to operate on.

    Examples
    --------

    Input:

        Before pass "main" graph:

            [const/greater/y] ---------\
            [placeholder/args_0] -> [greater] -> [if] -> [identity]
                              \------------------/  \--> [identity]
            [placeholder/args_1] ----------------/

        Before pass "then" graph:

            [const/sub/y] ---------------\
            [placeholder/sub_args_0] -> [sub]
            [placeholder/sub_args_1] -> [identity]

        Before pass "else" graph:

            [const/add/y] ---------------\
            [placeholder/add_args_0] -> [add]

            [const/mul/y] ---------------\
            [placeholder/add_args_1] -> [mul]

    Output:

        After pass "main" graph:

            [const/greater/y] ---------\
            [placeholder/args_0] -> [greater] -> [make_tuple] -> [if] -> [get_tuple] -> [identity]
                              \---------------------/               \--> [get_tuple] -> [identity]
            [placeholder/args_1] -------------------/

        After pass "then" graph:

                                      [const/sub/y] ---------------\
            [entry] -> [get_tuple] -> [placeholder/sub_args_0] -> [sub] -> [make_tuple] -> [return]
                    -> [get_tuple] -> [placeholder/sub_args_1] -----------------/

        After pass "else" graph:

                                      [const/add/y] ---------------\
            [entry] -> [get_tuple] -> [placeholder/add_args_0] -> [add] -> [make_tuple] -> [return]
                    -> [get_tuple] -> [placeholder/add_args_1] -> [mul] --------/
                                      [const/mul/y] ---------------/

    """
    for cond_name, cond_node in fn.graph.copy().items():
        if cond_node.op not in {"StatelessIf", "If"}:
            continue

        then_fn_name = cond_node.attr.get("then_branch")
        else_fn_name = cond_node.attr.get("else_branch")

        msg = "Rewriting '{}' ({}) sub-graphs: then '{}', else '{}'"
        logging.info(
            msg.format(cond_node.name, cond_node.op, then_fn_name,
                       else_fn_name))

        then_fn = tf_ssa.functions.get(then_fn_name)
        else_fn = tf_ssa.functions.get(else_fn_name)

        # insert function entry nodes
        then_entry = _insert_function_entry(then_fn)
        else_entry = _insert_function_entry(else_fn)

        # pack node inputs to a single tuple
        cond_input = _insert_make_tuple(fn, "make_tuple/{}".format(cond_name))
        for ci in cond_node.inputs:
            disconnect_edge(fn.graph, ci, cond_node.name)
            connect_edge(fn.graph, ci, cond_input)
        connect_edge(fn.graph, cond_input, cond_node.name)

        # unpack node outputs to multiple get_tuples
        for i, co in enumerate(cond_node.outputs):
            # utilize FunctionDef's ret to make sure function outputs and
            # node outputs order matches when multiple outputs are there.
            # Fallback to use original cond_node.outputs order if fails.
            o_original = fn.graph[co].original_node
            if o_original:
                c_input = [
                    n for n in o_original.input if str(n).startswith(cond_name)
                ][0]
                if ":" in c_input:
                    identity_postfix = "identity_{}".format(
                        c_input.split(":")[-1])
                else:  # access identity "0"
                    identity_postfix = "identity"

                identity_keys = [
                    t for t in then_fn.ret.keys()
                    if t.endswith(identity_postfix)
                ]
                if len(identity_keys) != 1:
                    raise NotImplementedError("Branch not found.")

                mapped_name = then_fn.ret[identity_keys[0]].split(":")[0]

                if mapped_name in then_fn.outputs:
                    idx = then_fn.outputs.index(mapped_name)
                else:  # in else_fn.outputs
                    idx = else_fn.outputs.index(mapped_name)
            else:
                idx = i

            cond_output = _insert_get_tuple(
                fn, "get_tuple/{}/{}".format(idx, cond_name), idx)
            edge_idx = fn.graph[co].inputs.index(cond_node.name)
            replace_dest(fn.graph, cond_node, co, cond_output)
            connect_edge_at_index(fn.graph, cond_output, co, edge_idx)

        # fetch inputs using get_tuple for then branch
        for i, ti in enumerate(then_fn.inputs):
            then_input = _insert_get_tuple(then_fn,
                                           "get_tuple/{}/{}".format(i,
                                                                    ti), i + 1)
            connect_edge(then_fn.graph, then_entry, then_input)
            replace_node(then_fn.graph, ti, then_input)
            delete_node(then_fn.graph, ti)

        # fetch inputs using get_tuple for else branch
        for i, ei in enumerate(else_fn.inputs):
            else_input = _insert_get_tuple(else_fn,
                                           "get_tuple/{}/{}".format(i,
                                                                    ei), i + 1)
            connect_edge(else_fn.graph, else_entry, else_input)
            replace_node(else_fn.graph, ei, else_input)
            delete_node(else_fn.graph, ei)

        # returns a tuple of value(s) as output for then branch
        then_output = _insert_make_tuple(then_fn)
        for to in then_fn.outputs:
            if to not in then_fn.graph.keys():
                # from identity, map back to get_tuple node
                to = "get_tuple/{}/{}".format(then_fn.inputs.index(to), to)
            connect_edge(then_fn.graph, to, then_output.name)

        then_return = _insert_return(then_fn)
        connect_edge(then_fn.graph, then_output.name, then_return.name)

        # returns a tuple of value(s) as output for else branch
        else_output = _insert_make_tuple(else_fn)
        for eo in else_fn.outputs:
            if eo not in else_fn.graph.keys():
                # from identity, map back to get_tuple node
                eo = "get_tuple/{}/{}".format(else_fn.inputs.index(eo), eo)
            connect_edge(else_fn.graph, eo, else_output.name)

        else_return = _insert_return(else_fn)
        connect_edge(else_fn.graph, else_output.name, else_return.name)