def test_input_flags(self): """Detection of different types of inputs""" inpt = spn.Input() self.assertFalse(inpt) self.assertFalse(inpt.is_op) self.assertFalse(inpt.is_var) self.assertFalse(inpt.is_param) n = spn.Sum() inpt = spn.Input(n) self.assertTrue(inpt) self.assertTrue(inpt.is_op) self.assertFalse(inpt.is_var) self.assertFalse(inpt.is_param) n = spn.ContVars() inpt = spn.Input(n) self.assertTrue(inpt) self.assertFalse(inpt.is_op) self.assertTrue(inpt.is_var) self.assertFalse(inpt.is_param) n = spn.Weights() inpt = spn.Input(n) self.assertTrue(inpt) self.assertFalse(inpt.is_op) self.assertFalse(inpt.is_var) self.assertTrue(inpt.is_param)
def test_generte_set_errors(self): """Detecting structure errors in __generate_set""" gen = spn.DenseSPNGenerator(num_decomps=2, num_subsets=3, num_mixtures=2) v1 = spn.IVs(num_vars=2, num_vals=4) v2 = spn.ContVars(num_vars=3, name="ContVars1") v3 = spn.ContVars(num_vars=2, name="ContVars2") s1 = spn.Sum(v3, v2) n1 = spn.Concat(v2) with self.assertRaises(spn.StructureError): gen._DenseSPNGenerator__generate_set([ spn.Input(v1, [0, 3, 2, 6, 7]), spn.Input(v2, [1, 2]), spn.Input(s1, None), spn.Input(n1, None) ])
def test_generte_set(self): """Generation of sets of inputs with __generate_set""" gen = spn.DenseSPNGenerator(num_decomps=2, num_subsets=3, num_mixtures=2) v1 = spn.IVs(num_vars=2, num_vals=4) v2 = spn.ContVars(num_vars=3, name="ContVars1") v3 = spn.ContVars(num_vars=2, name="ContVars2") s1 = spn.Sum(v3) n1 = spn.Concat(v2) out = gen._DenseSPNGenerator__generate_set([ spn.Input(v1, [0, 3, 2, 6, 7]), spn.Input(v2, [1, 2]), spn.Input(s1, None), spn.Input(n1, None) ]) # scope_dict: # Scope({IVs(0x7f00cb4049b0):0}): {(IVs(0x7f00cb4049b0), 0), # (IVs(0x7f00cb4049b0), 2), # (IVs(0x7f00cb4049b0), 3)}, # Scope({IVs(0x7f00cb4049b0):1}): {(IVs(0x7f00cb4049b0), 7), # (IVs(0x7f00cb4049b0), 6)}, # Scope({ContVars1(0x7f00b7982ef0):1}): {(Concat(0x7f00cb404d68), 1), # (ContVars1(0x7f00b7982ef0), 1)}, # Scope({ContVars1(0x7f00b7982ef0):2}): {(Concat(0x7f00cb404d68), 2), # (ContVars1(0x7f00b7982ef0), 2)}, # Scope({ContVars1(0x7f00b7982ef0):0}): {(Concat(0x7f00cb404d68), 0)}, # Scope({ContVars2(0x7f00cb391eb8):0, ContVars2(0x7f00cb391eb8):1}): { # (Sum(0x7f00cb404a90), 0)}} # Since order is undetermined, we check items self.assertEqual(len(out), 6) self.assertIn(tuple(sorted([(v2, 1), (n1, 1)])), out) self.assertIn(tuple(sorted([(v2, 2), (n1, 2)])), out) self.assertIn(tuple(sorted([(n1, 0)])), out) self.assertIn(tuple(sorted([(v1, 0), (v1, 2), (v1, 3)])), out) self.assertIn(tuple(sorted([(v1, 6), (v1, 7)])), out) self.assertIn(tuple(sorted([(s1, 0)])), out)
def test_generte_set(self): """Generation of sets of inputs with __generate_set""" gen = spn.DenseSPNGenerator(num_decomps=2, num_subsets=3, num_mixtures=2) v1 = spn.IVs(num_vars=2, num_vals=4) v2 = spn.ContVars(num_vars=3, name="ContVars1") v3 = spn.ContVars(num_vars=2, name="ContVars2") s1 = spn.Sum(v3) n1 = spn.Concat(v2) out = gen._DenseSPNGenerator__generate_set([ spn.Input(v1, [0, 3, 2, 6, 7]), spn.Input(v2, [1, 2]), spn.Input(s1, None), spn.Input(n1, None) ]) # Since order is undetermined, we check items self.assertEqual(len(out), 6) self.assertIn(tuple(sorted([(v2, 1), (n1, 1)])), out) self.assertIn(tuple(sorted([(v2, 2), (n1, 2)])), out) self.assertIn(tuple(sorted([(n1, 0)])), out) self.assertIn(tuple(sorted([(v1, 0), (v1, 2), (v1, 3)])), out) self.assertIn(tuple(sorted([(v1, 6), (v1, 7)])), out) self.assertIn(tuple(sorted([(s1, 0)])), out)
def test_input_conversion(self): """Conversion and verification of input specs in Input""" v1 = spn.ContVars(num_vars=5) # None inpt = spn.Input() self.assertIs(inpt.node, None) self.assertIs(inpt.indices, None) self.assertFalse(inpt) inpt = spn.Input(None) self.assertIs(inpt.node, None) self.assertIs(inpt.indices, None) self.assertFalse(inpt) inpt = spn.Input(None, [1, 2, 3]) self.assertIs(inpt.node, None) self.assertIs(inpt.indices, None) self.assertFalse(inpt) inpt = spn.Input.as_input(None) self.assertIs(inpt.node, None) self.assertIs(inpt.indices, None) self.assertFalse(inpt) inpt = spn.Input.as_input((None, [1, 2, 3])) self.assertIs(inpt.node, None) self.assertIs(inpt.indices, None) self.assertFalse(inpt) # Node inpt = spn.Input(v1) self.assertIs(inpt.node, v1) self.assertIs(inpt.indices, None) self.assertTrue(inpt) inpt = spn.Input.as_input(v1) self.assertIs(inpt.node, v1) self.assertIs(inpt.indices, None) self.assertTrue(inpt) # (Node, None) inpt = spn.Input(v1, None) self.assertIs(inpt.node, v1) self.assertIs(inpt.indices, None) self.assertTrue(inpt) inpt = spn.Input.as_input((v1, None)) self.assertIs(inpt.node, v1) self.assertIs(inpt.indices, None) self.assertTrue(inpt) # (Node, index) inpt = spn.Input(v1, 10) self.assertIs(inpt.node, v1) self.assertListEqual(inpt.indices, [10]) self.assertTrue(inpt) inpt = spn.Input.as_input((v1, 10)) self.assertIs(inpt.node, v1) self.assertListEqual(inpt.indices, [10]) self.assertTrue(inpt) # (Node, indices) inpt = spn.Input(v1, [10]) self.assertIs(inpt.node, v1) self.assertListEqual(inpt.indices, [10]) self.assertTrue(inpt) inpt = spn.Input.as_input((v1, [10])) self.assertIs(inpt.node, v1) self.assertListEqual(inpt.indices, [10]) self.assertTrue(inpt) inpt = spn.Input(v1, [10, 1, 20]) self.assertIs(inpt.node, v1) self.assertListEqual(inpt.indices, [10, 1, 20]) self.assertTrue(inpt) inpt = spn.Input.as_input((v1, [10, 1, 20])) self.assertIs(inpt.node, v1) self.assertListEqual(inpt.indices, [10, 1, 20]) self.assertTrue(inpt) # Checking type of input with self.assertRaises(TypeError): inpt = spn.Input(set()) with self.assertRaises(TypeError): inpt = spn.Input.as_input(set()) with self.assertRaises(TypeError): inpt = spn.Input.as_input((set())) with self.assertRaises(TypeError): inpt = spn.Input.as_input(tuple()) with self.assertRaises(TypeError): inpt = spn.Input(v1, set()) with self.assertRaises(TypeError): inpt = spn.Input(v1, set()) with self.assertRaises(TypeError): inpt = spn.Input.as_input((v1, set())) with self.assertRaises(TypeError): inpt = spn.Input.as_input((v1,)) # Detecting empty list with self.assertRaises(ValueError): inpt = spn.Input(v1, []) with self.assertRaises(ValueError): inpt = spn.Input.as_input((v1, [])) # Detecting incorrect indices with self.assertRaises(ValueError): inpt = spn.Input(v1, [1, set(), 2]) with self.assertRaises(ValueError): inpt = spn.Input.as_input((v1, [1, set(), 2])) with self.assertRaises(ValueError): inpt = spn.Input(v1, [1, -1, 2]) with self.assertRaises(ValueError): inpt = spn.Input.as_input((v1, [1, -1, 2])) # Detecting duplicate indices with self.assertRaises(ValueError): inpt = spn.Input(v1, [0, 1, 3, 1]) with self.assertRaises(ValueError): inpt = spn.Input.as_input((v1, [0, 1, 3, 1]))