def propagate_downwards(G: NNGraph): for node in G.dfs(): # First propagate the in dim hints to the out dim hints # Any node that does not want this to happen should set its out dim hints if node.in_dims_hint is not None: if isinstance(node, ReshapeParameters): if len(node.old_shape) == len(node.in_dims_hint[0]): LOG.debug("set reshape %s in dims hint %s", node.name, node.in_dims_hint[0]) node.old_shape.apply_naming_hints(node.in_dims_hint[0]) elif isinstance(node, GlobalPoolParameters): if node.keep_dims: node.out_dims_hint = deepcopy(node.in_dims_hint) elif isinstance(node, MatrixBroadcastedLinearOpParameters): max_hint = None for hint in node.in_dims_hint: if hint is not None and (max_hint is None or len(hint) > len(max_hint)): max_hint = hint if max_hint is not None: node.out_dims_hint = [max_hint] elif isinstance(node, ConcatParameters): # if any incoming edge of the concat doesn't have a hint # set it the same as the others any_in_hint = next( (hint for hint in node.in_dims_hint if hint is not None), None) if any_in_hint: LOG.debug("set concat %s in dims hint %s", node.name, any_in_hint) for edge in G.in_edges(node.name): if not node.in_dims_hint[edge.to_idx]: node.in_dims_hint[edge.to_idx] = any_in_hint node.out_dims_hint = [any_in_hint] else: if node.out_dims_hint is None: node.out_dims_hint = deepcopy(node.in_dims_hint) # if we have an out dim hint then propagate it to downstream nodes if node.out_dims_hint is not None: LOG.debug("propagate down hint from %s", node.name) for edge in G.out_edges(node.name): hint = node.out_dims_hint[edge.from_idx] if hint is None: continue if edge.to_node.in_dims_hint is None: edge.to_node.in_dims_hint = SparseList() if edge.to_node.in_dims_hint[edge.to_idx] is None: edge.to_node.in_dims_hint[edge.to_idx] = hint
def propagate_downwards(G: NNGraph): for node in G.dfs(): # First propagate the in dim hints to the out dim hints # Any node that does not want this to happen should set its out dim hints if node.in_dims_hint is not None: if isinstance(node, ReshapeParameters): assert len(node.old_shape) == len(node.in_dims_hint[0]), "reshape doesn't match input" node.old_shape.apply_naming_hints(node.in_dims_hint[0]) else: if node.out_dims_hint is None: node.out_dims_hint = deepcopy(node.in_dims_hint) # if we have an out dim hint then propagate it to downstream nodes if node.out_dims_hint is not None: for edge in G.out_edges(node.name): hint = node.out_dims_hint[edge.from_idx] if edge.to_node.in_dims_hint is None: edge.to_node.in_dims_hint = SparseList() if edge.to_node.in_dims_hint[edge.to_idx] is None: edge.to_node.in_dims_hint[edge.to_idx] = hint