def test_const_inp_but_non_foldable_nested_graph(self): cond = gs.Constant("cond", values=np.array(True)) X = gs.Variable("X", dtype=np.float32, shape=(1, )) graph = Graph(inputs=[X]) then_graph = Graph(name="Then") then_graph.outputs = [then_graph.add(X, X)] else_graph = Graph(name="Else") else_graph.outputs = [else_graph.add(X, else_graph.add(X, X))] # Even though if_op looks foldable because it has all constant inputs, # it's not, since its subgraphs depend on variables in the outer scope. graph.outputs = [graph.if_op(cond, then_graph, else_graph)] # This should not raise because the `If` node should be excluded from # constant folding. graph.fold_constants(error_ok=False).cleanup() assert graph.nodes[0].op == "If" assert len(then_graph.nodes) == 1 assert len(else_graph.nodes) == 2
def test_with_nested_graph(self): cond = gs.Variable("cond", dtype=np.bool, shape=(1, )) X = gs.Variable("X", dtype=np.float32, shape=(1, )) Y = gs.Constant("Y", values=np.ones((1, ), dtype=np.float32)) graph = Graph(inputs=[X, cond]) then_graph = Graph(name="Then") then_graph.outputs = [then_graph.add(Y, Y)] else_graph = Graph(name="Else") else_graph.outputs = [else_graph.add(X, else_graph.add(Y, Y))] graph.outputs = [graph.if_op(cond, then_graph, else_graph)] graph.fold_constants() graph.cleanup() assert len(then_graph.nodes) == 0 assert np.all(then_graph.outputs[0].values == (Y.values * 2)) assert len(else_graph.nodes) == 1 assert isinstance(else_graph.nodes[0].inputs[1], Constant) assert np.all(else_graph.nodes[0].inputs[1].values == (Y.values * 2))