9.20 利用函数注解实现方法重载

问题

你已经学过怎样使用函数参数注解,那么你可能会想利用它来实现基于类型的方法重载。但是你不确定应该怎样去实现(或者到底行得通不)。

解决方案

本小节的技术是基于一个简单的技术,那就是Python允许参数注解,代码可以像下面这样写:

  1. class Spam:
  2. def bar(self, x:int, y:int):
  3. print('Bar 1:', x, y)
  4.  
  5. def bar(self, s:str, n:int = 0):
  6. print('Bar 2:', s, n)
  7.  
  8. s = Spam()
  9. s.bar(2, 3) # Prints Bar 1: 2 3
  10. s.bar('hello') # Prints Bar 2: hello 0

下面是我们第一步的尝试,使用到了一个元类和描述器:

  1. # multiple.py
  2. import inspect
  3. import types
  4.  
  5. class MultiMethod:
  6. '''
  7. Represents a single multimethod.
  8. '''
  9. def __init__(self, name):
  10. self._methods = {}
  11. self.__name__ = name
  12.  
  13. def register(self, meth):
  14. '''
  15. Register a new method as a multimethod
  16. '''
  17. sig = inspect.signature(meth)
  18.  
  19. # Build a type signature from the method's annotations
  20. types = []
  21. for name, parm in sig.parameters.items():
  22. if name == 'self':
  23. continue
  24. if parm.annotation is inspect.Parameter.empty:
  25. raise TypeError(
  26. 'Argument {} must be annotated with a type'.format(name)
  27. )
  28. if not isinstance(parm.annotation, type):
  29. raise TypeError(
  30. 'Argument {} annotation must be a type'.format(name)
  31. )
  32. if parm.default is not inspect.Parameter.empty:
  33. self._methods[tuple(types)] = meth
  34. types.append(parm.annotation)
  35.  
  36. self._methods[tuple(types)] = meth
  37.  
  38. def __call__(self, *args):
  39. '''
  40. Call a method based on type signature of the arguments
  41. '''
  42. types = tuple(type(arg) for arg in args[1:])
  43. meth = self._methods.get(types, None)
  44. if meth:
  45. return meth(*args)
  46. else:
  47. raise TypeError('No matching method for types {}'.format(types))
  48.  
  49. def __get__(self, instance, cls):
  50. '''
  51. Descriptor method needed to make calls work in a class
  52. '''
  53. if instance is not None:
  54. return types.MethodType(self, instance)
  55. else:
  56. return self
  57.  
  58. class MultiDict(dict):
  59. '''
  60. Special dictionary to build multimethods in a metaclass
  61. '''
  62. def __setitem__(self, key, value):
  63. if key in self:
  64. # If key already exists, it must be a multimethod or callable
  65. current_value = self[key]
  66. if isinstance(current_value, MultiMethod):
  67. current_value.register(value)
  68. else:
  69. mvalue = MultiMethod(key)
  70. mvalue.register(current_value)
  71. mvalue.register(value)
  72. super().__setitem__(key, mvalue)
  73. else:
  74. super().__setitem__(key, value)
  75.  
  76. class MultipleMeta(type):
  77. '''
  78. Metaclass that allows multiple dispatch of methods
  79. '''
  80. def __new__(cls, clsname, bases, clsdict):
  81. return type.__new__(cls, clsname, bases, dict(clsdict))
  82.  
  83. @classmethod
  84. def __prepare__(cls, clsname, bases):
  85. return MultiDict()

为了使用这个类,你可以像下面这样写:

  1. class Spam(metaclass=MultipleMeta):
  2. def bar(self, x:int, y:int):
  3. print('Bar 1:', x, y)
  4.  
  5. def bar(self, s:str, n:int = 0):
  6. print('Bar 2:', s, n)
  7.  
  8. # Example: overloaded __init__
  9. import time
  10.  
  11. class Date(metaclass=MultipleMeta):
  12. def __init__(self, year: int, month:int, day:int):
  13. self.year = year
  14. self.month = month
  15. self.day = day
  16.  
  17. def __init__(self):
  18. t = time.localtime()
  19. self.__init__(t.tm_year, t.tm_mon, t.tm_mday)

下面是一个交互示例来验证它能正确的工作:

  1. >>> s = Spam()
  2. >>> s.bar(2, 3)
  3. Bar 1: 2 3
  4. >>> s.bar('hello')
  5. Bar 2: hello 0
  6. >>> s.bar('hello', 5)
  7. Bar 2: hello 5
  8. >>> s.bar(2, 'hello')
  9. Traceback (most recent call last):
  10. File "<stdin>", line 1, in <module>
  11. File "multiple.py", line 42, in __call__
  12. raise TypeError('No matching method for types {}'.format(types))
  13. TypeError: No matching method for types (<class 'int'>, <class 'str'>)
  14. >>> # Overloaded __init__
  15. >>> d = Date(2012, 12, 21)
  16. >>> # Get today's date
  17. >>> e = Date()
  18. >>> e.year
  19. 2012
  20. >>> e.month
  21. 12
  22. >>> e.day
  23. 3
  24. >>>

讨论

坦白来讲,相对于通常的代码而已本节使用到了很多的魔法代码。但是,它却能让我们深入理解元类和描述器的底层工作原理,并能加深对这些概念的印象。因此,就算你并不会立即去应用本节的技术,它的一些底层思想却会影响到其它涉及到元类、描述器和函数注解的编程技术。

本节的实现中的主要思路其实是很简单的。MutipleMeta 元类使用它的 prepare() 方法来提供一个作为 MultiDict 实例的自定义字典。这个跟普通字典不一样的是,MultiDict 会在元素被设置的时候检查是否已经存在,如果存在的话,重复的元素会在 MultiMethod实例中合并。

MultiMethod 实例通过构建从类型签名到函数的映射来收集方法。在这个构建过程中,函数注解被用来收集这些签名然后构建这个映射。这个过程在 MultiMethod.register() 方法中实现。这种映射的一个关键特点是对于多个方法,所有参数类型都必须要指定,否则就会报错。

为了让 MultiMethod 实例模拟一个调用,它的 call() 方法被实现了。这个方法从所有排除 slef 的参数中构建一个类型元组,在内部map中查找这个方法,然后调用相应的方法。为了能让 MultiMethod 实例在类定义时正确操作,get() 是必须得实现的。它被用来构建正确的绑定方法。比如:

  1. >>> b = s.bar
  2. >>> b
  3. <bound method Spam.bar of <__main__.Spam object at 0x1006a46d0>>
  4. >>> b.__self__
  5. <__main__.Spam object at 0x1006a46d0>
  6. >>> b.__func__
  7. <__main__.MultiMethod object at 0x1006a4d50>
  8. >>> b(2, 3)
  9. Bar 1: 2 3
  10. >>> b('hello')
  11. Bar 2: hello 0
  12. >>>

不过本节的实现还有一些限制,其中一个是它不能使用关键字参数。例如:

  1. >>> s.bar(x=2, y=3)
  2. Traceback (most recent call last):
  3. File "<stdin>", line 1, in <module>
  4. TypeError: __call__() got an unexpected keyword argument 'y'
  5.  
  6. >>> s.bar(s='hello')
  7. Traceback (most recent call last):
  8. File "<stdin>", line 1, in <module>
  9. TypeError: __call__() got an unexpected keyword argument 's'
  10. >>>

也许有其他的方法能添加这种支持,但是它需要一个完全不同的方法映射方式。问题在于关键字参数的出现是没有顺序的。当它跟位置参数混合使用时,那你的参数就会变得比较混乱了,这时候你不得不在 call() 方法中先去做个排序。

同样对于继承也是有限制的,例如,类似下面这种代码就不能正常工作:

  1. class A:
  2. pass
  3.  
  4. class B(A):
  5. pass
  6.  
  7. class C:
  8. pass
  9.  
  10. class Spam(metaclass=MultipleMeta):
  11. def foo(self, x:A):
  12. print('Foo 1:', x)
  13.  
  14. def foo(self, x:C):
  15. print('Foo 2:', x)

原因是因为 x:A 注解不能成功匹配子类实例(比如B的实例),如下:

  1. >>> s = Spam()
  2. >>> a = A()
  3. >>> s.foo(a)
  4. Foo 1: <__main__.A object at 0x1006a5310>
  5. >>> c = C()
  6. >>> s.foo(c)
  7. Foo 2: <__main__.C object at 0x1007a1910>
  8. >>> b = B()
  9. >>> s.foo(b)
  10. Traceback (most recent call last):
  11. File "<stdin>", line 1, in <module>
  12. File "multiple.py", line 44, in __call__
  13. raise TypeError('No matching method for types {}'.format(types))
  14. TypeError: No matching method for types (<class '__main__.B'>,)
  15. >>>

作为使用元类和注解的一种替代方案,可以通过描述器来实现类似的效果。例如:

  1. import types
  2.  
  3. class multimethod:
  4. def __init__(self, func):
  5. self._methods = {}
  6. self.__name__ = func.__name__
  7. self._default = func
  8.  
  9. def match(self, *types):
  10. def register(func):
  11. ndefaults = len(func.__defaults__) if func.__defaults__ else 0
  12. for n in range(ndefaults+1):
  13. self._methods[types[:len(types) - n]] = func
  14. return self
  15. return register
  16.  
  17. def __call__(self, *args):
  18. types = tuple(type(arg) for arg in args[1:])
  19. meth = self._methods.get(types, None)
  20. if meth:
  21. return meth(*args)
  22. else:
  23. return self._default(*args)
  24.  
  25. def __get__(self, instance, cls):
  26. if instance is not None:
  27. return types.MethodType(self, instance)
  28. else:
  29. return self

为了使用描述器版本,你需要像下面这样写:

  1. class Spam:
  2. @multimethod
  3. def bar(self, *args):
  4. # Default method called if no match
  5. raise TypeError('No matching method for bar')
  6.  
  7. @bar.match(int, int)
  8. def bar(self, x, y):
  9. print('Bar 1:', x, y)
  10.  
  11. @bar.match(str, int)
  12. def bar(self, s, n = 0):
  13. print('Bar 2:', s, n)

描述器方案同样也有前面提到的限制(不支持关键字参数和继承)。

所有事物都是平等的,有好有坏,也许最好的办法就是在普通代码中避免使用方法重载。不过有些特殊情况下还是有意义的,比如基于模式匹配的方法重载程序中。举个例子,8.21小节中的访问者模式可以修改为一个使用方法重载的类。但是,除了这个以外,通常不应该使用方法重载(就简单的使用不同名称的方法就行了)。

在Python社区对于实现方法重载的讨论已经由来已久。对于引发这个争论的原因,可以参考下Guido van Rossum的这篇博客:Five-Minute Multimethods in Python

原文:

http://python3-cookbook.readthedocs.io/zh_CN/latest/c09/p20_implement_multiple_dispatch_with_function_annotations.html