Example #1
0
def test_input_elementwise_sum_node():
    for s in [(),
              (3, 4, 5)]:
        network = tn.ContainerNode(
            "all",
            [tn.InputElementwiseSumNode("ies"),
             tn.SequentialNode(
                 "seq1",
                [tn.InputNode("i1", shape=s),
                 tn.SendToNode("st1", reference="ies", to_key="in1")]),
             tn.SequentialNode(
                 "seq2",
                 [tn.InputNode("i2", shape=s),
                  tn.SendToNode("st2", reference="ies", to_key="in2")]),
             tn.SequentialNode(
                 "seq3",
                 [tn.InputNode("i3", shape=s),
                  tn.SendToNode("st3", reference="ies", to_key="in3")])]
        ).network()
        fn = network.function(["i1", "i2", "i3"], ["ies"])
        i1 = np.array(np.random.rand(*s), dtype=fX)
        i2 = np.array(np.random.rand(*s), dtype=fX)
        i3 = np.array(np.random.rand(*s), dtype=fX)
        np.testing.assert_allclose(i1 + i2 + i3,
                                   fn(i1, i2, i3)[0],
                                   rtol=1e-5)
Example #2
0
 def fcn_network(combine_fn):
     network = tn.ContainerNode("c", [
         tn.SequentialNode(
             "s1",
             [tn.InputNode("in1", shape=(3, 4, 5)),
              tn.SendToNode("stn1", reference="fcn", to_key="b")]),
         tn.SequentialNode(
             "s2",
             [tn.InputNode("in2", shape=(3, 4, 5)),
              tn.SendToNode("stn2", reference="fcn", to_key="a")]),
         tn.InputFunctionCombineNode("fcn", combine_fn=combine_fn)
     ]).network()
     return network.function(["in1", "in2"], ["fcn"])
Example #3
0
def test_node_with_generated_children_can_serialize():
    root_node = tn.ContainerNode("c", [
        tn.SequentialNode("s1", [
            tn.InputNode("in", shape=(3, 4, 5)),
            tn.SendToNode("stn1", reference="s2")
        ]),
        tn.SequentialNode("s2", [tn.SendToNode("stn2", reference="stn3")]),
        tn.SequentialNode("s3", [tn.SendToNode("stn3", reference="i")]),
        tn.IdentityNode("i"),
    ])
    root_node.network().build()
    root2 = treeano.core.node_from_data(treeano.core.node_to_data(root_node))
    nt.assert_equal(root_node, root2)
Example #4
0
def test_network_doesnt_mutate():
    root_node = tn.ContainerNode("c", [
        tn.SequentialNode("s1", [
            tn.InputNode("in", shape=(3, 4, 5)),
            tn.SendToNode("stn1", reference="s2")
        ]),
        tn.SequentialNode("s2", [tn.SendToNode("stn2", reference="stn3")]),
        tn.SequentialNode("s3", [tn.SendToNode("stn3", reference="i")]),
        tn.IdentityNode("i"),
    ])
    original_dict = copy.deepcopy(root_node.__dict__)
    root_node.network().build()
    nt.assert_equal(original_dict, root_node.__dict__)
Example #5
0
def test_send_to_node():
    network = tn.ContainerNode("c", [
        tn.SequentialNode("s1", [
            tn.InputNode("in", shape=(3, 4, 5)),
            tn.SendToNode("stn1", reference="s2")
        ]),
        tn.SequentialNode("s2", [tn.SendToNode("stn2", reference="stn3")]),
        tn.SequentialNode("s3", [tn.SendToNode("stn3", reference="i")]),
        tn.IdentityNode("i"),
    ]).network()

    fn = network.function(["in"], ["i"])
    x = np.random.randn(3, 4, 5).astype(fX)
    np.testing.assert_allclose(fn(x)[0], x)
 def architecture_children(self):
     return [
         tn.AuxiliaryNode(
             self.name + "_auxiliary",
             tn.SequentialNode(
                 self.name + "_sequential",
                 [ElementwiseKLSparsityPenaltyNode(
                     self.name + "_sparsitypenalty"),
                  tn.AggregatorNode(self.name + "_aggregator"),
                  tn.MultiplyConstantNode(self.name + "_multiplyweight"),
                  tn.SendToNode(self.name + "_sendto", to_key=self.name)]))]
Example #7
0
 def architecture_children(self):
     inner = self.raw_children()
     input_node = tn.IdentityNode(self.name + "_input")
     return [
         tn.SequentialNode(self.name + "_sequential", [
             input_node, inner,
             tn.AuxiliaryNode(
                 self.name + "_auxiliary",
                 tn.SequentialNode(self.name + "_innerseq", [
                     ElementwiseContractionPenaltyNode(
                         self.name + "_contractionpenalty",
                         input_reference=input_node.name),
                     tn.AggregatorNode(self.name + "_aggregator"),
                     tn.MultiplyConstantNode(self.name + "_multiplyweight"),
                     tn.SendToNode(self.name + "_sendto", to_key=self.name)
                 ]))
         ])
     ]
Example #8
0
def test_send_to_node_serialization():
    tn.check_serialization(tn.SendToNode("a"))
    tn.check_serialization(tn.SendToNode("a", reference="bar"))