Exemple #1
0
def test_input_elementwise_sum_node():
    for s in [(),
              (3, 4, 5)]:
        network = tn.ContainerNode(
            "all",
            [tn.InputElementwiseSumNode("ies"),
             tn.SequentialNode(
                 "seq1",
                [tn.InputNode("i1", shape=s),
                 tn.SendToNode("st1", reference="ies", to_key="in1")]),
             tn.SequentialNode(
                 "seq2",
                 [tn.InputNode("i2", shape=s),
                  tn.SendToNode("st2", reference="ies", to_key="in2")]),
             tn.SequentialNode(
                 "seq3",
                 [tn.InputNode("i3", shape=s),
                  tn.SendToNode("st3", reference="ies", to_key="in3")])]
        ).network()
        fn = network.function(["i1", "i2", "i3"], ["ies"])
        i1 = np.array(np.random.rand(*s), dtype=fX)
        i2 = np.array(np.random.rand(*s), dtype=fX)
        i3 = np.array(np.random.rand(*s), dtype=fX)
        np.testing.assert_allclose(i1 + i2 + i3,
                                   fn(i1, i2, i3)[0],
                                   rtol=1e-5)
Exemple #2
0
def test_container_node_raises():
    network = tn.SequentialNode(
        "s",
        [tn.ContainerNode("c", []),
         tn.IdentityNode("i")
         ]).network()
    fn = network.function([], ["i"])
    fn()
Exemple #3
0
 def fcn_network(combine_fn):
     network = tn.ContainerNode("c", [
         tn.SequentialNode(
             "s1",
             [tn.InputNode("in1", shape=(3, 4, 5)),
              tn.SendToNode("stn1", reference="fcn", to_key="b")]),
         tn.SequentialNode(
             "s2",
             [tn.InputNode("in2", shape=(3, 4, 5)),
              tn.SendToNode("stn2", reference="fcn", to_key="a")]),
         tn.InputFunctionCombineNode("fcn", combine_fn=combine_fn)
     ]).network()
     return network.function(["in1", "in2"], ["fcn"])
Exemple #4
0
def test_node_with_generated_children_can_serialize():
    root_node = tn.ContainerNode("c", [
        tn.SequentialNode("s1", [
            tn.InputNode("in", shape=(3, 4, 5)),
            tn.SendToNode("stn1", reference="s2")
        ]),
        tn.SequentialNode("s2", [tn.SendToNode("stn2", reference="stn3")]),
        tn.SequentialNode("s3", [tn.SendToNode("stn3", reference="i")]),
        tn.IdentityNode("i"),
    ])
    root_node.network().build()
    root2 = treeano.core.node_from_data(treeano.core.node_to_data(root_node))
    nt.assert_equal(root_node, root2)
Exemple #5
0
def test_network_doesnt_mutate():
    root_node = tn.ContainerNode("c", [
        tn.SequentialNode("s1", [
            tn.InputNode("in", shape=(3, 4, 5)),
            tn.SendToNode("stn1", reference="s2")
        ]),
        tn.SequentialNode("s2", [tn.SendToNode("stn2", reference="stn3")]),
        tn.SequentialNode("s3", [tn.SendToNode("stn3", reference="i")]),
        tn.IdentityNode("i"),
    ])
    original_dict = copy.deepcopy(root_node.__dict__)
    root_node.network().build()
    nt.assert_equal(original_dict, root_node.__dict__)
Exemple #6
0
def test_send_to_node():
    network = tn.ContainerNode("c", [
        tn.SequentialNode("s1", [
            tn.InputNode("in", shape=(3, 4, 5)),
            tn.SendToNode("stn1", reference="s2")
        ]),
        tn.SequentialNode("s2", [tn.SendToNode("stn2", reference="stn3")]),
        tn.SequentialNode("s3", [tn.SendToNode("stn3", reference="i")]),
        tn.IdentityNode("i"),
    ]).network()

    fn = network.function(["in"], ["i"])
    x = np.random.randn(3, 4, 5).astype(fX)
    np.testing.assert_allclose(fn(x)[0], x)