用python实现基本数据结构和算法
1章:ADT抽象数据类型,定义数据和其操作
什么是ADT: 抽象数据类型,学过数据结构的应该都知道。
How to select datastructures for ADT
- Dose the data structure provide for the storage requirements as specified by the domain of the ADT?
- Does the data structure provide the data access and manipulation functionality to fully implement the ADT?
- Effcient implemention? based on complexity analysis.
下边代码是个简单的示例,比如实现一个简单的Bag类,先定义其具有的操作,然后我们再用类的magic method来实现这些方法:
class Bag:
"""
constructor: 构造函数
size
contains
append
remove
iter
"""
def __init__(self):
self._items = list()
def __len__(self):
return len(self._items)
def __contains__(self, item):
return item in self._items
def add(self, item):
self._items.append(item)
def remove(self, item):
assert item in self._items, 'item must in the bag'
return self._items.remove(item)
def __iter__(self):
return _BagIterator(self._items)
class _BagIterator:
""" 注意这里实现了迭代器类 """
def __init__(self, seq):
self._bag_items = seq
self._cur_item = 0
def __iter__(self):
return self
def __next__(self):
if self._cur_item < len(self._bag_items):
item = self._bag_items[self._cur_item]
self._cur_item += 1
return item
else:
raise StopIteration
b = Bag()
b.add(1)
b.add(2)
for i in b: # for使用__iter__构建,用__next__迭代
print(i)
"""
# for 语句等价于
i = b.__iter__()
while True:
try:
item = i.__next__()
print(item)
except StopIteration:
break
"""
2章:array vs list
array: 定长,操作有限,但是节省内存;貌似我的生涯中还没用过,不过python3.5中我试了确实有array类,可以用import array直接导入
list: 会预先分配内存,操作丰富,但是耗费内存。我用sys.getsizeof做了实验。我个人理解很类似C++ STL里的vector,是使用最频繁的数据结构。
- list.append: 如果之前没有分配够内存,会重新开辟新区域,然后复制之前的数据,复杂度退化
- list.insert: 会移动被插入区域后所有元素,O(n)
- list.pop: pop不同位置需要的复杂度不同pop(0)是O(1)复杂度,pop()首位O(n)复杂度
- list[]: slice操作copy数据(预留空间)到另一个list
来实现一个array的ADT:
import ctypes
class Array:
def __init__(self, size):
assert size > 0, 'array size must be > 0'
self._size = size
PyArrayType = ctypes.py_object * size
self._elements = PyArrayType()
self.clear(None)
def __len__(self):
return self._size
def __getitem__(self, index):
assert index >= 0 and index < len(self), 'out of range'
return self._elements[index]
def __setitem__(self, index, value):
assert index >= 0 and index < len(self), 'out of range'
self._elements[index] = value
def clear(self, value):
""" 设置每个元素为value """
for i in range(len(self)):
self._elements[i] = value
def __iter__(self):
return _ArrayIterator(self._elements)
class _ArrayIterator:
def __init__(self, items):
self._items = items
self._idx = 0
def __iter__(self):
return self
def __next__(self):
if self._idex < len(self._items):
val = self._items[self._idx]
self._idex += 1
return val
else:
raise StopIteration
Two-Demensional Arrays
class Array2D:
""" 要实现的方法
Array2D(nrows, ncols): constructor
numRows()
numCols()
clear(value)
getitem(i, j)
setitem(i, j, val)
"""
def __init__(self, numrows, numcols):
self._the_rows = Array(numrows) # 数组的数组
for i in range(numrows):
self._the_rows[i] = Array(numcols)
@property
def numRows(self):
return len(self._the_rows)
@property
def NumCols(self):
return len(self._the_rows[0])
def clear(self, value):
for row in self._the_rows:
row.clear(value)
def __getitem__(self, ndx_tuple): # ndx_tuple: (x, y)
assert len(ndx_tuple) == 2
row, col = ndx_tuple[0], ndx_tuple[1]
assert (row >= 0 and row < self.numRows and
col >= 0 and col < self.NumCols)
the_1d_array = self._the_rows[row]
return the_1d_array[col]
def __setitem__(self, ndx_tuple, value):
assert len(ndx_tuple) == 2
row, col = ndx_tuple[0], ndx_tuple[1]
assert (row >= 0 and row < self.numRows and
col >= 0 and col < self.NumCols)
the_1d_array = self._the_rows[row]
the_1d_array[col] = value
The Matrix ADT, m行,n列。这个最好用还是用pandas处理矩阵,自己实现比较*疼
class Matrix:
""" 最好用pandas的DataFrame
Matrix(rows, ncols): constructor
numCols()
getitem(row, col)
setitem(row, col, val)
scaleBy(scalar): 每个元素乘scalar
transpose(): 返回transpose转置
add(rhsMatrix): size must be the same
subtract(rhsMatrix)
multiply(rhsMatrix)
"""
def __init__(self, numRows, numCols):
self._theGrid = Array2D(numRows, numCols)
self._theGrid.clear(0)
@property
def numRows(self):
return self._theGrid.numRows
@property
def NumCols(self):
return self._theGrid.numCols
def __getitem__(self, ndxTuple):
return self._theGrid[ndxTuple[0], ndxTuple[1]]
def __setitem__(self, ndxTuple, scalar):
self._theGrid[ndxTuple[0], ndxTuple[1]] = scalar
def scaleBy(self, scalar):
for r in range(self.numRows):
for c in range(self.numCols):
self[r, c] *= scalar
def __add__(self, rhsMatrix):
assert (rhsMatrix.numRows == self.numRows and
rhsMatrix.numCols == self.numCols)
newMartrix = Matrix(self.numRows, self.numCols)
for r in range(self.numRows):
for c in range(self.numCols):
newMartrix[r, c] = self[r, c] + rhsMatrix[r, c]
3章:Sets and Maps
除了list之外,最常用的应该就是python内置的set和dict了。
sets ADT
A set is a container that stores a collection of unique values over a given comparable domain in which the stored values have no particular ordering.
class Set:
""" 使用list实现set ADT
Set()
length()
contains(element)
add(element)
remove(element)
equals(element)
isSubsetOf(setB)
union(setB)
intersect(setB)
difference(setB)
iterator()
"""
def __init__(self):
self._theElements = list()
def __len__(self):
return len(self._theElements)
def __contains__(self, element):
return element in self._theElements
def add(self, element):
if element not in self:
self._theElements.append(element)
def remove(self, element):
assert element in self, 'The element must be set'
self._theElements.remove(element)
def __eq__(self, setB):
if len(self) != len(setB):
return False
else:
return self.isSubsetOf(setB)
def isSubsetOf(self, setB):
for element in self:
if element not in setB:
return False
return True
def union(self, setB):
newSet = Set()
newSet._theElements.extend(self._theElements)
for element in setB:
if element not in self:
newSet._theElements.append(element)
return newSet
Maps or Dict: 键值对,python内部采用hash实现。
class Map:
""" Map ADT list implemention
Map()
length()
contains(key)
add(key, value)
remove(key)
valudOf(key)
iterator()
"""
def __init__(self):
self._entryList = list()
def __len__(self):
return len(self._entryList)
def __contains__(self, key):
ndx = self._findPosition(key)
return ndx is not None
def add(self, key, value):
ndx = self._findPosition(key)
if ndx is not None:
self._entryList[ndx].value = value
return False
else:
entry = _MapEntry(key, value)
self._entryList.append(entry)
return True
def valueOf(self, key):
ndx = self._findPosition(key)
assert ndx is not None, 'Invalid map key'
return self._entryList[ndx].value
def remove(self, key):
ndx = self._findPosition(key)
assert ndx is not None, 'Invalid map key'
self._entryList.pop(ndx)
def __iter__(self):
return _MapIterator(self._entryList)
def _findPosition(self, key):
for i in range(len(self)):
if self._entryList[i].key == key:
return i
return None
class _MapEntry: # or use collections.namedtuple('_MapEntry', 'key,value')
def __init__(self, key, value):
self.key = key
self.value = value
The multiArray ADT, 多维数组,一般是使用一个一维数组模拟,然后通过计算下标获取元素
class MultiArray:
""" row-major or column-marjor ordering, this is row-major ordering
MultiArray(d1, d2, ...dn)
dims(): the number of dimensions
length(dim): the length of given array dimension
clear(value)
getitem(i1, i2, ... in), index(i1,i2,i3) = i1*(d2*d3) + i2*d3 + i3
setitem(i1, i2, ... in)
计算下标:index(i1,i2,...in) = i1*f1 + i2*f2 + ... + i(n-1)*f(n-1) + in*1
"""
def __init__(self, *dimensions):
# Implementation of MultiArray ADT using a 1-D # array,数组的数组的数组。。。
assert len(dimensions) > 1, 'The array must have 2 or more dimensions'
self._dims = dimensions
# Compute to total number of elements in the array
size = 1
for d in dimensions:
assert d > 0, 'Dimensions must be > 0'
size *= d
# Create the 1-D array to store the elements
self._elements = Array(size)
# Create a 1-D array to store the equation factors
self._factors = Array(len(dimensions))
self._computeFactors()
@property
def numDims(self):
return len(self._dims)
def length(self, dim):
assert dim > 0 and dim < len(self._dims), 'Dimension component out of range'
return self._dims[dim-1]
def clear(self, value):
self._elements.clear(value)
def __getitem__(self, ndxTuple):
assert len(ndxTuple) == self.numDims, 'Invalid # of array subscripts'
index = self._computeIndex(ndxTuple)
assert index is not None, 'Array subscript out of range'
return self._elements[index]
def __setitem__(self, ndxTuple, value):
assert len(ndxTuple) == self.numDims, 'Invalid # of array subscripts'
index = self._computeIndex(ndxTuple)
assert index is not None, 'Array subscript out of range'
self._elements[index] = value
def _computeIndex(self, ndxTuple):
# using the equation: i1*f1 + i2*f2 + ... + in*fn
offset = 0
for j in range(len(ndxTuple)):
if ndxTuple[j] < 0 or ndxTuple[j] >= self._dims[j]:
return None
else:
offset += ndexTuple[j] * self._factors[j]
return offset
4章:Algorithm Analysis
一般使用大O标记法来衡量算法的平均时间复杂度, 1 < log(n) < n < nlog(n) < n^2 < n^3 < a^n。 了解常用数据结构操作的平均时间复杂度有利于使用更高效的数据结构,当然有时候需要在时间和空间上进行衡量,有些操作甚至还会退化,比如list的append操作,如果list空间不够,会去开辟新的空间,操作复杂度退化到O(n),有时候还需要使用均摊分析(amortized)
5章:Searching and Sorting
排序和查找是最基础和频繁的操作,python内置了in操作符和bisect二分操作模块实现查找,内置了sorted方法来实现排序操作。二分和快排也是面试中经常考到的,本章讲的是基本的排序和查找。
def binary_search(sorted_seq, val):
""" 实现标准库中的bisect.bisect_left """
low = 0
high = len(sorted_seq) - 1
while low <= high:
mid = (high + low) // 2
if sorted_seq[mid] == val:
return mid
elif val < sorted_seq[mid]:
high = mid - 1
else:
low = mid + 1
return low
def bubble_sort(seq): # O(n^2), n(n-1)/2 = 1/2(n^2 + n)
n = len(seq)
for i in range(n-1):
for j in range(n-1-i): # 这里之所以 n-1 还需要 减去 i 是因为每一轮冒泡最大的元素都会冒泡到最后,无需再比较
if seq[j] > seq[j+1]:
seq[j], seq[j+1] = seq[j+1], seq[j]
def select_sort(seq):
"""可以看作是冒泡的改进,每次找一个最小的元素交换,每一轮只需要交换一次"""
n = len(seq)
for i in range(n-1):
min_idx = i # assume the ith element is the smallest
for j in range(i+1, n):
if seq[j] < seq[min_idx]: # find the minist element index
min_idx = j
if min_idx != i: # swap
seq[i], seq[min_idx] = seq[min_idx], seq[i]
def insertion_sort(seq):
""" 每次挑选下一个元素插入已经排序的数组中,初始时已排序数组只有一个元素"""
n = len(seq)
for i in range(1, n):
value = seq[i] # save the value to be positioned
# find the position where value fits in the ordered part of the list
pos = i
while pos > 0 and value < seq[pos-1]:
# Shift the items to the right during the search
seq[pos] = seq[pos-1]
pos -= 1
seq[pos] = value
def merge_sorted_list(listA, listB):
""" 归并两个有序数组 """
new_list = list()
a = b = 0
while a < len(listA) and b < len(listB):
if listA[a] < listB[b]:
new_list.append(listA[a])
a += 1
else:
new_list.append(listB[b])
b += 1
while a < len(listA):
new_list.append(listA[a])
a += 1
while b < len(listB):
new_list.append(listB[b])
b += 1
return new_list
6章: Linked Structure
list是最常用的数据结构,但是list在中间增减元素的时候效率会很低,这时候linked list会更适合,缺点就是获取元素的平均时间复杂度变成了O(n)
# 单链表实现
class ListNode:
def __init__(self, data):
self.data = data
self.next = None
def travsersal(head, callback):
curNode = head
while curNode is not None:
callback(curNode.data)
curNode = curNode.next
def unorderdSearch(head, target):
curNode = head
while curNode is not None and curNode.data != target:
curNode = curNode.next
return curNode is not None
# Given the head pointer, prepend an item to an unsorted linked list.
def prepend(head, item):
newNode = ListNode(item)
newNode.next = head
head = newNode
# Given the head reference, remove a target from a linked list
def remove(head, target):
predNode = None
curNode = head
while curNode is not None and curNode.data != target:
# 寻找目标
predNode = curNode
curNode = curNode.data
if curNode is not None:
if curNode is head:
head = curNode.next
else:
predNode.next = curNode.next
7章:Stacks
栈也是计算机里用得比较多的数据结构,栈是一种后进先出的数据结构,可以理解为往一个桶里放盘子,先放进去的会被压在地下,拿盘子的时候,后放的会被先拿出来。
class Stack:
""" Stack ADT, using a python list
Stack()
isEmpty()
length()
pop(): assert not empty
peek(): assert not empty, return top of non-empty stack without removing it
push(item)
"""
def __init__(self):
self._items = list()
def isEmpty(self):
return len(self) == 0
def __len__(self):
return len(self._items)
def peek(self):
assert not self.isEmpty()
return self._items[-1]
def pop(self):
assert not self.isEmpty()
return self._items.pop()
def push(self, item):
self._items.append(item)
class Stack:
""" Stack ADT, use linked list
使用list实现很简单,但是如果涉及大量push操作,list的空间不够时复杂度退化到O(n)
而linked list可以保证最坏情况下仍是O(1)
"""
def __init__(self):
self._top = None # top节点, _StackNode or None
self._size = 0 # int
def isEmpty(self):
return self._top is None
def __len__(self):
return self._size
def peek(self):
assert not self.isEmpty()
return self._top.item
def pop(self):
assert not self.isEmpty()
node = self._top
self.top = self._top.next
self._size -= 1
return node.item
def _push(self, item):
self._top = _StackNode(item, self._top)
self._size += 1
class _StackNode:
def __init__(self, item, link):
self.item = item
self.next = link
8章:Queues
队列也是经常使用的数据结构,比如发送消息等,celery可以使用redis提供的list实现消息队列。 本章我们用list和linked list来实现队列和优先级队列。
class Queue:
""" Queue ADT, use list。list实现,简单但是push和pop效率最差是O(n)
Queue()
isEmpty()
length()
enqueue(item)
dequeue()
"""
def __init__(self):
self._qList = list()
def isEmpty(self):
return len(self) == 0
def __len__(self):
return len(self._qList)
def enquue(self, item):
self._qList.append(item)
def dequeue(self):
assert not self.isEmpty()
return self._qList.pop(0)
from array import Array # Array那一章实现的Array ADT
class Queue:
"""
circular Array ,通过头尾指针实现。list内置append和pop复杂度会退化,使用
环数组实现可以使得入队出队操作时间复杂度为O(1),缺点是数组长度需要固定。
"""
def __init__(self, maxSize):
self._count = 0
self._front = 0
self._back = maxSize - 1
self._qArray = Array(maxSize)
def isEmpty(self):
return self._count == 0
def isFull(self):
return self._count == len(self._qArray)
def __len__(self):
return len(self._count)
def enqueue(self, item):
assert not self.isFull()
maxSize = len(self._qArray)
self._back = (self._back + 1) % maxSize # 移动尾指针
self._qArray[self._back] = item
self._count += 1
def dequeue(self):
assert not self.isFull()
item = self._qArray[self._front]
maxSize = len(self._qArray)
self._front = (self._front + 1) % maxSize
self._count -= 1
return item
class _QueueNode:
def __init__(self, item):
self.item = item
class Queue:
""" Queue ADT, linked list 实现。为了改进环型数组有最大数量的限制,改用
带有头尾节点的linked list实现。
"""
def __init__(self):
self._qhead = None
self._qtail = None
self._qsize = 0
def isEmpty(self):
return self._qhead is None
def __len__(self):
return self._count
def enqueue(self, item):
node = _QueueNode(item) # 创建新的节点并用尾节点指向他
if self.isEmpty():
self._qhead = node
else:
self._qtail.next = node
self._qtail = node
self._qcount += 1
def dequeue(self):
assert not self.isEmpty(), 'Can not dequeue from an empty queue'
node = self._qhead
if self._qhead is self._qtail:
self._qtail = None
self._qhead = self._qhead.next # 前移头节点
self._count -= 1
return node.item
class UnboundedPriorityQueue:
""" PriorityQueue ADT: 给每个item加上优先级p,高优先级先dequeue
分为两种:
- bounded PriorityQueue: 限制优先级在一个区间[0...p)
- unbounded PriorityQueue: 不限制优先级
PriorityQueue()
BPriorityQueue(numLevels): create a bounded PriorityQueue with priority in range
[0, numLevels-1]
isEmpty()
length()
enqueue(item, priority): 如果是bounded PriorityQueue, priority必须在区间内
dequeue(): 最高优先级的出队,同优先级的按照FIFO顺序
- 两种实现方式:
1.入队的时候都是到队尾,出队操作找到最高优先级的出队,出队操作O(n)
2.始终维持队列有序,每次入队都找到该插入的位置,出队操作是O(1)
(注意如果用list实现list.append和pop操作复杂度会因内存分配退化)
"""
from collections import namedtuple
_PriorityQEntry = namedtuple('_PriorityQEntry', 'item, priority')
# 采用方式1,用内置list实现unbounded PriorityQueue
def __init__(self):
self._qlist = list()
def isEmpty(self):
return len(self) == 0
def __len__(self):
return len(self._qlist)
def enqueue(self, item, priority):
entry = UnboundedPriorityQueue._PriorityQEntry(item, priority)
self._qlist.append(entry)
def deque(self):
assert not self.isEmpty(), 'can not deque from an empty queue'
highest = self._qlist[0].priority
for i in range(len(self)): # 出队操作O(n),遍历找到最高优先级
if self._qlist[i].priority < highest:
highest = self._qlist[i].priority
entry = self._qlist.pop(highest)
return entry.item
class BoundedPriorityQueue:
""" BoundedPriorityQueue ADT,用linked list实现。上一个地方提到了 BoundedPriorityQueue
但是为什么需要 BoundedPriorityQueue呢? BoundedPriorityQueue 的优先级限制在[0, maxPriority-1]
对于 UnboundedPriorityQueue,出队操作由于要遍历寻找优先级最高的item,所以平均
是O(n)的操作,但是对于 BoundedPriorityQueue,用队列数组实现可以达到常量时间,
用空间换时间。比如要弹出一个元素,直接找到第一个非空队列弹出 元素就可以了。
(小数字代表高优先级,先出队)
qlist
[0] -> ["white"]
[1]
[2] -> ["black", "green"]
[3] -> ["purple", "yellow"]
"""
# Implementation of the bounded Priority Queue ADT using an array of #
# queues in which the queues are implemented using a linked list.
from array import Array # 第二章定义的ADT
def __init__(self, numLevels):
self._qSize = 0
self._qLevels = Array(numLevels)
for i in range(numLevels):
self._qLevels[i] = Queue() # 上一节讲到用linked list实现的Queue
def isEmpty(self):
return len(self) == 0
def __len__(self):
return len(self._qSize)
def enqueue(self, item, priority):
assert priority >= 0 and priority < len(self._qLevels), 'invalid priority'
self._qLevel[priority].enquue(item) # 直接找到 priority 对应的槽入队
def deque(self):
assert not self.isEmpty(), 'can not deque from an empty queue'
i = 0
p = len(self._qLevels)
while i < p and not self._qLevels[i].isEmpty(): # 找到第一个非空队列
i += 1
return self._qLevels[i].dequeue()
9章:Advanced Linked Lists
之前曾经介绍过单链表,一个链表节点只有data和next字段,本章介绍高级的链表。
Doubly Linked List,双链表,每个节点多了个prev指向前一个节点。双链表可以用来编写文本编辑器的buffer。
class DListNode:
def __init__(self, data):
self.data = data
self.prev = None
self.next = None
def revTraversa(tail):
curNode = tail
while cruNode is not None:
print(curNode.data)
curNode = curNode.prev
def search_sorted_doubly_linked_list(head, tail, probe, target):
""" probing technique探查法,改进直接遍历,不过最坏时间复杂度仍是O(n)
searching a sorted doubly linked list using the probing technique
Args:
head (DListNode obj)
tail (DListNode obj)
probe (DListNode or None)
target (DListNode.data): data to search
"""
if head is None: # make sure list is not empty
return False
if probe is None: # if probe is null, initialize it to first node
probe = head
# if the target comes before the probe node, we traverse backward, otherwise
# traverse forward
if target < probe.data:
while probe is not None and target <= probe.data:
if target == probe.dta:
return True
else:
probe = probe.prev
else:
while probe is not None and target >= probe.data:
if target == probe.data:
return True
else:
probe = probe.next
return False
def insert_node_into_ordered_doubly_linekd_list(value):
""" 最好画个图看,链表操作很容易绕晕,注意赋值顺序"""
newnode = DListNode(value)
if head is None: # empty list
head = newnode
tail = head
elif value < head.data: # insert before head
newnode.next = head
head.prev = newnode
head = newnode
elif value > tail.data: # insert after tail
newnode.prev = tail
tail.next = newnode
tail = newnode
else: # insert into middle
node = head
while node is not None and node.data < value:
node = node.next
newnode.next = node
newnode.prev = node.prev
node.prev.next = newnode
node.prev = newnode
循环链表
def travrseCircularList(listRef):
curNode = listRef
done = listRef is None
while not None:
curNode = curNode.next
print(curNode.data)
done = curNode is listRef # 回到遍历起始点
def searchCircularList(listRef, target):
curNode = listRef
done = listRef is None
while not done:
curNode = curNode.next
if curNode.data == target:
return True
else:
done = curNode is listRef or curNode.data > target
return False
def add_newnode_into_ordered_circular_linked_list(listRef, value):
""" 插入并维持顺序
1.插入空链表;2.插入头部;3.插入尾部;4.按顺序插入中间
"""
newnode = ListNode(value)
if listRef is None: # empty list
listRef = newnode
newnode.next = newnode
elif value < listRef.next.data: # insert in front
newnode.next = listRef.next
listRef.next = newnode
elif value > listRef.data: # insert in back
newnode.next = listRef.next
listRef.next = newnode
listRef = newnode
else: # insert in the middle
preNode = None
curNode = listRef
done = listRef is None
while not done:
preNode = curNode
preNode = curNode.next
done = curNode is listRef or curNode.data > value
newnode.next = curNode
preNode.next = newnode
利用循环双端链表我们可以实现一个经典的缓存失效算法,lru:
# -*- coding: utf-8 -*-
class Node(object):
def __init__(self, prev=None, next=None, key=None, value=None):
self.prev, self.next, self.key, self.value = prev, next, key, value
class CircularDoubleLinkedList(object):
def __init__(self):
node = Node()
node.prev, node.next = node, node
self.rootnode = node
def headnode(self):
return self.rootnode.next
def tailnode(self):
return self.rootnode.prev
def remove(self, node):
if node is self.rootnode:
return
else:
node.prev.next = node.next
node.next.prev = node.prev
def append(self, node):
tailnode = self.tailnode()
tailnode.next = node
node.next = self.rootnode
self.rootnode.prev = node
class LRUCache(object):
def __init__(self, maxsize=16):
self.maxsize = maxsize
self.cache = {}
self.access = CircularDoubleLinkedList()
self.isfull = len(self.cache) >= self.maxsize
def __call__(self, func):
def wrapper(n):
cachenode = self.cache.get(n)
if cachenode is not None: # hit
self.access.remove(cachenode)
self.access.append(cachenode)
return cachenode.value
else: # miss
value = func(n)
if not self.isfull:
tailnode = self.access.tailnode()
newnode = Node(tailnode, self.access.rootnode, n, value)
self.access.append(newnode)
self.cache[n] = newnode
self.isfull = len(self.cache) >= self.maxsize
return value
else: # full
lru_node = self.access.headnode()
del self.cache[lru_node.key]
self.access.remove(lru_node)
tailnode = self.access.tailnode()
newnode = Node(tailnode, self.access.rootnode, n, value)
self.access.append(newnode)
self.cache[n] = newnode
return value
return wrapper
@LRUCache()
def fib(n):
if n <= 2:
return 1
else:
return fib(n - 1) + fib(n - 2)
for i in range(1, 35):
print(fib(i))
10章:Recursion
Recursion is a process for solving problems by subdividing a larger problem into smaller cases of the problem itself and then solving the smaller, more trivial parts.
递归函数:调用自己的函数
# 递归函数:调用自己的函数,看一个最简单的递归函数,倒序打印一个数
def printRev(n):
if n > 0:
print(n)
printRev(n-1)
printRev(3) # 从10输出到1
# 稍微改一下,print放在最后就得到了正序打印的函数
def printInOrder(n):
if n > 0:
printInOrder(n-1)
print(n) # 之所以最小的先打印是因为函数一直递归到n==1时候的最深栈,此时不再
# 递归,开始执行print语句,这时候n==1,之后每跳出一层栈,打印更大的值
printInOrder(3) # 正序输出
Properties of Recursion: 使用stack解决的问题都能用递归解决
- A recursive solution must contain a base case; 递归出口,代表最小子问题(n == 0退出打印)
- A recursive solution must contain a recursive case; 可以分解的子问题
- A recursive solution must make progress toward the base case. 递减n使得n像递归出口靠近
Tail Recursion: occurs when a function includes a single recursive call as the last statement of the function. In this case, a stack is not needed to store values to te used upon the return of the recursive call and thus a solution can be implemented using a iterative loop instead.
# Recursive Binary Search
def recBinarySearch(target, theSeq, first, last):
# 你可以写写单元测试来验证这个函数的正确性
if first > last: # 递归出口1
return False
else:
mid = (first + last) // 2
if theSeq[mid] == target:
return True # 递归出口2
elif theSeq[mid] > target:
return recBinarySearch(target, theSeq, first, mid - 1)
else:
return recBinarySearch(target, theSeq, mid + 1, last)
11章:Hash Tables
基于比较的搜索(线性搜索,有序数组的二分搜索)最好的时间复杂度只能达到O(logn),利用hash可以实现O(1)查找,python内置dict的实现方式就是hash,你会发现dict的key必须要是实现了 __hash__
和 __eq__
方法的。
Hashing: hashing is the process of mapping a search a key to a limited range of array indeices with the goal of providing direct access to the keys.
hash方法有个hash函数用来给key计算一个hash值,作为数组下标,放到该下标对应的槽中。当不同key根据hash函数计算得到的下标相同时,就出现了冲突。解决冲突有很多方式,比如让每个槽成为链表,每次冲突以后放到该槽链表的尾部,但是查询时间就会退化,不再是O(1)。还有一种探查方式,当key的槽冲突时候,就会根据一种计算方式去寻找下一个空的槽存放,探查方式有线性探查,二次方探查法等,cpython解释器使用的是二次方探查法。还有一个问题就是当python使用的槽数量大于预分配的2/3时候,会重新分配内存并拷贝以前的数据,所以有时候dict的add操作代价还是比较高的,牺牲空间但是可以始终保证O(1)的查询效率。如果有大量的数据,建议还是使用bloomfilter或者redis提供的HyperLogLog。
如果你感兴趣,可以看看这篇文章,介绍c解释器如何实现的python dict对象:Python dictionary implementation。我们使用Python来实现一个类似的hash结构。
import ctypes
class Array: # 第二章曾经定义过的ADT,这里当做HashMap的槽数组使用
def __init__(self, size):
assert size > 0, 'array size must be > 0'
self._size = size
PyArrayType = ctypes.py_object * size
self._elements = PyArrayType()
self.clear(None)
def __len__(self):
return self._size
def __getitem__(self, index):
assert index >= 0 and index < len(self), 'out of range'
return self._elements[index]
def __setitem__(self, index, value):
assert index >= 0 and index < len(self), 'out of range'
self._elements[index] = value
def clear(self, value):
""" 设置每个元素为value """
for i in range(len(self)):
self._elements[i] = value
def __iter__(self):
return _ArrayIterator(self._elements)
class _ArrayIterator:
def __init__(self, items):
self._items = items
self._idx = 0
def __iter__(self):
return self
def __next__(self):
if self._idx < len(self._items):
val = self._items[self._idx]
self._idx += 1
return val
else:
raise StopIteration
class HashMap:
""" HashMap ADT实现,类似于python内置的dict
一个槽有三种状态:
1.从未使用 HashMap.UNUSED。此槽没有被使用和冲突过,查找时只要找到UNUSEd就不用再继续探查了
2.使用过但是remove了,此时是 HashMap.EMPTY,该探查点后边的元素扔可能是有key
3.槽正在使用 _MapEntry节点
"""
class _MapEntry: # 槽里存储的数据
def __init__(self, key, value):
self.key = key
self.value = value
UNUSED = None # 没被使用过的槽,作为该类变量的一个单例,下边都是is 判断
EMPTY = _MapEntry(None, None) # 使用过但是被删除的槽
def __init__(self):
self._table = Array(7) # 初始化7个槽
self._count = 0
# 超过2/3空间被使用就重新分配,load factor = 2/3
self._maxCount = len(self._table) - len(self._table) // 3
def __len__(self):
return self._count
def __contains__(self, key):
slot = self._findSlot(key, False)
return slot is not None
def add(self, key, value):
if key in self: # 覆盖原有value
slot = self._findSlot(key, False)
self._table[slot].value = value
return False
else:
slot = self._findSlot(key, True)
self._table[slot] = HashMap._MapEntry(key, value)
self._count += 1
if self._count == self._maxCount: # 超过2/3使用就rehash
self._rehash()
return True
def valueOf(self, key):
slot = self._findSlot(key, False)
assert slot is not None, 'Invalid map key'
return self._table[slot].value
def remove(self, key):
""" remove操作把槽置为EMPTY"""
assert key in self, 'Key error %s' % key
slot = self._findSlot(key, forInsert=False)
value = self._table[slot].value
self._count -= 1
self._table[slot] = HashMap.EMPTY
return value
def __iter__(self):
return _HashMapIteraotr(self._table)
def _slot_can_insert(self, slot):
return (self._table[slot] is HashMap.EMPTY or
self._table[slot] is HashMap.UNUSED)
def _findSlot(self, key, forInsert=False):
""" 注意原书有错误,代码根本不能运行,这里我自己改写的
Args:
forInsert (bool): if the search is for an insertion
Returns:
slot or None
"""
slot = self._hash1(key)
step = self._hash2(key)
_len = len(self._table)
if not forInsert: # 查找是否存在key
while self._table[slot] is not HashMap.UNUSED:
# 如果一个槽是UNUSED,直接跳出
if self._table[slot] is HashMap.EMPTY:
slot = (slot + step) % _len
continue
elif self._table[slot].key == key:
return slot
slot = (slot + step) % _len
return None
else: # 为了插入key
while not self._slot_can_insert(slot): # 循环直到找到一个可以插入的槽
slot = (slot + step) % _len
return slot
def _rehash(self): # 当前使用槽数量大于2/3时候重新创建新的table
origTable = self._table
newSize = len(self._table) * 2 + 1 # 原来的2*n+1倍
self._table = Array(newSize)
self._count = 0
self._maxCount = newSize - newSize // 3
# 将原来的key value添加到新的table
for entry in origTable:
if entry is not HashMap.UNUSED and entry is not HashMap.EMPTY:
slot = self._findSlot(entry.key, True)
self._table[slot] = entry
self._count += 1
def _hash1(self, key):
""" 计算key的hash值"""
return abs(hash(key)) % len(self._table)
def _hash2(self, key):
""" key冲突时候用来计算新槽的位置"""
return 1 + abs(hash(key)) % (len(self._table)-2)
class _HashMapIteraotr:
def __init__(self, array):
self._array = array
self._idx = 0
def __iter__(self):
return self
def __next__(self):
if self._idx < len(self._array):
if self._array[self._idx] is not None and self._array[self._idx].key is not None:
key = self._array[self._idx].key
self._idx += 1
return key
else:
self._idx += 1
else:
raise StopIteration
def print_h(h):
for idx, i in enumerate(h):
print(idx, i)
print('\n')
def test_HashMap():
""" 一些简单的单元测试,不过测试用例覆盖不是很全面 """
h = HashMap()
assert len(h) == 0
h.add('a', 'a')
assert h.valueOf('a') == 'a'
assert len(h) == 1
a_v = h.remove('a')
assert a_v == 'a'
assert len(h) == 0
h.add('a', 'a')
h.add('b', 'b')
assert len(h) == 2
assert h.valueOf('b') == 'b'
b_v = h.remove('b')
assert b_v == 'b'
assert len(h) == 1
h.remove('a')
assert len(h) == 0
n = 10
for i in range(n):
h.add(str(i), i)
assert len(h) == n
print_h(h)
for i in range(n):
assert str(i) in h
for i in range(n):
h.remove(str(i))
assert len(h) == 0
12章: Advanced Sorting
第5章介绍了基本的排序算法,本章介绍高级排序算法。
归并排序(mergesort): 分治法
def merge_sorted_list(listA, listB):
""" 归并两个有序数组,O(max(m, n)) ,m和n是数组长度"""
print('merge left right list', listA, listB, end='')
new_list = list()
a = b = 0
while a < len(listA) and b < len(listB):
if listA[a] < listB[b]:
new_list.append(listA[a])
a += 1
else:
new_list.append(listB[b])
b += 1
while a < len(listA):
new_list.append(listA[a])
a += 1
while b < len(listB):
new_list.append(listB[b])
b += 1
print(' ->', new_list)
return new_list
def mergesort(theList):
""" O(nlogn),log层调用,每层n次操作
mergesort: divided and conquer 分治
1. 把原数组分解成越来越小的子数组
2. 合并子数组来创建一个有序数组
"""
print(theList) # 我把关键步骤打出来了,你可以运行下看看整个过程
if len(theList) <= 1: # 递归出口
return theList
else:
mid = len(theList) // 2
# 递归分解左右两边数组
left_half = mergesort(theList[:mid])
right_half = mergesort(theList[mid:])
# 合并两边的有序子数组
newList = merge_sorted_list(left_half, right_half)
return newList
""" 这是我调用一次打出来的排序过程
[10, 9, 8, 7, 6, 5, 4, 3, 2, 1]
[10, 9, 8, 7, 6]
[10, 9]
[10]
[9]
merge left right list [10] [9] -> [9, 10]
[8, 7, 6]
[8]
[7, 6]
[7]
[6]
merge left right list [7] [6] -> [6, 7]
merge left right list [8] [6, 7] -> [6, 7, 8]
merge left right list [9, 10] [6, 7, 8] -> [6, 7, 8, 9, 10]
[5, 4, 3, 2, 1]
[5, 4]
[5]
[4]
merge left right list [5] [4] -> [4, 5]
[3, 2, 1]
[3]
[2, 1]
[2]
[1]
merge left right list [2] [1] -> [1, 2]
merge left right list [3] [1, 2] -> [1, 2, 3]
merge left right list [4, 5] [1, 2, 3] -> [1, 2, 3, 4, 5]
"""
快速排序
def quicksort(theSeq, first, last): # average: O(nlog(n))
"""
quicksort :也是分而治之,但是和归并排序不同的是,采用选定主元(pivot)而不是从中间
进行数组划分
1. 第一步选定pivot用来划分数组,pivot左边元素都比它小,右边元素都大于等于它
2. 对划分的左右两边数组递归,直到递归出口(数组元素数目小于2)
3. 对pivot和左右划分的数组合并成一个有序数组
"""
if first < last:
pos = partitionSeq(theSeq, first, last)
# 对划分的子数组递归操作
quicksort(theSeq, first, pos - 1)
quicksort(theSeq, pos + 1, last)
def partitionSeq(theSeq, first, last):
""" 快排中的划分操作,把比pivot小的挪到左边,比pivot大的挪到右边"""
pivot = theSeq[first]
print('before partitionSeq', theSeq)
left = first + 1
right = last
while True:
# 找到第一个比pivot大的
while left <= right and theSeq[left] < pivot:
left += 1
# 从右边开始找到比pivot小的
while right >= left and theSeq[right] >= pivot:
right -= 1
if right < left:
break
else:
theSeq[left], theSeq[right] = theSeq[right], theSeq[left]
# 把pivot放到合适的位置
theSeq[first], theSeq[right] = theSeq[right], theSeq[first]
print('after partitionSeq {}: {}\t'.format(theSeq, pivot))
return right # 返回pivot的位置
def test_partitionSeq():
l = [0,1,2,3,4]
assert partitionSeq(l, 0, len(l)-1) == 0
l = [4,3,2,1,0]
assert partitionSeq(l, 0, len(l)-1) == 4
l = [2,3,0,1,4]
assert partitionSeq(l, 0, len(l)-1) == 2
test_partitionSeq()
def test_quicksort():
def _is_sorted(seq):
for i in range(len(seq)-1):
if seq[i] > seq[i+1]:
return False
return True
from random import randint
for i in range(100):
_len = randint(1, 100)
to_sort = []
for i in range(_len):
to_sort.append(randint(0, 100))
quicksort(to_sort, 0, len(to_sort)-1) # 注意这里用了原地排序,直接更改了数组
print(to_sort)
assert _is_sorted(to_sort)
test_quicksort()
利用快排中的partitionSeq操作,我们还能实现另一个算法,nth_element,快速查找一个无序数组中的第k大元素
def nth_element(seq, beg, end, k):
if beg == end:
return seq[beg]
pivot_index = partitionSeq(seq, beg, end)
if pivot_index == k:
return seq[k]
elif pivot_index > k:
return nth_element(seq, beg, pivot_index-1, k)
else:
return nth_element(seq, pivot_index+1, end, k)
def test_nth_element():
from random import shuffle
n = 10
l = list(range(n))
shuffle(l)
print(l)
for i in range(len(l)):
assert nth_element(l, 0, len(l)-1, i) == i
test_nth_element()
13章: Binary Tree
The binary Tree: 二叉树,每个节点做多只有两个子节点
class _BinTreeNode:
def __init__(self, data):
self.data = data
self.left = None
self.right = None
# 三种depth-first遍历
def preorderTrav(subtree):
""" 先(根)序遍历"""
if subtree is not None:
print(subtree.data)
preorderTrav(subtree.left)
preorderTrav(subtree.right)
def inorderTrav(subtree):
""" 中(根)序遍历"""
if subtree is not None:
preorderTrav(subtree.left)
print(subtree.data)
preorderTrav(subtree.right)
def postorderTrav(subtree):
""" 后(根)序遍历"""
if subtree is not None:
preorderTrav(subtree.left)
preorderTrav(subtree.right)
print(subtree.data)
# 宽度优先遍历(bradth-First Traversal): 一层一层遍历, 使用queue
def breadthFirstTrav(bintree):
from queue import Queue # py3
q = Queue()
q.put(bintree)
while not q.empty():
node = q.get()
print(node.data)
if node.left is not None:
q.put(node.left)
if node.right is not None:
q.put(node.right)
class _ExpTreeNode:
__slots__ = ('element', 'left', 'right')
def __init__(self, data):
self.element = data
self.left = None
self.right = None
def __repr__(self):
return '<_ExpTreeNode: {} {} {}>'.format(
self.element, self.left, self.right)
from queue import Queue
class ExpressionTree:
"""
表达式树: 操作符存储在内节点操作数存储在叶子节点的二叉树。(符号树真难打出来)
*
/ \
+ -
/ \ / \
9 3 8 4
(9+3) * (8-4)
Expression Tree Abstract Data Type,可以实现二元操作符
ExpressionTree(expStr): user string as constructor param
evaluate(varDict): evaluates the expression and returns the numeric result
toString(): constructs and retutns a string represention of the expression
Usage:
vars = {'a': 5, 'b': 12}
expTree = ExpressionTree("(a/(b-3))")
print('The result = ', expTree.evaluate(vars))
"""
def __init__(self, expStr):
self._expTree = None
self._buildTree(expStr)
def evaluate(self, varDict):
return self._evalTree(self._expTree, varDict)
def __str__(self):
return self._buildString(self._expTree)
def _buildString(self, treeNode):
""" 在一个子树被遍历之前添加做括号,在子树被遍历之后添加右括号 """
# print(treeNode)
if treeNode.left is None and treeNode.right is None:
return str(treeNode.element) # 叶子节点是操作数直接返回
else:
expStr = '('
expStr += self._buildString(treeNode.left)
expStr += str(treeNode.element)
expStr += self._buildString(treeNode.right)
expStr += ')'
return expStr
def _evalTree(self, subtree, varDict):
# 是不是叶子节点, 是的话说明是操作数,直接返回
if subtree.left is None and subtree.right is None:
# 操作数是合法数字吗
if subtree.element >= '0' and subtree.element <= '9':
return int(subtree.element)
else: # 操作数是个变量
assert subtree.element in varDict, 'invalid variable.'
return varDict[subtree.element]
else: # 操作符则计算其子表达式
lvalue = self._evalTree(subtree.left, varDict)
rvalue = self._evalTree(subtree.right, varDict)
print(subtree.element)
return self._computeOp(lvalue, subtree.element, rvalue)
def _computeOp(self, left, op, right):
assert op
op_func = {
'+': lambda left, right: left + right, # or import operator, operator.add
'-': lambda left, right: left - right,
'*': lambda left, right: left * right,
'/': lambda left, right: left / right,
'%': lambda left, right: left % right,
}
return op_func[op](left, right)
def _buildTree(self, expStr):
expQ = Queue()
for token in expStr: # 遍历表达式字符串的每个字符
expQ.put(token)
self._expTree = _ExpTreeNode(None) # 创建root节点
self._recBuildTree(self._expTree, expQ)
def _recBuildTree(self, curNode, expQ):
token = expQ.get()
if token == '(':
curNode.left = _ExpTreeNode(None)
self._recBuildTree(curNode.left, expQ)
# next token will be an operator: + = * / %
curNode.element = expQ.get()
curNode.right = _ExpTreeNode(None)
self._recBuildTree(curNode.right, expQ)
# the next token will be ')', remmove it
expQ.get()
else: # the token is a digit that has to be converted to an int.
curNode.element = token
vars = {'a': 5, 'b': 12}
expTree = ExpressionTree("((2*7)+8)")
print(expTree)
print('The result = ', expTree.evaluate(vars))
Heap(堆):二叉树最直接的一个应用就是实现堆。堆就是一颗完全二叉树,最大堆的非叶子节点的值都比孩子大,最小堆的非叶子结点的值都比孩子小。 python内置了heapq模块帮助我们实现堆操作,比如用内置的heapq模块实现个堆排序:
# 使用python内置的heapq实现heap sort
def heapsort(iterable):
from heapq import heappush, heappop
h = []
for value in iterable:
heappush(h, value)
return [heappop(h) for i in range(len(h))]
但是一般实现堆的时候实际上并不是用数节点来实现的,而是使用数组实现,效率比较高。为什么可以用数组实现呢?因为完全二叉树的性质, 可以用下标之间的关系表示节点之间的关系,MaxHeap的docstring中已经说明了
class MaxHeap:
"""
Heaps:
完全二叉树,最大堆的非叶子节点的值都比孩子大,最小堆的非叶子结点的值都比孩子小
Heap包含两个属性,order property 和 shape property(a complete binary tree),在插入
一个新节点的时候,始终要保持这两个属性
插入操作:保持堆属性和完全二叉树属性, sift-up 操作维持堆属性
extract操作:只获取根节点数据,并把树最底层最右节点copy到根节点后,sift-down操作维持堆属性
用数组实现heap,从根节点开始,从上往下从左到右给每个节点编号,则根据完全二叉树的
性质,给定一个节点i, 其父亲和孩子节点的编号分别是:
parent = (i-1) // 2
left = 2 * i + 1
rgiht = 2 * i + 2
使用数组实现堆一方面效率更高,节省树节点的内存占用,一方面还可以避免复杂的指针操作,减少
调试难度。
"""
def __init__(self, maxSize):
self._elements = Array(maxSize) # 第二章实现的Array ADT
self._count = 0
def __len__(self):
return self._count
def capacity(self):
return len(self._elements)
def add(self, value):
assert self._count < self.capacity(), 'can not add to full heap'
self._elements[self._count] = value
self._count += 1
self._siftUp(self._count - 1)
self.assert_keep_heap() # 确定每一步add操作都保持堆属性
def extract(self):
assert self._count > 0, 'can not extract from an empty heap'
value = self._elements[0] # save root value
self._count -= 1
self._elements[0] = self._elements[self._count] # 最右下的节点放到root后siftDown
self._siftDown(0)
self.assert_keep_heap()
return value
def _siftUp(self, ndx):
if ndx > 0:
parent = (ndx - 1) // 2
# print(ndx, parent)
if self._elements[ndx] > self._elements[parent]: # swap
self._elements[ndx], self._elements[parent] = self._elements[parent], self._elements[ndx]
self._siftUp(parent) # 递归
def _siftDown(self, ndx):
left = 2 * ndx + 1
right = 2 * ndx + 2
# determine which node contains the larger value
largest = ndx
if (left < self._count and
self._elements[left] >= self._elements[largest] and
self._elements[left] >= self._elements[right]): # 原书这个地方没写实际上找的未必是largest
largest = left
elif right < self._count and self._elements[right] >= self._elements[largest]:
largest = right
if largest != ndx:
self._elements[ndx], self._elements[largest] = self._elements[largest], self._elements[ndx]
self._siftDown(largest)
def __repr__(self):
return ' '.join(map(str, self._elements))
def assert_keep_heap(self):
""" 我加了这个函数是用来验证每次add或者extract之后,仍保持最大堆的性质"""
_len = len(self)
for i in range(0, int((_len-1)/2)): # 内部节点(非叶子结点)
l = 2 * i + 1
r = 2 * i + 2
if l < _len and r < _len:
assert self._elements[i] >= self._elements[l] and self._elements[i] >= self._elements[r]
def test_MaxHeap():
""" 最大堆实现的单元测试用例 """
_len = 10
h = MaxHeap(_len)
for i in range(_len):
h.add(i)
h.assert_keep_heap()
for i in range(_len):
# 确定每次出来的都是最大的数字,添加的时候是从小到大添加的
assert h.extract() == _len-i-1
test_MaxHeap()
def simpleHeapSort(theSeq):
""" 用自己实现的MaxHeap实现堆排序,直接修改原数组实现inplace排序"""
if not theSeq:
return theSeq
_len = len(theSeq)
heap = MaxHeap(_len)
for i in theSeq:
heap.add(i)
for i in reversed(range(_len)):
theSeq[i] = heap.extract()
return theSeq
def test_simpleHeapSort():
""" 用一些测试用例证明实现的堆排序是可以工作的 """
def _is_sorted(seq):
for i in range(len(seq)-1):
if seq[i] > seq[i+1]:
return False
return True
from random import randint
assert simpleHeapSort([]) == []
for i in range(1000):
_len = randint(1, 100)
to_sort = []
for i in range(_len):
to_sort.append(randint(0, 100))
simpleHeapSort(to_sort) # 注意这里用了原地排序,直接更改了数组
assert _is_sorted(to_sort)
test_simpleHeapSort()
14章: Search Trees
二叉差找树性质:对每个内部节点V, 1. 所有key小于V.key的存储在V的左子树。 2. 所有key大于V.key的存储在V的右子树 对BST进行中序遍历会得到升序的key序列
class _BSTMapNode:
__slots__ = ('key', 'value', 'left', 'right')
def __init__(self, key, value):
self.key = key
self.value = value
self.left = None
self.right = None
def __repr__(self):
return '<{}:{}> left:{}, right:{}'.format(
self.key, self.value, self.left, self.right)
__str__ = __repr__
class BSTMap:
""" BST,树节点包含key可payload。用BST来实现之前用hash实现过的Map ADT.
性质:对每个内部节点V,
1.对于节点V,所有key小于V.key的存储在V的左子树。
2.所有key大于V.key的存储在V的右子树
对BST进行中序遍历会得到升序的key序列
"""
def __init__(self):
self._root = None
self._size = 0
self._rval = None # 作为remove的返回值
def __len__(self):
return self._size
def __iter__(self):
return _BSTMapIterator(self._root, self._size)
def __contains__(self, key):
return self._bstSearch(self._root, key) is not None
def valueOf(self, key):
node = self._bstSearch(self._root, key)
assert node is not None, 'Invalid map key.'
return node.value
def _bstSearch(self, subtree, target):
if subtree is None: # 递归出口,遍历到树底没有找到key或是空树
return None
elif target < subtree.key:
return self._bstSearch(subtree.left, target)
elif target > subtree.key:
return self._bstSearch(subtree.right, target)
return subtree # 返回引用
def _bstMinumum(self, subtree):
""" 顺着树一直往左下角递归找就是最小的,向右下角递归就是最大的 """
if subtree is None:
return None
elif subtree.left is None:
return subtree
else:
return subtree._bstMinumum(self, subtree.left)
def add(self, key, value):
""" 添加或者替代一个key的value, O(N) """
node = self._bstSearch(self._root, key)
if node is not None: # if key already exists, update value
node.value = value
return False
else: # insert a new entry
self._root = self._bstInsert(self._root, key, value)
self._size += 1
return True
def _bstInsert(self, subtree, key, value):
""" 新的节点总是插入在树的叶子结点上 """
if subtree is None:
subtree = _BSTMapNode(key, value)
elif key < subtree.key:
subtree.left = self._bstInsert(subtree.left, key, value)
elif key > subtree.key:
subtree.right = self._bstInsert(subtree.right, key, value)
# 注意这里没有else语句了,应为在被调用处add函数里先判断了是否有重复key
return subtree
def remove(self, key):
""" O(N)
被删除的节点分为三种:
1.叶子结点:直接把其父亲指向该节点的指针置None
2.该节点有一个孩子: 删除该节点后,父亲指向一个合适的该节点的孩子
3.该节点有俩孩子:
(1)找到要删除节点N和其后继S(中序遍历后该节点下一个)
(2)复制S的key到N
(3)从N的右子树中删除后继S(即在N的右子树中最小的)
"""
assert key in self, 'invalid map key'
self._root = self._bstRemove(self._root, key)
self._size -= 1
return self._rval
def _bstRemove(self, subtree, target):
# search for the item in the tree
if subtree is None:
return subtree
elif target < subtree.key:
subtree.left = self._bstRemove(subtree.left, target)
return subtree
elif target > subtree.key:
subtree.right = self._bstRemove(subtree.right, target)
return subtree
else: # found the node containing the item
self._rval = subtree.value
if subtree.left is None and subtree.right is None:
# 叶子node
return None
elif subtree.left is None or subtree.right is None:
# 有一个孩子节点
if subtree.left is not None:
return subtree.left
else:
return subtree.right
else: # 有俩孩子节点
successor = self._bstMinumum(subtree.right)
subtree.key = successor.key
subtree.value = successor.value
subtree.right = self._bstRemove(subtree.right, successor.key)
return subtree
def __repr__(self):
return '->'.join([str(i) for i in self])
def assert_keep_bst_property(self, subtree):
""" 写这个函数为了验证add和delete操作始终维持了bst的性质 """
if subtree is None:
return
if subtree.left is not None and subtree.right is not None:
assert subtree.left.value <= subtree.value
assert subtree.right.value >= subtree.value
self.assert_keep_bst_property(subtree.left)
self.assert_keep_bst_property(subtree.right)
elif subtree.left is None and subtree.right is not None:
assert subtree.right.value >= subtree.value
self.assert_keep_bst_property(subtree.right)
elif subtree.left is not None and subtree.right is None:
assert subtree.left.value <= subtree.value
self.assert_keep_bst_property(subtree.left)
class _BSTMapIterator:
def __init__(self, root, size):
self._theKeys = Array(size)
self._curItem = 0
self._bstTraversal(root)
self._curItem = 0
def __iter__(self):
return self
def __next__(self):
if self._curItem < len(self._theKeys):
key = self._theKeys[self._curItem]
self._curItem += 1
return key
else:
raise StopIteration
def _bstTraversal(self, subtree):
if subtree is not None:
self._bstTraversal(subtree.left)
self._theKeys[self._curItem] = subtree.key
self._curItem += 1
self._bstTraversal(subtree.right)
def test_BSTMap():
l = [60, 25, 100, 35, 17, 80]
bst = BSTMap()
for i in l:
bst.add(i)
def test_HashMap():
""" 之前用来测试用hash实现的map,改为用BST实现的Map测试 """
# h = HashMap()
h = BSTMap()
assert len(h) == 0
h.add('a', 'a')
assert h.valueOf('a') == 'a'
assert len(h) == 1
a_v = h.remove('a')
assert a_v == 'a'
assert len(h) == 0
h.add('a', 'a')
h.add('b', 'b')
assert len(h) == 2
assert h.valueOf('b') == 'b'
b_v = h.remove('b')
assert b_v == 'b'
assert len(h) == 1
h.remove('a')
assert len(h) == 0
_len = 10
for i in range(_len):
h.add(str(i), i)
assert len(h) == _len
for i in range(_len):
assert str(i) in h
for i in range(_len):
print(len(h))
print('bef', h)
_ = h.remove(str(i))
assert _ == i
print('aft', h)
print(len(h))
assert len(h) == 0
test_HashMap()