5. 用 args 和 *kwargs 自定义聚合函数
# 用inspect模块查看groupby对象的agg方法的签名
In[31]: college = pd.read_csv('data/college.csv')
grouped = college.groupby(['STABBR', 'RELAFFIL'])
In[32]: import inspect
inspect.signature(grouped.agg)
Out[32]: <Signature (arg, *args, **kwargs)>
如何做
# 自定义一个返回去本科生人数在1000和3000之间的比例的函数
In[33]: def pct_between_1_3k(s):
return s.between(1000, 3000).mean()
# 用州和宗教分组,再聚合
In[34]: college.groupby(['STABBR', 'RELAFFIL'])['UGDS'].agg(pct_between_1_3k).head(9)
Out[34]:
STABBR RELAFFIL
AK 0 0.142857
1 0.000000
AL 0 0.236111
1 0.333333
AR 0 0.279412
1 0.111111
AS 0 1.000000
AZ 0 0.096774
1 0.000000
Name: UGDS, dtype: float64
# 但是这个函数不能让用户自定义上下限,再新写一个函数
In[35]: def pct_between(s, low, high):
return s.between(low, high).mean()
# 使用这个自定义聚合函数,并传入最大和最小值
In[36]: college.groupby(['STABBR', 'RELAFFIL'])['UGDS'].agg(pct_between, 1000, 10000).head(9)
Out[36]:
STABBR RELAFFIL
AK 0 0.428571
1 0.000000
AL 0 0.458333
1 0.375000
AR 0 0.397059
1 0.166667
AS 0 1.000000
AZ 0 0.233871
1 0.111111
Name: UGDS, dtype: float64
原理
# 显示指定最大和最小值
In[37]: college.groupby(['STABBR', 'RELAFFIL'])['UGDS'].agg(pct_between, high=10000, low=1000).head(9)
Out[37]:
STABBR RELAFFIL
AK 0 0.428571
1 0.000000
AL 0 0.458333
1 0.375000
AR 0 0.397059
1 0.166667
AS 0 1.000000
AZ 0 0.233871
1 0.111111
Name: UGDS, dtype: float64
# 也可以关键字参数和非关键字参数混合使用,只要非关键字参数在后面
In[38]: college.groupby(['STABBR', 'RELAFFIL'])['UGDS'].agg(pct_between, 1000, high=10000).head(9)
Out[38]:
STABBR RELAFFIL
AK 0 0.428571
1 0.000000
AL 0 0.458333
1 0.375000
AR 0 0.397059
1 0.166667
AS 0 1.000000
AZ 0 0.233871
1 0.111111
Name: UGDS, dtype: float64
更多
# Pandas不支持多重聚合时,使用参数
In[39]: college.groupby(['STABBR', 'RELAFFIL'])['UGDS'].agg(['mean', pct_between], low=100, high=1000)
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-39-3e3e18919cf9> in <module>()
----> 1 college.groupby(['STABBR', 'RELAFFIL'])['UGDS'].agg(['mean', pct_between], low=100, high=1000)
/Users/Ted/anaconda/lib/python3.6/site-packages/pandas/core/groupby.py in aggregate(self, func_or_funcs, *args, **kwargs)
2871 if hasattr(func_or_funcs, '__iter__'):
2872 ret = self._aggregate_multiple_funcs(func_or_funcs,
-> 2873 (_level or 0) + 1)
2874 else:
2875 cyfunc = self._is_cython_func(func_or_funcs)
/Users/Ted/anaconda/lib/python3.6/site-packages/pandas/core/groupby.py in _aggregate_multiple_funcs(self, arg, _level)
2944 obj._reset_cache()
2945 obj._selection = name
-> 2946 results[name] = obj.aggregate(func)
2947
2948 if isinstance(list(compat.itervalues(results))[0],
/Users/Ted/anaconda/lib/python3.6/site-packages/pandas/core/groupby.py in aggregate(self, func_or_funcs, *args, **kwargs)
2878
2879 if self.grouper.nkeys > 1:
-> 2880 return self._python_agg_general(func_or_funcs, *args, **kwargs)
2881
2882 try:
/Users/Ted/anaconda/lib/python3.6/site-packages/pandas/core/groupby.py in _python_agg_general(self, func, *args, **kwargs)
852
853 if len(output) == 0:
--> 854 return self._python_apply_general(f)
855
856 if self.grouper._filter_empty_groups:
/Users/Ted/anaconda/lib/python3.6/site-packages/pandas/core/groupby.py in _python_apply_general(self, f)
718 def _python_apply_general(self, f):
719 keys, values, mutated = self.grouper.apply(f, self._selected_obj,
--> 720 self.axis)
721
722 return self._wrap_applied_output(
/Users/Ted/anaconda/lib/python3.6/site-packages/pandas/core/groupby.py in apply(self, f, data, axis)
1800 # group might be modified
1801 group_axes = _get_axes(group)
-> 1802 res = f(group)
1803 if not _is_indexed_like(res, group_axes):
1804 mutated = True
/Users/Ted/anaconda/lib/python3.6/site-packages/pandas/core/groupby.py in <lambda>(x)
840 def _python_agg_general(self, func, *args, **kwargs):
841 func = self._is_builtin_func(func)
--> 842 f = lambda x: func(x, *args, **kwargs)
843
844 # iterate through "columns" ex exclusions to populate output dict
TypeError: pct_between() missing 2 required positional arguments: 'low' and 'high'
# 用闭包自定义聚合函数
In[40]: def make_agg_func(func, name, *args, **kwargs):
def wrapper(x):
return func(x, *args, **kwargs)
wrapper.__name__ = name
return wrapper
my_agg1 = make_agg_func(pct_between, 'pct_1_3k', low=1000, high=3000)
my_agg2 = make_agg_func(pct_between, 'pct_10_30k', 10000, 30000)['UGDS'].agg(pct_between, 1000, high=10000).head(9)
Out[41]: