def test_merge_init(self): """ Test Merge Node initialization """ builder = StaticBuilder(scope='Main') in1 = builder.addInner([[3]], num_inputs=1, node_class=NormalTriLNode) in2 = builder.addInner([[3]], num_inputs=1, node_class=NormalTriLNode) builder.addMergeNode(node_list=[in1, in2], merge_class=MergeNormals)
def test_merge_build1(self): """ Test Merge Node build """ builder = StaticBuilder(scope='Main') i1 = builder.addInput([[1]]) in1 = builder.addInner([[3]], node_class=NormalTriLNode) in2 = builder.addInner([[3]], node_class=NormalTriLNode) builder.addDirectedLink(i1, in1, islot=0) builder.addDirectedLink(i1, in2, islot=0) m1 = builder.addMergeNode([in1, in2], merge_class=MergeNormals) builder.build() m1 = builder.nodes[m1] self.assertIn('loc', m1._oslot_to_otensor) self.assertIn('main', m1._oslot_to_otensor) self.assertIn('cov', m1._oslot_to_otensor)
def test_merge_build2(self): """ Test Merge Node build """ builder = StaticBuilder(scope='Main') i1 = builder.addInput([[3]], iclass=NormalInputNode, name='N1') i2 = builder.addInput([[3]], iclass=NormalInputNode, name='N2') m1 = builder.addMergeNode(node_list=[i1, i2], merge_class=MergeNormals) builder.build() sess = tf.Session(graph=tf.get_default_graph()) sess.run(tf.global_variables_initializer()) s3 = builder.eval_node_oslot(sess, m1, oslot='cov') print("merge output", s3) print("merge output shape", s3.shape)