示例#1
0
def partition(comp_arrow: CompositeArrow) -> List[Set[Arrow]]:
    """Partitions the comp_arrow into sequential layers of its sub_arrows"""
    partition_arrows = []
    arrow_colors = pqdict()
    for sub_arrow in comp_arrow.get_sub_arrows():
        arrow_colors[sub_arrow] = sub_arrow.num_in_ports()

    for port in comp_arrow.in_ports():
        in_ports = comp_arrow.neigh_in_ports(port)
        for in_port in in_ports:
            assert in_port.arrow in arrow_colors, "sub_arrow not in arrow_colors"
            arrow_colors[in_port.arrow] -= 1

    while len(arrow_colors) > 0:
        arrow_layer = set()
        view_arrow, view_priority = arrow_colors.topitem()

        while view_priority == 0:
            sub_arrow, priority = arrow_colors.popitem()
            arrow_layer.add(sub_arrow)
            if len(arrow_colors) == 0:
                break
            view_arrow, view_priority = arrow_colors.topitem()

        partition_arrows.append(arrow_layer)

        for arrow in arrow_layer:
            for out_port in arrow.out_ports():
                in_ports = comp_arrow.neigh_in_ports(out_port)
                for in_port in in_ports:
                    if in_port.arrow != comp_arrow:
                        assert in_port.arrow in arrow_colors, "sub_arrow not in arrow_colors"
                        arrow_colors[in_port.arrow] -= 1

    return partition_arrows
示例#2
0
def filter_arrows(fil, arrow: CompositeArrow, deep=True):
    good_arrows = set()
    for sub_arrow in arrow.get_sub_arrows():
        if fil(sub_arrow):
            good_arrows.add(sub_arrow)
        if deep and isinstance(sub_arrow, CompositeArrow):
            for sub_sub_arrow in filter_arrows(fil, sub_arrow, deep=deep):
                good_arrows.add(sub_sub_arrow)
    return good_arrows
示例#3
0
def gen_arrow_colors(comp_arrow: CompositeArrow):
    """
    Interpret a composite arrow on some inputs
    Args:
        comp_arrow: Composite Arrow
    Returns:
        arrow_colors: Priority Queue of arrows
    """
    # priority is the number of inputs each arrrow has which have been 'seen'
    # seen inputs are inputs to the composition, or outputs of arrows that
    # have already been converted into
    arrow_colors = pqdict()  # type: MutableMapping[Arrow, int]
    for sub_arrow in comp_arrow.get_sub_arrows():
        arrow_colors[sub_arrow] = sub_arrow.num_in_ports()

    # TODO: Unify
    arrow_colors[comp_arrow] = comp_arrow.num_out_ports()
    return arrow_colors
示例#4
0
def inner_invert(comp_arrow: CompositeArrow, port_attr: PortAttributes,
                 dispatch: Dict[Arrow, Callable]):
    """Construct a parametric inverse of arrow
    Args:
        arrow: Arrow to invert
        dispatch: Dict mapping arrow class to invert function
    Returns:
        A (approximate) parametric inverse of `arrow`
        The ith in_port of comp_arrow will be corresponding ith out_port
        error_ports and param_ports will follow"""
    # Empty compositon for inverse
    inv_comp_arrow = CompositeArrow(name="%s_inv" % comp_arrow.name)
    # import pdb; pdb.set_trace()

    # Add a port on inverse arrow for every port on arrow
    for port in comp_arrow.ports():
        inv_port = inv_comp_arrow.add_port()
        if is_in_port(port):
            if is_constant(port, port_attr):
                make_in_port(inv_port)
            else:
                make_out_port(inv_port)
        elif is_out_port(port):
            if is_constant(port, port_attr):
                make_out_port(inv_port)
            else:
                make_in_port(inv_port)
        # Transfer port information
        # FIXME: What port_attr go transfered from port to inv_port
        if 'shape' not in port_attr[port]:
            print('WARNING: shape unknown for %s' % port)
        else:
            set_port_shape(inv_port, get_port_shape(port, port_attr))

    # invert each sub_arrow
    arrow_to_inv = dict()
    arrow_to_port_map = dict()
    for sub_arrow in comp_arrow.get_sub_arrows():
        inv_sub_arrow, port_map = invert_sub_arrow(sub_arrow, port_attr,
                                                   dispatch)
        assert sub_arrow is not None
        assert inv_sub_arrow.parent is None
        arrow_to_port_map[sub_arrow] = port_map
        arrow_to_inv[sub_arrow] = inv_sub_arrow

    # Add comp_arrow to inv
    assert comp_arrow is not None
    arrow_to_inv[comp_arrow] = inv_comp_arrow
    comp_port_map = {i: i for i in range(comp_arrow.num_ports())}
    arrow_to_port_map[comp_arrow] = comp_port_map

    # Then, rewire up all the edges
    for out_port, in_port in comp_arrow.edges.items():
        left_inv_port = get_inv_port(out_port, arrow_to_port_map, arrow_to_inv)
        right_inv_port = get_inv_port(in_port, arrow_to_port_map, arrow_to_inv)
        both = [left_inv_port, right_inv_port]
        projecting = list(
            filter(lambda x: would_project(x, inv_comp_arrow), both))
        receiving = list(
            filter(lambda x: would_receive(x, inv_comp_arrow), both))
        assert len(projecting) == 1, "Should be only 1 projecting"
        assert len(receiving) == 1, "Should be only 1 receiving"
        inv_comp_arrow.add_edge(projecting[0], receiving[0])

    for transform in transforms:
        transform(inv_comp_arrow)

    # Craete new ports on inverse compositions for parametric and error ports
    for sub_arrow in inv_comp_arrow.get_sub_arrows():
        for port in sub_arrow.ports():
            if is_param_port(port):
                assert port not in inv_comp_arrow.edges.keys()
                assert port not in inv_comp_arrow.edges.values()
                param_port = inv_comp_arrow.add_port()
                inv_comp_arrow.add_edge(param_port, port)
                make_in_port(param_port)
                make_param_port(param_port)
            elif is_error_port(port):
                assert port not in inv_comp_arrow.edges.keys()
                assert port not in inv_comp_arrow.edges.values()
                error_port = inv_comp_arrow.add_port()
                inv_comp_arrow.add_edge(port, error_port)
                make_out_port(error_port)
                make_error_port(error_port)

    return inv_comp_arrow, comp_port_map