예제 #1
0
 def test_o_multiple_outputs(self):
     intermediate_tensor = Variable(name="intermediate")
     intermediate_tensor2 = Variable(name="intermediate2")
     input_node = Node(op="Add", name="Input", inputs=[self.input_tensor], outputs=[intermediate_tensor])
     output_node = Node(op="Add", name="Out", inputs=[intermediate_tensor], outputs=[self.output_tensor])
     output_node2 = Node(op="Add", name="Input2", inputs=[intermediate_tensor], outputs=[intermediate_tensor2])
     assert input_node.o() == output_node
     assert input_node.o(1) == output_node2
예제 #2
0
 def test_o(self):
     intermediate_tensor = Variable(name="intermediate")
     input_node = Node(op="Add",
                       name="Input",
                       inputs=[self.input_tensor],
                       outputs=[intermediate_tensor])
     output_node = Node(op="Add",
                        name="Out",
                        inputs=[intermediate_tensor],
                        outputs=[self.output_tensor])
     assert input_node.o() == output_node