def shd(node_1: BinaryTreeNode, node_2: BinaryTreeNode, hd: Callable[[BinaryTreeNode, BinaryTreeNode], float]) -> float: """Structural Hamming distance (SHD) :param node_1: :param node_2: :param hd: :return: """ if node_1 is None or node_2 is None: return 1 # first get arity of each node arity_1 = 0 arity_2 = 0 if node_1.has_left_child(): arity_1 += 1 if node_1.has_right_child(): arity_1 += 1 if node_2.has_left_child(): arity_2 += 1 if node_2.has_right_child(): arity_2 += 1 if arity_1 != arity_2: return 1 else: if arity_1 == 0: # both are leaves return hd(node_1, node_2) else: m = arity_1 ham_dist = hd(node_1, node_2) children_dist_sum = sum([ shd(node_1.left, node_2.left, hd), shd(node_1.right, node_2.right, hd) ]) return (1 / (m + 1)) * (ham_dist + children_dist_sum)
class TestBinaryTreeNode(TestCase): def setUp(self): self.root_val = 'Parent Value' self.root = BinaryTreeNode(self.root_val) self.left_child_val = 42 self.right_child_val = 13 def test_has_left_child(self): self.assertFalse(self.root.has_left_child()) self.root.add_right(self.right_child_val) self.assertFalse(self.root.has_left_child()) self.root.add_left(self.left_child_val) self.assertTrue(self.root.has_left_child()) def test_has_right_child(self): self.assertFalse(self.root.has_right_child()) self.root.add_left(self.left_child_val) self.assertFalse(self.root.has_right_child()) self.root.add_right(self.right_child_val) self.assertTrue(self.root.has_right_child()) def test_has_parent(self): self.assertFalse(self.root.has_parent()) right = self.root.add_right(self.right_child_val) self.assertTrue(right.has_parent()) self.assertTrue(self.root.right.has_parent()) self.assertFalse(self.root.has_parent()) left = self.root.add_left(self.left_child_val) self.assertTrue(left.has_parent()) self.assertTrue(self.root.left.has_parent()) self.assertFalse(self.root.has_parent()) def test_is_left_child(self): self.assertRaises(AttributeError, self.root.is_left_child) left = self.root.add_left(self.left_child_val) self.assertTrue(left.is_left_child()) self.assertTrue(self.root.left.is_left_child()) right = self.root.add_right(self.right_child_val) self.assertFalse(right.is_left_child()) self.assertFalse(self.root.right.is_left_child()) def test_is_right_child(self): self.assertRaises(AttributeError, self.root.is_right_child) left = self.root.add_left(self.left_child_val) self.assertFalse(left.is_right_child()) self.assertFalse(self.root.left.is_right_child()) right = self.root.add_right(self.right_child_val) self.assertTrue(right.is_right_child()) self.assertTrue(self.root.right.is_right_child()) def test_is_root(self): self.assertTrue(self.root.is_root()) left = self.root.add_left(self.left_child_val) self.assertFalse(left.is_root()) right = self.root.add_right(self.right_child_val) self.assertFalse(right.is_root()) def test_is_leaf(self): self.assertTrue(self.root.is_leaf()) left = self.root.add_left(self.left_child_val) self.assertTrue(left.is_leaf()) self.assertFalse(self.root.is_leaf()) right = self.root.add_right(self.right_child_val) self.assertTrue(right.is_leaf()) self.assertFalse(self.root.is_leaf()) def test_add_left(self): result = self.root.add_left(self.left_child_val) self.assertEqual(result.parent, self.root) self.assertEqual(result.parent.value, self.root_val) self.assertEqual(result.parent.left, result) self.assertEqual(result.parent.left.value, self.left_child_val) def test_add_right(self): result = self.root.add_right(self.right_child_val) self.assertEqual(result.parent, self.root) self.assertEqual(result.parent.value, self.root_val) self.assertEqual(result.parent.right, result) self.assertEqual(result.parent.right.value, self.right_child_val) def test_create_graph(self): result = self.root.create_graph() self.assertIsInstance(result, Digraph) def test_height(self): root = BinaryTreeNode('*') self.assertEqual(root.height(), 1) left = root.add_left(10) self.assertEqual(root.height(), 2) right = root.add_right(20) self.assertEqual(root.height(), 2) ll = left.add_left(40) self.assertEqual(root.height(), 3) left.add_right(50) self.assertEqual(root.height(), 3) right.add_left(60) self.assertEqual(root.height(), 3) right.add_right(70) self.assertEqual(root.height(), 3) ll.add_left(80) self.assertEqual(root.height(), 4) def test_contains(self): root = BinaryTreeNode('*') self.assertIn('*', root) left = root.add_left(10) self.assertIn('*', root) self.assertIn(10, root) self.assertIn(10, left) self.assertIn(10, root.left) right = root.add_right(20) self.assertIn('*', root) self.assertIn(20, right) self.assertIn(20, right) self.assertIn(20, root.right) def test_iter(self): root = BinaryTreeNode('*') self.assertEqual(root.height(), 1) left = root.add_left(10) self.assertEqual(root.height(), 2) right = root.add_right(20) self.assertEqual(root.height(), 2) ll = left.add_left(40) self.assertEqual(root.height(), 3) left.add_right(50) self.assertEqual(root.height(), 3) right.add_left(60) self.assertEqual(root.height(), 3) right.add_right(70) self.assertEqual(root.height(), 3) ll.add_left(80) self.assertEqual(root.height(), 4) result = [] for value in root: self.assertIn(value, root) result.append(value) self.assertEqual(len(result), 8) def test_len(self): root = BinaryTreeNode('*') self.assertEqual(len(root), 1) left = root.add_left(10) self.assertEqual(len(root), 2) self.assertEqual(len(left), 1) right = root.add_right(20) self.assertEqual(len(root), 3) self.assertEqual(len(left), 1) self.assertEqual(len(right), 1) ll = left.add_left(40) self.assertEqual(len(root), 4) self.assertEqual(len(left), 2) self.assertEqual(len(right), 1) self.assertEqual(len(ll), 1) lr = left.add_right(50) self.assertEqual(len(root), 5) self.assertEqual(len(left), 3) self.assertEqual(len(right), 1) self.assertEqual(len(ll), 1) self.assertEqual(len(lr), 1) rl = right.add_left(60) self.assertEqual(len(root), 6) self.assertEqual(len(left), 3) self.assertEqual(len(right), 2) self.assertEqual(len(ll), 1) self.assertEqual(len(lr), 1) self.assertEqual(len(rl), 1) rr = right.add_right(70) self.assertEqual(len(root), 7) self.assertEqual(len(left), 3) self.assertEqual(len(right), 3) self.assertEqual(len(ll), 1) self.assertEqual(len(lr), 1) self.assertEqual(len(rl), 1) self.assertEqual(len(rr), 1)