8.22 不用递归实现访问者模式

问题

你使用访问者模式遍历一个很深的嵌套树形数据结构,并且因为超过嵌套层级限制而失败。你想消除递归,并同时保持访问者编程模式。

解决方案

通过巧妙的使用生成器可以在树遍历或搜索算法中消除递归。在8.21小节中,我们给出了一个访问者类。下面我们利用一个栈和生成器重新实现这个类:

  1. import types
  2.  
  3. class Node:
  4. pass
  5.  
  6. class NodeVisitor:
  7. def visit(self, node):
  8. stack = [node]
  9. last_result = None
  10. while stack:
  11. try:
  12. last = stack[-1]
  13. if isinstance(last, types.GeneratorType):
  14. stack.append(last.send(last_result))
  15. last_result = None
  16. elif isinstance(last, Node):
  17. stack.append(self._visit(stack.pop()))
  18. else:
  19. last_result = stack.pop()
  20. except StopIteration:
  21. stack.pop()
  22.  
  23. return last_result
  24.  
  25. def _visit(self, node):
  26. methname = 'visit_' + type(node).__name__
  27. meth = getattr(self, methname, None)
  28. if meth is None:
  29. meth = self.generic_visit
  30. return meth(node)
  31.  
  32. def generic_visit(self, node):
  33. raise RuntimeError('No {} method'.format('visit_' + type(node).__name__))

如果你使用这个类,也能达到相同的效果。事实上你完全可以将它作为上一节中的访问者模式的替代实现。考虑如下代码,遍历一个表达式的树:

  1. class UnaryOperator(Node):
  2. def __init__(self, operand):
  3. self.operand = operand
  4.  
  5. class BinaryOperator(Node):
  6. def __init__(self, left, right):
  7. self.left = left
  8. self.right = right
  9.  
  10. class Add(BinaryOperator):
  11. pass
  12.  
  13. class Sub(BinaryOperator):
  14. pass
  15.  
  16. class Mul(BinaryOperator):
  17. pass
  18.  
  19. class Div(BinaryOperator):
  20. pass
  21.  
  22. class Negate(UnaryOperator):
  23. pass
  24.  
  25. class Number(Node):
  26. def __init__(self, value):
  27. self.value = value
  28.  
  29. # A sample visitor class that evaluates expressions
  30. class Evaluator(NodeVisitor):
  31. def visit_Number(self, node):
  32. return node.value
  33.  
  34. def visit_Add(self, node):
  35. return self.visit(node.left) + self.visit(node.right)
  36.  
  37. def visit_Sub(self, node):
  38. return self.visit(node.left) - self.visit(node.right)
  39.  
  40. def visit_Mul(self, node):
  41. return self.visit(node.left) * self.visit(node.right)
  42.  
  43. def visit_Div(self, node):
  44. return self.visit(node.left) / self.visit(node.right)
  45.  
  46. def visit_Negate(self, node):
  47. return -self.visit(node.operand)
  48.  
  49. if __name__ == '__main__':
  50. # 1 + 2*(3-4) / 5
  51. t1 = Sub(Number(3), Number(4))
  52. t2 = Mul(Number(2), t1)
  53. t3 = Div(t2, Number(5))
  54. t4 = Add(Number(1), t3)
  55. # Evaluate it
  56. e = Evaluator()
  57. print(e.visit(t4)) # Outputs 0.6

如果嵌套层次太深那么上述的Evaluator就会失效:

  1. >>> a = Number(0)
  2. >>> for n in range(1, 100000):
  3. ... a = Add(a, Number(n))
  4. ...
  5. >>> e = Evaluator()
  6. >>> e.visit(a)
  7. Traceback (most recent call last):
  8. ...
  9. File "visitor.py", line 29, in _visit
  10. return meth(node)
  11. File "visitor.py", line 67, in visit_Add
  12. return self.visit(node.left) + self.visit(node.right)
  13. RuntimeError: maximum recursion depth exceeded
  14. >>>

现在我们稍微修改下上面的Evaluator:

  1. class Evaluator(NodeVisitor):
  2. def visit_Number(self, node):
  3. return node.value
  4.  
  5. def visit_Add(self, node):
  6. yield (yield node.left) + (yield node.right)
  7.  
  8. def visit_Sub(self, node):
  9. yield (yield node.left) - (yield node.right)
  10.  
  11. def visit_Mul(self, node):
  12. yield (yield node.left) * (yield node.right)
  13.  
  14. def visit_Div(self, node):
  15. yield (yield node.left) / (yield node.right)
  16.  
  17. def visit_Negate(self, node):
  18. yield - (yield node.operand)

再次运行,就不会报错了:

  1. >>> a = Number(0)
  2. >>> for n in range(1,100000):
  3. ... a = Add(a, Number(n))
  4. ...
  5. >>> e = Evaluator()
  6. >>> e.visit(a)
  7. 4999950000
  8. >>>

如果你还想添加其他自定义逻辑也没问题:

  1. class Evaluator(NodeVisitor):
  2. ...
  3. def visit_Add(self, node):
  4. print('Add:', node)
  5. lhs = yield node.left
  6. print('left=', lhs)
  7. rhs = yield node.right
  8. print('right=', rhs)
  9. yield lhs + rhs
  10. ...

下面是简单的测试:

  1. >>> e = Evaluator()
  2. >>> e.visit(t4)
  3. Add: <__main__.Add object at 0x1006a8d90>
  4. left= 1
  5. right= -0.4
  6. 0.6
  7. >>>

讨论

这一小节我们演示了生成器和协程在程序控制流方面的强大功能。避免递归的一个通常方法是使用一个栈或队列的数据结构。例如,深度优先的遍历算法,第一次碰到一个节点时将其压入栈中,处理完后弹出栈。visit() 方法的核心思路就是这样。

另外一个需要理解的就是生成器中yield语句。当碰到yield语句时,生成器会返回一个数据并暂时挂起。上面的例子使用这个技术来代替了递归。例如,之前我们是这样写递归:

  1. value = self.visit(node.left)

现在换成yield语句:

  1. value = yield node.left

它会将 node.left 返回给 visit() 方法,然后 visit() 方法调用那个节点相应的 visit_Name() 方法。yield暂时将程序控制器让出给调用者,当执行完后,结果会赋值给value,

看完这一小节,你也许想去寻找其它没有yield语句的方案。但是这么做没有必要,你必须处理很多棘手的问题。例如,为了消除递归,你必须要维护一个栈结构,如果不使用生成器,代码会变得很臃肿,到处都是栈操作语句、回调函数等。实际上,使用yield语句可以让你写出非常漂亮的代码,它消除了递归但是看上去又很像递归实现,代码很简洁。

原文:

http://python3-cookbook.readthedocs.io/zh_CN/latest/c08/p22_implementing_visitor_pattern_without_recursion.html