1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122
| class TreeNode(object): def __init__(self, val): self.val = val self.left = None self.right = None self.height = 1
class AVLTree(object): def insert(self, root, key): if not root: return TreeNode(key) elif key < root.val: root.left = self.insert(root.left, key) else: root.right = self.insert(root.right, key) root.height = 1 + max(self.get_height(root.left), self.get_height(root.right)) balance = self.get_balance(root) if balance > 1 and key < root.left.val: return self.right_rotate(root) if balance < -1 and key > root.right.val: return self.left_rotate(root) if balance > 1 and key > root.left.val: root.left = self.left_rotate(root.left) return self.right_rotate(root) if balance < -1 and key < root.right.val: root.right = self.right_rotate(root.right) return self.left_rotate(root) return root
def delete(self, root, key): if not root: return root elif key < root.val: root.left = self.delete(root.left, key) elif key > root.val: root.right = self.delete(root.right, key) else: if root.left is None: temp = root.right root = None return temp elif root.right is None: temp = root.left root = None return temp temp = self.get_min_value_node(root.right) root.val = temp.val root.right = self.delete(root.right, temp.val) if root is None: return root root.height = 1 + max(self.get_height(root.left), self.get_height(root.right)) balance = self.get_balance(root) if balance > 1 and self.get_balance(root.left) >= 0: return self.get_height(root) if balance < -1 and self.get_balance(root.right) <= 0: return self.left_rotate(root) if balance > 1 and self.get_balance(root.left) < 0: root.left = self.left_rotate(root.left) return self.get_height(root) if balance < -1 and self.get_balance(root.right) > 0: root.right = self.get_height(root.right) return self.left_rotate(root) return root
def left_rotate(self, z): y = z.right t2 = y.left y.left = z z.right = t2 z.height = 1 + max(self.get_height(z.left), self.get_height(z.right)) y.height = 1 + max(self.get_height(y.left), self.get_height(y.right)) return y
def right_rotate(self, z): y = z.left t3 = y.right y.right = z z.left = t3 z.height = 1 + max(self.get_height(z.left), self.get_height(z.right)) y.height = 1 + max(self.get_height(y.left), self.get_height(y.right)) return y
def get_height(self, root): if not root: return 0 return root.height
def get_balance(self, root): if not root: return 0 return self.get_height(root.left) - self.get_height(root.right)
def get_min_value_node(self, root): if root is None or root.left is None: return root return self.get_min_value_node(root.left)
def preOrder(self, root): if not root: return print("{0} ".format(root.val)) self.preOrder(root.left) self.preOrder(root.right)
def search(root, key): if root is None: return None if key > root.val: return search(root.right, key) elif key < root.val: return search(root.left, key) else: return root
|