def test_compute_graph_up_noconst(self): """Computing value assuming no constant functions""" # Number of times val_fun was called # Use list to avoid creating local fun variable during assignment counter = [0] def val_fun(node, *inputs): counter[0] += 1 if isinstance(node, spn.graph.node.VarNode): return 1 elif isinstance(node, spn.graph.node.ParamNode): return 0.1 else: weight_val, iv_val, *values = inputs return weight_val + sum(values) + 1 # Generate graph v1 = spn.RawLeaf(num_vars=1) v2 = spn.RawLeaf(num_vars=1) v3 = spn.RawLeaf(num_vars=1) s1 = spn.Sum(v1, v1, v2) # v1 included twice s2 = spn.Sum(v1, v3) s3 = spn.Sum(v2, v3, v3) # v3 included twice s4 = spn.Sum(s1, v1) s5 = spn.Sum(s2, v3, s3) s6 = spn.Sum(s4, s2, s5, s4, s5) # s4 and s5 included twice spn.generate_weights(s6) # Calculate value val = spn.compute_graph_up(s6, val_fun) # Test self.assertAlmostEqual(val, 35.2) self.assertEqual(counter[0], 15)
def test_traversing_on_dense(self): """Compare traversal algs on dense SPN""" def fun1(node, *args): counter[0] += 1 def fun2(node, *args): counter[0] += 1 if node.is_op: return [None] * len(node.inputs) # Generate dense graph v1 = spn.IndicatorLeaf(num_vars=3, num_vals=2, name="IndicatorLeaf1") v2 = spn.IndicatorLeaf(num_vars=3, num_vals=2, name="IndicatorLeaf2") gen = spn.DenseSPNGenerator(num_decomps=2, num_subsets=3, num_mixtures=2, input_dist=spn.DenseSPNGenerator.InputDist.MIXTURE, num_input_mixtures=None) root = gen.generate(v1, v2) spn.generate_weights(root) # Run traversal algs and count nodes counter = [0] spn.compute_graph_up_down(root, down_fun=fun2, graph_input=1) c1 = counter[0] counter = [0] spn.compute_graph_up(root, val_fun=fun1) c2 = counter[0] counter = [0] spn.traverse_graph(root, fun=fun1, skip_params=False) c3 = counter[0] # Compare self.assertEqual(c1, c3) self.assertEqual(c2, c3)
def test_compute_graph_up_const(self): """Computing value with constant function detection""" # Number of times val_fun was called # Use list to avoid creating local fun variable during assignment counter = [0] # Generate graph v1 = spn.RawLeaf(num_vars=1) v2 = spn.RawLeaf(num_vars=1) v3 = spn.RawLeaf(num_vars=1) s1 = spn.Sum(v1, v1, v2) # v1 included twice s2 = spn.Sum(v1, v3) s3 = spn.Sum(v2, v3, v3) # v3 included twice s4 = spn.Sum(s1, v1) s5 = spn.Sum(s2, v3, s3) s6 = spn.Sum(s4, s2, s5, s4, s5) # s4 and s5 included twice def val_fun(node, *inputs): counter[0] += 1 # s3 is not needed for calculations since only parent is s5 self.assertIsNot(node, s3) # Fixed value or compute using children if node == s5: return 16 else: if isinstance(node, spn.graph.node.VarNode): return 1 else: weight_val, iv_val, *values = inputs return sum(values) + 1 def const_fun(node): if node == s5: return True else: return False # Calculate value val = spn.compute_graph_up(s6, val_fun, const_fun) # Test self.assertEqual(val, 48) self.assertEqual(counter[0], 8)