def test_input_elementwise_sum_node(): for s in [(), (3, 4, 5)]: network = tn.ContainerNode( "all", [tn.InputElementwiseSumNode("ies"), tn.SequentialNode( "seq1", [tn.InputNode("i1", shape=s), tn.SendToNode("st1", reference="ies", to_key="in1")]), tn.SequentialNode( "seq2", [tn.InputNode("i2", shape=s), tn.SendToNode("st2", reference="ies", to_key="in2")]), tn.SequentialNode( "seq3", [tn.InputNode("i3", shape=s), tn.SendToNode("st3", reference="ies", to_key="in3")])] ).network() fn = network.function(["i1", "i2", "i3"], ["ies"]) i1 = np.array(np.random.rand(*s), dtype=fX) i2 = np.array(np.random.rand(*s), dtype=fX) i3 = np.array(np.random.rand(*s), dtype=fX) np.testing.assert_allclose(i1 + i2 + i3, fn(i1, i2, i3)[0], rtol=1e-5)
def test_container_node_raises(): network = tn.SequentialNode( "s", [tn.ContainerNode("c", []), tn.IdentityNode("i") ]).network() fn = network.function([], ["i"]) fn()
def fcn_network(combine_fn): network = tn.ContainerNode("c", [ tn.SequentialNode( "s1", [tn.InputNode("in1", shape=(3, 4, 5)), tn.SendToNode("stn1", reference="fcn", to_key="b")]), tn.SequentialNode( "s2", [tn.InputNode("in2", shape=(3, 4, 5)), tn.SendToNode("stn2", reference="fcn", to_key="a")]), tn.InputFunctionCombineNode("fcn", combine_fn=combine_fn) ]).network() return network.function(["in1", "in2"], ["fcn"])
def test_node_with_generated_children_can_serialize(): root_node = tn.ContainerNode("c", [ tn.SequentialNode("s1", [ tn.InputNode("in", shape=(3, 4, 5)), tn.SendToNode("stn1", reference="s2") ]), tn.SequentialNode("s2", [tn.SendToNode("stn2", reference="stn3")]), tn.SequentialNode("s3", [tn.SendToNode("stn3", reference="i")]), tn.IdentityNode("i"), ]) root_node.network().build() root2 = treeano.core.node_from_data(treeano.core.node_to_data(root_node)) nt.assert_equal(root_node, root2)
def test_network_doesnt_mutate(): root_node = tn.ContainerNode("c", [ tn.SequentialNode("s1", [ tn.InputNode("in", shape=(3, 4, 5)), tn.SendToNode("stn1", reference="s2") ]), tn.SequentialNode("s2", [tn.SendToNode("stn2", reference="stn3")]), tn.SequentialNode("s3", [tn.SendToNode("stn3", reference="i")]), tn.IdentityNode("i"), ]) original_dict = copy.deepcopy(root_node.__dict__) root_node.network().build() nt.assert_equal(original_dict, root_node.__dict__)
def test_send_to_node(): network = tn.ContainerNode("c", [ tn.SequentialNode("s1", [ tn.InputNode("in", shape=(3, 4, 5)), tn.SendToNode("stn1", reference="s2") ]), tn.SequentialNode("s2", [tn.SendToNode("stn2", reference="stn3")]), tn.SequentialNode("s3", [tn.SendToNode("stn3", reference="i")]), tn.IdentityNode("i"), ]).network() fn = network.function(["in"], ["i"]) x = np.random.randn(3, 4, 5).astype(fX) np.testing.assert_allclose(fn(x)[0], x)