Source file
src/net/rpc/server.go
1
2
3
4
5
129 package rpc
130
131 import (
132 "bufio"
133 "encoding/gob"
134 "errors"
135 "go/token"
136 "io"
137 "log"
138 "net"
139 "net/http"
140 "reflect"
141 "strings"
142 "sync"
143 )
144
145 const (
146
147 DefaultRPCPath = "/_goRPC_"
148 DefaultDebugPath = "/debug/rpc"
149 )
150
151
152 var typeOfError = reflect.TypeFor[error]()
153
154 type methodType struct {
155 sync.Mutex
156 method reflect.Method
157 ArgType reflect.Type
158 ReplyType reflect.Type
159 numCalls uint
160 }
161
162 type service struct {
163 name string
164 rcvr reflect.Value
165 typ reflect.Type
166 method map[string]*methodType
167 }
168
169
170
171
172 type Request struct {
173 ServiceMethod string
174 Seq uint64
175 next *Request
176 }
177
178
179
180
181 type Response struct {
182 ServiceMethod string
183 Seq uint64
184 Error string
185 next *Response
186 }
187
188
189 type Server struct {
190 serviceMap sync.Map
191 reqLock sync.Mutex
192 freeReq *Request
193 respLock sync.Mutex
194 freeResp *Response
195 }
196
197
198 func NewServer() *Server {
199 return &Server{}
200 }
201
202
203 var DefaultServer = NewServer()
204
205
206 func isExportedOrBuiltinType(t reflect.Type) bool {
207 for t.Kind() == reflect.Pointer {
208 t = t.Elem()
209 }
210
211
212 return token.IsExported(t.Name()) || t.PkgPath() == ""
213 }
214
215
216
217
218
219
220
221
222
223
224
225
226 func (server *Server) Register(rcvr any) error {
227 return server.register(rcvr, "", false)
228 }
229
230
231
232 func (server *Server) RegisterName(name string, rcvr any) error {
233 return server.register(rcvr, name, true)
234 }
235
236
237
238 const logRegisterError = false
239
240 func (server *Server) register(rcvr any, name string, useName bool) error {
241 s := new(service)
242 s.typ = reflect.TypeOf(rcvr)
243 s.rcvr = reflect.ValueOf(rcvr)
244 sname := name
245 if !useName {
246 sname = reflect.Indirect(s.rcvr).Type().Name()
247 }
248 if sname == "" {
249 s := "rpc.Register: no service name for type " + s.typ.String()
250 log.Print(s)
251 return errors.New(s)
252 }
253 if !useName && !token.IsExported(sname) {
254 s := "rpc.Register: type " + sname + " is not exported"
255 log.Print(s)
256 return errors.New(s)
257 }
258 s.name = sname
259
260
261 s.method = suitableMethods(s.typ, logRegisterError)
262
263 if len(s.method) == 0 {
264 str := ""
265
266
267 method := suitableMethods(reflect.PointerTo(s.typ), false)
268 if len(method) != 0 {
269 str = "rpc.Register: type " + sname + " has no exported methods of suitable type (hint: pass a pointer to value of that type)"
270 } else {
271 str = "rpc.Register: type " + sname + " has no exported methods of suitable type"
272 }
273 log.Print(str)
274 return errors.New(str)
275 }
276
277 if _, dup := server.serviceMap.LoadOrStore(sname, s); dup {
278 return errors.New("rpc: service already defined: " + sname)
279 }
280 return nil
281 }
282
283
284
285 func suitableMethods(typ reflect.Type, logErr bool) map[string]*methodType {
286 methods := make(map[string]*methodType)
287 for m := 0; m < typ.NumMethod(); m++ {
288 method := typ.Method(m)
289 mtype := method.Type
290 mname := method.Name
291
292 if !method.IsExported() {
293 continue
294 }
295
296 if mtype.NumIn() != 3 {
297 if logErr {
298 log.Printf("rpc.Register: method %q has %d input parameters; needs exactly three\n", mname, mtype.NumIn())
299 }
300 continue
301 }
302
303 argType := mtype.In(1)
304 if !isExportedOrBuiltinType(argType) {
305 if logErr {
306 log.Printf("rpc.Register: argument type of method %q is not exported: %q\n", mname, argType)
307 }
308 continue
309 }
310
311 replyType := mtype.In(2)
312 if replyType.Kind() != reflect.Pointer {
313 if logErr {
314 log.Printf("rpc.Register: reply type of method %q is not a pointer: %q\n", mname, replyType)
315 }
316 continue
317 }
318
319 if !isExportedOrBuiltinType(replyType) {
320 if logErr {
321 log.Printf("rpc.Register: reply type of method %q is not exported: %q\n", mname, replyType)
322 }
323 continue
324 }
325
326 if mtype.NumOut() != 1 {
327 if logErr {
328 log.Printf("rpc.Register: method %q has %d output parameters; needs exactly one\n", mname, mtype.NumOut())
329 }
330 continue
331 }
332
333 if returnType := mtype.Out(0); returnType != typeOfError {
334 if logErr {
335 log.Printf("rpc.Register: return type of method %q is %q, must be error\n", mname, returnType)
336 }
337 continue
338 }
339 methods[mname] = &methodType{method: method, ArgType: argType, ReplyType: replyType}
340 }
341 return methods
342 }
343
344
345
346
347 var invalidRequest = struct{}{}
348
349 func (server *Server) sendResponse(sending *sync.Mutex, req *Request, reply any, codec ServerCodec, errmsg string) {
350 resp := server.getResponse()
351
352 resp.ServiceMethod = req.ServiceMethod
353 if errmsg != "" {
354 resp.Error = errmsg
355 reply = invalidRequest
356 }
357 resp.Seq = req.Seq
358 sending.Lock()
359 err := codec.WriteResponse(resp, reply)
360 if debugLog && err != nil {
361 log.Println("rpc: writing response:", err)
362 }
363 sending.Unlock()
364 server.freeResponse(resp)
365 }
366
367 func (m *methodType) NumCalls() (n uint) {
368 m.Lock()
369 n = m.numCalls
370 m.Unlock()
371 return n
372 }
373
374 func (s *service) call(server *Server, sending *sync.Mutex, wg *sync.WaitGroup, mtype *methodType, req *Request, argv, replyv reflect.Value, codec ServerCodec) {
375 if wg != nil {
376 defer wg.Done()
377 }
378 mtype.Lock()
379 mtype.numCalls++
380 mtype.Unlock()
381 function := mtype.method.Func
382
383 returnValues := function.Call([]reflect.Value{s.rcvr, argv, replyv})
384
385 errInter := returnValues[0].Interface()
386 errmsg := ""
387 if errInter != nil {
388 errmsg = errInter.(error).Error()
389 }
390 server.sendResponse(sending, req, replyv.Interface(), codec, errmsg)
391 server.freeRequest(req)
392 }
393
394 type gobServerCodec struct {
395 rwc io.ReadWriteCloser
396 dec *gob.Decoder
397 enc *gob.Encoder
398 encBuf *bufio.Writer
399 closed bool
400 }
401
402 func (c *gobServerCodec) ReadRequestHeader(r *Request) error {
403 return c.dec.Decode(r)
404 }
405
406 func (c *gobServerCodec) ReadRequestBody(body any) error {
407 return c.dec.Decode(body)
408 }
409
410 func (c *gobServerCodec) WriteResponse(r *Response, body any) (err error) {
411 if err = c.enc.Encode(r); err != nil {
412 if c.encBuf.Flush() == nil {
413
414
415 log.Println("rpc: gob error encoding response:", err)
416 c.Close()
417 }
418 return
419 }
420 if err = c.enc.Encode(body); err != nil {
421 if c.encBuf.Flush() == nil {
422
423
424 log.Println("rpc: gob error encoding body:", err)
425 c.Close()
426 }
427 return
428 }
429 return c.encBuf.Flush()
430 }
431
432 func (c *gobServerCodec) Close() error {
433 if c.closed {
434
435 return nil
436 }
437 c.closed = true
438 return c.rwc.Close()
439 }
440
441
442
443
444
445
446
447 func (server *Server) ServeConn(conn io.ReadWriteCloser) {
448 buf := bufio.NewWriter(conn)
449 srv := &gobServerCodec{
450 rwc: conn,
451 dec: gob.NewDecoder(conn),
452 enc: gob.NewEncoder(buf),
453 encBuf: buf,
454 }
455 server.ServeCodec(srv)
456 }
457
458
459
460 func (server *Server) ServeCodec(codec ServerCodec) {
461 sending := new(sync.Mutex)
462 wg := new(sync.WaitGroup)
463 for {
464 service, mtype, req, argv, replyv, keepReading, err := server.readRequest(codec)
465 if err != nil {
466 if debugLog && err != io.EOF {
467 log.Println("rpc:", err)
468 }
469 if !keepReading {
470 break
471 }
472
473 if req != nil {
474 server.sendResponse(sending, req, invalidRequest, codec, err.Error())
475 server.freeRequest(req)
476 }
477 continue
478 }
479 wg.Add(1)
480 go service.call(server, sending, wg, mtype, req, argv, replyv, codec)
481 }
482
483
484 wg.Wait()
485 codec.Close()
486 }
487
488
489
490 func (server *Server) ServeRequest(codec ServerCodec) error {
491 sending := new(sync.Mutex)
492 service, mtype, req, argv, replyv, keepReading, err := server.readRequest(codec)
493 if err != nil {
494 if !keepReading {
495 return err
496 }
497
498 if req != nil {
499 server.sendResponse(sending, req, invalidRequest, codec, err.Error())
500 server.freeRequest(req)
501 }
502 return err
503 }
504 service.call(server, sending, nil, mtype, req, argv, replyv, codec)
505 return nil
506 }
507
508 func (server *Server) getRequest() *Request {
509 server.reqLock.Lock()
510 req := server.freeReq
511 if req == nil {
512 req = new(Request)
513 } else {
514 server.freeReq = req.next
515 *req = Request{}
516 }
517 server.reqLock.Unlock()
518 return req
519 }
520
521 func (server *Server) freeRequest(req *Request) {
522 server.reqLock.Lock()
523 req.next = server.freeReq
524 server.freeReq = req
525 server.reqLock.Unlock()
526 }
527
528 func (server *Server) getResponse() *Response {
529 server.respLock.Lock()
530 resp := server.freeResp
531 if resp == nil {
532 resp = new(Response)
533 } else {
534 server.freeResp = resp.next
535 *resp = Response{}
536 }
537 server.respLock.Unlock()
538 return resp
539 }
540
541 func (server *Server) freeResponse(resp *Response) {
542 server.respLock.Lock()
543 resp.next = server.freeResp
544 server.freeResp = resp
545 server.respLock.Unlock()
546 }
547
548 func (server *Server) readRequest(codec ServerCodec) (service *service, mtype *methodType, req *Request, argv, replyv reflect.Value, keepReading bool, err error) {
549 service, mtype, req, keepReading, err = server.readRequestHeader(codec)
550 if err != nil {
551 if !keepReading {
552 return
553 }
554
555 codec.ReadRequestBody(nil)
556 return
557 }
558
559
560 argIsValue := false
561 if mtype.ArgType.Kind() == reflect.Pointer {
562 argv = reflect.New(mtype.ArgType.Elem())
563 } else {
564 argv = reflect.New(mtype.ArgType)
565 argIsValue = true
566 }
567
568 if err = codec.ReadRequestBody(argv.Interface()); err != nil {
569 return
570 }
571 if argIsValue {
572 argv = argv.Elem()
573 }
574
575 replyv = reflect.New(mtype.ReplyType.Elem())
576
577 switch mtype.ReplyType.Elem().Kind() {
578 case reflect.Map:
579 replyv.Elem().Set(reflect.MakeMap(mtype.ReplyType.Elem()))
580 case reflect.Slice:
581 replyv.Elem().Set(reflect.MakeSlice(mtype.ReplyType.Elem(), 0, 0))
582 }
583 return
584 }
585
586 func (server *Server) readRequestHeader(codec ServerCodec) (svc *service, mtype *methodType, req *Request, keepReading bool, err error) {
587
588 req = server.getRequest()
589 err = codec.ReadRequestHeader(req)
590 if err != nil {
591 req = nil
592 if err == io.EOF || err == io.ErrUnexpectedEOF {
593 return
594 }
595 err = errors.New("rpc: server cannot decode request: " + err.Error())
596 return
597 }
598
599
600
601 keepReading = true
602
603 dot := strings.LastIndex(req.ServiceMethod, ".")
604 if dot < 0 {
605 err = errors.New("rpc: service/method request ill-formed: " + req.ServiceMethod)
606 return
607 }
608 serviceName := req.ServiceMethod[:dot]
609 methodName := req.ServiceMethod[dot+1:]
610
611
612 svci, ok := server.serviceMap.Load(serviceName)
613 if !ok {
614 err = errors.New("rpc: can't find service " + req.ServiceMethod)
615 return
616 }
617 svc = svci.(*service)
618 mtype = svc.method[methodName]
619 if mtype == nil {
620 err = errors.New("rpc: can't find method " + req.ServiceMethod)
621 }
622 return
623 }
624
625
626
627
628
629 func (server *Server) Accept(lis net.Listener) {
630 for {
631 conn, err := lis.Accept()
632 if err != nil {
633 log.Print("rpc.Serve: accept:", err.Error())
634 return
635 }
636 go server.ServeConn(conn)
637 }
638 }
639
640
641 func Register(rcvr any) error { return DefaultServer.Register(rcvr) }
642
643
644
645 func RegisterName(name string, rcvr any) error {
646 return DefaultServer.RegisterName(name, rcvr)
647 }
648
649
650
651
652
653
654
655
656
657 type ServerCodec interface {
658 ReadRequestHeader(*Request) error
659 ReadRequestBody(any) error
660 WriteResponse(*Response, any) error
661
662
663 Close() error
664 }
665
666
667
668
669
670
671
672 func ServeConn(conn io.ReadWriteCloser) {
673 DefaultServer.ServeConn(conn)
674 }
675
676
677
678 func ServeCodec(codec ServerCodec) {
679 DefaultServer.ServeCodec(codec)
680 }
681
682
683
684 func ServeRequest(codec ServerCodec) error {
685 return DefaultServer.ServeRequest(codec)
686 }
687
688
689
690
691 func Accept(lis net.Listener) { DefaultServer.Accept(lis) }
692
693
694 var connected = "200 Connected to Go RPC"
695
696
697 func (server *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
698 if req.Method != "CONNECT" {
699 w.Header().Set("Content-Type", "text/plain; charset=utf-8")
700 w.WriteHeader(http.StatusMethodNotAllowed)
701 io.WriteString(w, "405 must CONNECT\n")
702 return
703 }
704 conn, _, err := w.(http.Hijacker).Hijack()
705 if err != nil {
706 log.Print("rpc hijacking ", req.RemoteAddr, ": ", err.Error())
707 return
708 }
709 io.WriteString(conn, "HTTP/1.0 "+connected+"\n\n")
710 server.ServeConn(conn)
711 }
712
713
714
715
716 func (server *Server) HandleHTTP(rpcPath, debugPath string) {
717 http.Handle(rpcPath, server)
718 http.Handle(debugPath, debugHTTP{server})
719 }
720
721
722
723
724 func HandleHTTP() {
725 DefaultServer.HandleHTTP(DefaultRPCPath, DefaultDebugPath)
726 }
727
View as plain text