def partition(comp_arrow: CompositeArrow) -> List[Set[Arrow]]: """Partitions the comp_arrow into sequential layers of its sub_arrows""" partition_arrows = [] arrow_colors = pqdict() for sub_arrow in comp_arrow.get_sub_arrows(): arrow_colors[sub_arrow] = sub_arrow.num_in_ports() for port in comp_arrow.in_ports(): in_ports = comp_arrow.neigh_in_ports(port) for in_port in in_ports: assert in_port.arrow in arrow_colors, "sub_arrow not in arrow_colors" arrow_colors[in_port.arrow] -= 1 while len(arrow_colors) > 0: arrow_layer = set() view_arrow, view_priority = arrow_colors.topitem() while view_priority == 0: sub_arrow, priority = arrow_colors.popitem() arrow_layer.add(sub_arrow) if len(arrow_colors) == 0: break view_arrow, view_priority = arrow_colors.topitem() partition_arrows.append(arrow_layer) for arrow in arrow_layer: for out_port in arrow.out_ports(): in_ports = comp_arrow.neigh_in_ports(out_port) for in_port in in_ports: if in_port.arrow != comp_arrow: assert in_port.arrow in arrow_colors, "sub_arrow not in arrow_colors" arrow_colors[in_port.arrow] -= 1 return partition_arrows
def filter_arrows(fil, arrow: CompositeArrow, deep=True): good_arrows = set() for sub_arrow in arrow.get_sub_arrows(): if fil(sub_arrow): good_arrows.add(sub_arrow) if deep and isinstance(sub_arrow, CompositeArrow): for sub_sub_arrow in filter_arrows(fil, sub_arrow, deep=deep): good_arrows.add(sub_sub_arrow) return good_arrows
def gen_arrow_colors(comp_arrow: CompositeArrow): """ Interpret a composite arrow on some inputs Args: comp_arrow: Composite Arrow Returns: arrow_colors: Priority Queue of arrows """ # priority is the number of inputs each arrrow has which have been 'seen' # seen inputs are inputs to the composition, or outputs of arrows that # have already been converted into arrow_colors = pqdict() # type: MutableMapping[Arrow, int] for sub_arrow in comp_arrow.get_sub_arrows(): arrow_colors[sub_arrow] = sub_arrow.num_in_ports() # TODO: Unify arrow_colors[comp_arrow] = comp_arrow.num_out_ports() return arrow_colors
def inner_invert(comp_arrow: CompositeArrow, port_attr: PortAttributes, dispatch: Dict[Arrow, Callable]): """Construct a parametric inverse of arrow Args: arrow: Arrow to invert dispatch: Dict mapping arrow class to invert function Returns: A (approximate) parametric inverse of `arrow` The ith in_port of comp_arrow will be corresponding ith out_port error_ports and param_ports will follow""" # Empty compositon for inverse inv_comp_arrow = CompositeArrow(name="%s_inv" % comp_arrow.name) # import pdb; pdb.set_trace() # Add a port on inverse arrow for every port on arrow for port in comp_arrow.ports(): inv_port = inv_comp_arrow.add_port() if is_in_port(port): if is_constant(port, port_attr): make_in_port(inv_port) else: make_out_port(inv_port) elif is_out_port(port): if is_constant(port, port_attr): make_out_port(inv_port) else: make_in_port(inv_port) # Transfer port information # FIXME: What port_attr go transfered from port to inv_port if 'shape' not in port_attr[port]: print('WARNING: shape unknown for %s' % port) else: set_port_shape(inv_port, get_port_shape(port, port_attr)) # invert each sub_arrow arrow_to_inv = dict() arrow_to_port_map = dict() for sub_arrow in comp_arrow.get_sub_arrows(): inv_sub_arrow, port_map = invert_sub_arrow(sub_arrow, port_attr, dispatch) assert sub_arrow is not None assert inv_sub_arrow.parent is None arrow_to_port_map[sub_arrow] = port_map arrow_to_inv[sub_arrow] = inv_sub_arrow # Add comp_arrow to inv assert comp_arrow is not None arrow_to_inv[comp_arrow] = inv_comp_arrow comp_port_map = {i: i for i in range(comp_arrow.num_ports())} arrow_to_port_map[comp_arrow] = comp_port_map # Then, rewire up all the edges for out_port, in_port in comp_arrow.edges.items(): left_inv_port = get_inv_port(out_port, arrow_to_port_map, arrow_to_inv) right_inv_port = get_inv_port(in_port, arrow_to_port_map, arrow_to_inv) both = [left_inv_port, right_inv_port] projecting = list( filter(lambda x: would_project(x, inv_comp_arrow), both)) receiving = list( filter(lambda x: would_receive(x, inv_comp_arrow), both)) assert len(projecting) == 1, "Should be only 1 projecting" assert len(receiving) == 1, "Should be only 1 receiving" inv_comp_arrow.add_edge(projecting[0], receiving[0]) for transform in transforms: transform(inv_comp_arrow) # Craete new ports on inverse compositions for parametric and error ports for sub_arrow in inv_comp_arrow.get_sub_arrows(): for port in sub_arrow.ports(): if is_param_port(port): assert port not in inv_comp_arrow.edges.keys() assert port not in inv_comp_arrow.edges.values() param_port = inv_comp_arrow.add_port() inv_comp_arrow.add_edge(param_port, port) make_in_port(param_port) make_param_port(param_port) elif is_error_port(port): assert port not in inv_comp_arrow.edges.keys() assert port not in inv_comp_arrow.edges.values() error_port = inv_comp_arrow.add_port() inv_comp_arrow.add_edge(port, error_port) make_out_port(error_port) make_error_port(error_port) return inv_comp_arrow, comp_port_map