grpc 认证鉴权 —— oauth2

前面我们说了 tls 认证,tls 保证了 client 和 server 通信的安全性,但是无法做到接口级别的权限控制。例如有 A、B、C、D 四个系统,存在下面两个场景: 1、我们希望 A 可以访问 B、C 系统,但是不能访问 D 系统 2、B 系统提供了 b1、b2、b3 三个接口,我们希望 A 系统可以访问 b1、b2 接口,但是不能访问 b3 接口。 此时 tls 认证肯定是无法实现上面两个诉求的,对于这两个场景,grpc 提供了 oauth2 的认证方式。对 oauth2 不了解的同学可以参考 http://www.ruanyifeng.com/blog/2019/04/oauth_design.html

oauth2 认证鉴权实现

grpc 官方提供了对 oauth2 认证鉴权的实现 demo,放在 examples 目录的 features 目录的 authentication 目录下,我们来看一下源码实现

server

server 端源码实现如下:

  1. func main() {
  2. flag.Parse()
  3. fmt.Printf("server starting on port %d...\n", *port)
  4. cert, err := tls.LoadX509KeyPair(testdata.Path("server1.pem"), testdata.Path("server1.key"))
  5. if err != nil {
  6. log.Fatalf("failed to load key pair: %s", err)
  7. }
  8. opts := []grpc.ServerOption{
  9. // The following grpc.ServerOption adds an interceptor for all unary
  10. // RPCs. To configure an interceptor for streaming RPCs, see:
  11. // https://godoc.org/google.golang.org/grpc#StreamInterceptor
  12. grpc.UnaryInterceptor(ensureValidToken),
  13. // Enable TLS for all incoming connections.
  14. grpc.Creds(credentials.NewServerTLSFromCert(&cert)),
  15. }
  16. s := grpc.NewServer(opts...)
  17. ecpb.RegisterEchoServer(s, &ecServer{})
  18. lis, err := net.Listen("tcp", fmt.Sprintf(":%d", *port))
  19. if err != nil {
  20. log.Fatalf("failed to listen: %v", err)
  21. }
  22. if err := s.Serve(lis); err != nil {
  23. log.Fatalf("failed to serve: %v", err)
  24. }
  25. }

server 端先调用了 tls 包下的 LoadX509KeyPair,通过 server 的公钥和私钥生成了一个 Certificate 结构体来保存证书信息。然后注册了一个校验 token 的方法到拦截器中,并将证书信息设置到 serverOption 中,构造 server 的时候层层透传进去,最终会被设置到 Server 里面 ServerOptions 结构中的 credentials.TransportCredentials 和 UnaryServerInterceptor 中。

我们来看看这两个结构什么时候会被调用,先梳理调用链路,在 s.Serve ——> s.handleRawConn ——> s.serveStreams ——> s.handleStream ——> s.processUnaryRPC 方法中有一行

  1. reply, appErr := md.Handler(srv.server, ctx, df, s.opts.unaryInt)

可以看到调用了 md.Handler 方法,将 s.opts.unaryInt 这个结构传入了进去。s.opts.unaryInt 就是我们之前注册的 UnaryServerInterceptor 拦截器。md 是一个 MethodDesc 这个结构,包括了 MethodName 和 Handler

  1. type MethodDesc struct {
  2. MethodName string
  3. Handler methodHandler
  4. }

这里会取出我们之前注册进去的结构,还记得我们介绍 helloworld 时 RegisterService 吗?至于如何取出 MethodName,源码中的设计非常复杂,经过了层层包装,这里不是本节重点就不赘述了。

  1. func RegisterGreeterServer(s *grpc.Server, srv GreeterServer) {
  2. s.RegisterService(&_Greeter_serviceDesc, srv)
  3. }
  4. var _Greeter_serviceDesc = grpc.ServiceDesc{
  5. ServiceName: "helloworld.Greeter",
  6. HandlerType: (*GreeterServer)(nil),
  7. Methods: []grpc.MethodDesc{
  8. {
  9. MethodName: "SayHello",
  10. Handler: _Greeter_SayHello_Handler,
  11. },
  12. },
  13. Streams: []grpc.StreamDesc{},
  14. Metadata: "helloworld.proto",
  15. }

我们看到 md.Handler 其实是 _Greeter_SayHello_Handler 这个结构,它也是在 pb 文件中生成的。

  1. func _Greeter_SayHello_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
  2. in := new(HelloRequest)
  3. if err := dec(in); err != nil {
  4. return nil, err
  5. }
  6. if interceptor == nil {
  7. return srv.(GreeterServer).SayHello(ctx, in)
  8. }
  9. info := &grpc.UnaryServerInfo{
  10. Server: srv,
  11. FullMethod: "/helloworld.Greeter/SayHello",
  12. }
  13. handler := func(ctx context.Context, req interface{}) (interface{}, error) {
  14. return srv.(GreeterServer).SayHello(ctx, req.(*HelloRequest))
  15. }
  16. return interceptor(ctx, in, info, handler)
  17. }

这里调用了我们传入的 interceptor 方法。回到我们的调用:

  1. reply, appErr := md.Handler(srv.server, ctx, df, s.opts.unaryInt)

所以其实是调用了 s.opts.unaryInt 这个拦截器。这个拦截器是我们之前在 创建 server 的时候赋值的。

  1. opts := []grpc.ServerOption{
  2. // The following grpc.ServerOption adds an interceptor for all unary
  3. // RPCs. To configure an interceptor for streaming RPCs, see:
  4. // https://godoc.org/google.golang.org/grpc#StreamInterceptor
  5. grpc.UnaryInterceptor(ensureValidToken),
  6. // Enable TLS for all incoming connections.
  7. grpc.Creds(credentials.NewServerTLSFromCert(&cert)),
  8. }
  9. s := grpc.NewServer(opts...)

看 grpc.UnaryInterceptor 这个方法,其实是将 ensureValidToken 这个函数赋值给了 s.opts.unaryInt

  1. func UnaryInterceptor(i UnaryServerInterceptor) ServerOption {
  2. return newFuncServerOption(func(o *serverOptions) {
  3. if o.unaryInt != nil {
  4. panic("The unary server interceptor was already set and may not be reset.")
  5. }
  6. o.unaryInt = i
  7. })
  8. }

所以之前我们执行的这一行

  1. return interceptor(ctx, in, info, handler)

其实是执行了 ensureValidToken 这个函数,这个函数就是我们在 server 端定义的 token 校验的函数。先取出我们传入的 metadata 数据,然后校验 token

  1. func ensureValidToken(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
  2. md, ok := metadata.FromIncomingContext(ctx)
  3. if !ok {
  4. return nil, errMissingMetadata
  5. }
  6. // The keys within metadata.MD are normalized to lowercase.
  7. // See: https://godoc.org/google.golang.org/grpc/metadata#New
  8. if !valid(md["authorization"]) {
  9. return nil, errInvalidToken
  10. }
  11. // Continue execution of handler after ensuring a valid token.
  12. return handler(ctx, req)
  13. }

校验完 token 后,最终执行了 handler(ctx, req)

  1. handler := func(ctx context.Context, req interface{}) (interface{}, error) {
  2. return srv.(GreeterServer).SayHello(ctx, req.(*HelloRequest))
  3. }
  4. return interceptor(ctx, in, info, handler)

可以看到最终其实执行了 GreeterServer 的 SayHello 这个函数,也就是我们在 main 函数中定义的,这个函数就是我们在 server 端定义的提供 SayHello 给客户端回消息的函数。

  1. // SayHello implements helloworld.GreeterServer
  2. func (s *server) SayHello(ctx context.Context, in *pb.HelloRequest) (*pb.HelloReply, error) {
  3. log.Printf("Received: %v", in.Name)
  4. return &pb.HelloReply{Message: "Hello " + in.Name}, nil
  5. }

这里还可以额外说一下,md.Handler 执行完之后,其实 reply 就是 SayHello 的回包。

  1. reply, appErr := md.Handler(srv.server, ctx, df, s.opts.unaryInt)

获取到回包之后 server 执行了 sendResponse 方法,将回包发送给 client,这个方法我们之前已经剖析过了,最终会调用 http2Server 的 Write 方法。

  1. if err := s.sendResponse(t, stream, reply, cp, opts, comp); err != nil {

看到这里,server 端对 token 的校验在哪里执行的我们已经清楚了。假如还没有被绕晕,那么恭喜你!可以继续完成 client 的挑战了。

client

在 client 中,先看 main 函数

  1. // Set up the credentials for the connection.
  2. perRPC := oauth.NewOauthAccess(fetchToken())
  3. creds, err := credentials.NewClientTLSFromFile(testdata.Path("ca.pem"), "x.test.youtube.com")
  4. if err != nil {
  5. log.Fatalf("failed to load credentials: %v", err)
  6. }
  7. opts := []grpc.DialOption{
  8. // In addition to the following grpc.DialOption, callers may also use
  9. // the grpc.CallOption grpc.PerRPCCredentials with the RPC invocation
  10. // itself.
  11. // See: https://godoc.org/google.golang.org/grpc#PerRPCCredentials
  12. grpc.WithPerRPCCredentials(perRPC),
  13. // oauth.NewOauthAccess requires the configuration of transport
  14. // credentials.
  15. grpc.WithTransportCredentials(creds),
  16. }
  17. conn, err := grpc.Dial(*addr, opts...)

可以看到 client 首先通过 NewOauthAccess 方法生成了包含 token 信息的 PerRPCCredentials 结构

  1. func NewOauthAccess(token *oauth2.Token) credentials.PerRPCCredentials {
  2. return oauthAccess{token: *token}
  3. }

然后再将 PerRPCCredentials 通过 grpc.WithPerRPCCredentials(perRPC) 添加到了到了 client 的 DialOptions 中的 transport.ConnectOptions 结构中的 [] credentials.PerRPCCredentials 结构中。

那么这个结构什么时候被使用呢,我们来看看。先梳理下调用链 ,在 client 调用的 Invoke ——> invoke ——> newClientStream ——> cs.newAttemptLocked ——> cs.cc.getTransport ——> pick ——> acw.getAddrConn().getReadyTransport() ——> ac.connect() ——> ac.resetTransport() ——> ac.tryAllAddrs ——> ac.createTransport ——> transport.NewClientTransport ——> newHTTP2Client 这个方法里面,有这么一段代码,先取出 []credentials.PerRPCCredentials 中的所有 PerRPCCredentials 添加到 perRPCCreds 中。

  1. transportCreds := opts.TransportCredentials
  2. perRPCCreds := opts.PerRPCCredentials
  3. if b := opts.CredsBundle; b != nil {
  4. if t := b.TransportCredentials(); t != nil {
  5. transportCreds = t
  6. }
  7. if t := b.PerRPCCredentials(); t != nil {
  8. perRPCCreds = append(perRPCCreds, t)
  9. }
  10. }

然后再将 perRPCCreds 赋值给 http2Client 的 perRPCCreds 属性

  1. t := &http2Client{
  2. ...
  3. perRPCCreds: perRPCCreds,
  4. ...
  5. }

那么 perRPCCreds 属性什么时候被用呢?来继续跟踪,newClientStream 方法中有一段代码

  1. op := func(a *csAttempt) error { return a.newStream() }
  2. if err := cs.withRetry(op, func() { cs.bufferForRetryLocked(0, op) }); err != nil {
  3. cs.finish(err)
  4. return nil, err
  5. }

这里调用了 csAttempt 的 newStream ——> a.t.NewStream (http2Client 的 NewStream) ——> createHeaderFields ——> getTrAuthData 方法

  1. func (t *http2Client) getTrAuthData(ctx context.Context, audience string) (map[string]string, error) {
  2. if len(t.perRPCCreds) == 0 {
  3. return nil, nil
  4. }
  5. authData := map[string]string{}
  6. for _, c := range t.perRPCCreds {
  7. data, err := c.GetRequestMetadata(ctx, audience)
  8. if err != nil {
  9. if _, ok := status.FromError(err); ok {
  10. return nil, err
  11. }
  12. return nil, status.Errorf(codes.Unauthenticated, "transport: %v", err)
  13. }
  14. for k, v := range data {
  15. // Capital header names are illegal in HTTP/2.
  16. k = strings.ToLower(k)
  17. authData[k] = v
  18. }
  19. }
  20. return authData, nil
  21. }

这个方法,通过调用 GetRequestMetadata 取出 token 信息,这里会调用 oauth 的 GetRequestMetadata 方法 ,按照指定格式拼装成一个 map[string]string{} 的形式

  1. func (s *serviceAccount) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) {
  2. s.mu.Lock()
  3. defer s.mu.Unlock()
  4. if !s.t.Valid() {
  5. var err error
  6. s.t, err = s.config.TokenSource(ctx).Token()
  7. if err != nil {
  8. return nil, err
  9. }
  10. }
  11. return map[string]string{
  12. "authorization": s.t.Type() + " " + s.t.AccessToken,
  13. }, nil
  14. }

然后将以 map[string]string{} 的形式组装成一个 string map 返回,如下:

  1. for k, v := range authData {
  2. headerFields = append(headerFields, hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)})
  3. }

返回的 map 会被遍历每个 key,并设置到 headerFields 中,以 http 头部的形式发送出去。数据最终会被 metadata.FromIncomingContext(ctx) 获取到,然后被取出 map 数据。

至此,client 和 server 的数据流转过程被打通