def partition(comp_arrow: CompositeArrow) -> List[Set[Arrow]]: """Partitions the comp_arrow into sequential layers of its sub_arrows""" partition_arrows = [] arrow_colors = pqdict() for sub_arrow in comp_arrow.get_sub_arrows(): arrow_colors[sub_arrow] = sub_arrow.num_in_ports() for port in comp_arrow.in_ports(): in_ports = comp_arrow.neigh_in_ports(port) for in_port in in_ports: assert in_port.arrow in arrow_colors, "sub_arrow not in arrow_colors" arrow_colors[in_port.arrow] -= 1 while len(arrow_colors) > 0: arrow_layer = set() view_arrow, view_priority = arrow_colors.topitem() while view_priority == 0: sub_arrow, priority = arrow_colors.popitem() arrow_layer.add(sub_arrow) if len(arrow_colors) == 0: break view_arrow, view_priority = arrow_colors.topitem() partition_arrows.append(arrow_layer) for arrow in arrow_layer: for out_port in arrow.out_ports(): in_ports = comp_arrow.neigh_in_ports(out_port) for in_port in in_ports: if in_port.arrow != comp_arrow: assert in_port.arrow in arrow_colors, "sub_arrow not in arrow_colors" arrow_colors[in_port.arrow] -= 1 return partition_arrows
def inner_interpret(conv: Callable, comp_arrow: CompositeArrow, inputs: List, arrow_colors: MutableMapping[Arrow, int], arrow_inputs: Sequence, state: Dict, port_grab: Dict[Port, Any]): """Convert an comp_arrow to a tensorflow graph and add to graph""" assert len(inputs) == comp_arrow.num_in_ports(), "wrong # inputs" emit_list = [] while len(arrow_colors) > 0: # print_arrow_colors(arrow_colors) # print("Converting ", sub_arrow.name) sub_arrow, priority = arrow_colors.popitem() if sub_arrow is not comp_arrow: assert priority == 0, "Must resolve {} more inputs to {} first".format( priority, sub_arrow) # inputs = [arrow_inputs[sub_arrow][i] for i in range(len(arrow_inputs[sub_arrow]))] inputs = [ arrow_inputs[sub_arrow][i] for i in sorted(arrow_inputs[sub_arrow].keys()) ] outputs = conv(sub_arrow, inputs, state) assert len(outputs) == len( sub_arrow.out_ports()), "diff num outputs" # Decrement the priority of each subarrow connected to this arrow # Unless of course it is connected to the outside word for i, out_port in enumerate(sub_arrow.out_ports()): neigh_in_ports = comp_arrow.neigh_in_ports(out_port) for neigh_in_port in neigh_in_ports: neigh_arrow = neigh_in_port.arrow arrow_colors[neigh_arrow] = arrow_colors[neigh_arrow] - 1 arrow_inputs[neigh_arrow][neigh_in_port.index] = outputs[i] # Extract some port, kind of a hack for port in port_grab: if port.arrow in arrow_inputs: if port.index in arrow_inputs[port.arrow]: port_grab[port] = arrow_inputs[port.arrow][port.index] outputs_dict = arrow_inputs[comp_arrow] out_port_indices = sorted(list(outputs_dict.keys())) return [outputs_dict[i] for i in out_port_indices]