def remove_leaf_nodes(root: TreeNode, target: int) -> TreeNode: if root: root.left = remove_leaf_nodes(root.left, target) root.right = remove_leaf_nodes(root.right, target) if root.val == target and is_leaf(root): return None return root
def prune_tree(root: TreeNode) -> TreeNode: if not root: return root.left = root.left if has_one(root.left) else None root.right = root.right if has_one(root.right) else None prune_tree(root.left) prune_tree(root.right) return root
def insert_into_bst(root: TreeNode, val: int) -> TreeNode: if not root: return TreeNode(val) if root.val < val: root.right = insert_into_bst(root.right, val) if root.val > val: root.left = insert_into_bst(root.left, val) return root
def prune_tree(root: TreeNode) -> TreeNode: if not root: return root.left = prune_tree(root.left) root.right = prune_tree(root.right) if not root.left and not root.right and root.val == 0: return None return root
def helper(inorder, postorder, in_st, in_end, post_end): if in_st > in_end or post_end < 0: return root = TreeNode(postorder[post_end]) ind = inorder.index(postorder[post_end]) root.left = helper(inorder, postorder, in_st, ind - 1, post_end - 1 - (in_end - ind)) root.right = helper(inorder, postorder, ind + 1, in_end, post_end - 1) return root
def construct(pre: List[int], post: List[int]) -> TreeNode: if not pre: return None root = TreeNode(pre[0]) if len(pre) == 1: return root L = post.index(pre[1]) + 1 root.left = self.constructFromPrePost(pre[1:L + 1], post[:L]) root.right = self.constructFromPrePost(pre[L + 1:], post[L:-1]) return root
def helper(preorder: List[int], inorder: List[int], pre_st: int, in_st: int, in_end: int) -> TreeNode: if pre_st > len(preorder) or in_st > in_end: return root = TreeNode(preorder[pre_st]) ind = inorder.index(preorder[pre_st]) root.left = helper(preorder, inorder, pre_st + 1, in_st, ind - 1) root.right = helper(preorder, inorder, pre_st + 1 + (ind - in_st), ind + 1, in_end) return root
def helper(i0, i1, N): if N == 0: return root = TreeNode(pre[i0]) if N == 1: return root for L in range(N): if post[i1 + L - 1] == pre[i0 + 1]: break root.left = helper(i0 + 1, i1, L) root.right = helper(i0 + L + 1, i1 + L, N - (L + 1)) return root
def allPossibleFBT(self, N: int) -> List[TreeNode]: if N not in self.memo: ans = [] for x in range(N): y = N - 1 - x for left in self.allPossibleFBT(x): for right in self.allPossibleFBT(y): bns = TreeNode(0) bns.left = left bns.right = right ans.append(bns) self.memo[N] = ans return self.memo[N]
def bst_from_preorder(preorder: List[int]) -> TreeNode: if not preorder: return root_val = preorder[0] root = TreeNode(root_val) if len(preorder) == 1: return root index = get_index(preorder, root_val) if index: root.left = bst_from_preorder(preorder[1:index]) root.right = bst_from_preorder(preorder[index:]) else: root.left = bst_from_preorder(preorder[1:]) return root
def tree_15(): ts = {i: TreeNode(i) for i in range(1, 6)} ts[1].left = ts[2] ts[1].right = ts[3] ts[2].left = ts[4] ts[2].right = ts[5] return ts
def tree(): ts = {x: TreeNode(x) for x in range(1, 6)} ts[1].left = ts[2] ts[1].right = ts[3] ts[2].left = ts[5] ts[3].left = ts[4] return ts
def tree_input(): ts = {x: TreeNode(x) for x in [1, 2, 3, 4, 7]} ts[4].left = ts[2] ts[4].right = ts[7] ts[2].left = ts[1] ts[2].right = ts[3] return ts
def recover_tree(root: TreeNode) -> None: """ Do not return anything, modify root in-place instead. """ state = {"first": None, "second": None, "pre": TreeNode(float("-inf"))} def traverse(node, state): if not node: return traverse(node.left, state) if not state["first"] and node.val < state["pre"].val: state["first"] = state["pre"] if state["first"] and node.val < state["pre"].val: state["second"] = node state["pre"] = node traverse(node.right, state) traverse(root, state) # swap state["first"].val, state["second"].val = ( state["second"].val, state["first"].val, ) return
def tree(): ts = {i: TreeNode(i) for i in [3, 9, 20, 15, 7]} ts[3].left = ts[9] ts[3].right = ts[20] ts[20].left = ts[15] ts[20].right = ts[7] return ts
def tree_ans(): ts = { k: TreeNode(v) for k, v in [("11", 1), ("12", 1), ("01", 0), ("02", 0)] } ts["11"].left = ts["01"] ts["01"].right = ts["12"] return ts
def tree(): ts = {i: TreeNode(i) for i in range(2, 8)} ts[5].left = ts[3] ts[5].right = ts[6] ts[3].left = ts[2] ts[3].right = ts[4] ts[6].right = ts[7] return ts
def tree(): ts = {x: TreeNode(x) for x in [4, 9, 0, 5, 1]} ts[4].left = ts[9] ts[4].right = ts[0] ts[9].left = ts[5] ts[9].right = ts[1] return ts
def helper(lo, hi): res = [] if lo > hi: res.append(None) return res for rt in range(lo, hi + 1): leftlist = helper(lo, rt - 1) rightlist = helper(rt + 1, hi) for right in rightlist: for left in leftlist: root = TreeNode(rt) root.right = right root.left = left res.append(root) return res
def tree_ans(): ts = {x: TreeNode(x) for x in [1, 2, 3, 4, 5, 7]} ts[4].left = ts[2] ts[4].right = ts[7] ts[2].left = ts[1] ts[2].right = ts[3] ts[7].left = ts[5] return ts
def tree(): ts = {x: TreeNode(x) for x in [8, 5, 1, 7, 10, 12]} ts[8].left = ts[5] ts[8].right = ts[10] ts[5].left = ts[1] ts[5].right = ts[7] ts[10].right = ts[12] return ts
def tree(): ts = {i: TreeNode(i) for i in [1, 2, 3, 4, 6, 7, 9]} ts[4].left = ts[2] ts[4].right = ts[7] ts[2].left = ts[1] ts[2].right = ts[3] ts[7].left = ts[6] ts[7].right = ts[9] return ts
def tree_17(): ts = {x: TreeNode(x) for x in range(1, 8)} ts[1].left = ts[2] ts[1].right = ts[3] ts[2].left = ts[4] ts[3].right = ts[5] ts[4].left = ts[6] ts[5].right = ts[7] return ts
def tree_input(): ts = { k: TreeNode(v) for k, v in [(1, 1), (2, 2), (3, 3), (4, 4), ("21", 2), ("22", 2)] } ts[1].left = ts[2] ts[1].right = ts[3] ts[2].left = ts["21"] ts[3].left = ts["22"] ts[3].right = ts[4] return ts
def tree(): ts = { k: TreeNode(v) for k, v in [(1, 1), ("41", 4), ("42", 4), ("43", 4), ("51", 5), ("52", 5)] } ts[1].left = ts["41"] ts[1].right = ts["51"] ts["41"].left = ts["42"] ts["41"].right = ts["43"] ts["51"].left = ts["52"] return ts
def tree(): ts = {x: TreeNode(x) for x in range(9)} ts[3].left = ts[5] ts[3].right = ts[1] ts[5].left = ts[6] ts[5].right = ts[2] ts[2].left = ts[7] ts[2].right = ts[4] ts[1].left = ts[0] ts[1].right = ts[8] return ts
def tree_input(): ts = {x: TreeNode(x) for x in [8, 3, 10, 1, 6, 14, 4, 7, 13]} ts[8].left = ts[3] ts[8].right = ts[10] ts[3].left = ts[1] ts[3].right = ts[6] ts[6].left = ts[4] ts[6].right = ts[7] ts[10].right = ts[14] ts[14].left = ts[13] return ts
def tree(): ts = {x: TreeNode(x) for x in range(12)} ts[0].left = ts[2] ts[0].right = ts[1] ts[2].left = ts[3] ts[3].left = ts[4] ts[3].right = ts[5] ts[4].right = ts[7] ts[7].left = ts[10] ts[7].right = ts[8] ts[5].left = ts[6] ts[6].left = ts[11] ts[6].right = ts[9] return ts
def tree_ans(): ts = {x: TreeNode(x) for x in [1, 3, 4]} ts[1].right = ts[3] ts[3].right = ts[4] return ts
def test_find_modes(): n1 = TreeNode(1) n22 = TreeNode(2) n23 = TreeNode(2) n1.right = n22 n22.left = n23 fm1 = FindModes() assert fm1.find_modes(n1) == [2] n2 = TreeNode(2) n12 = TreeNode(1) n2.left = n12 fm2 = FindModes() assert fm2.find_modes(n2) == [1, 2]