1
2
3
4
5
6
7 package httputil
8
9 import (
10 "context"
11 "errors"
12 "fmt"
13 "io"
14 "log"
15 "mime"
16 "net"
17 "net/http"
18 "net/http/httptrace"
19 "net/http/internal/ascii"
20 "net/textproto"
21 "net/url"
22 "strings"
23 "sync"
24 "time"
25
26 "golang.org/x/net/http/httpguts"
27 )
28
29
30 type ProxyRequest struct {
31
32
33 In *http.Request
34
35
36
37
38
39 Out *http.Request
40 }
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56 func (r *ProxyRequest) SetURL(target *url.URL) {
57 rewriteRequestURL(r.Out, target)
58 r.Out.Host = ""
59 }
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80 func (r *ProxyRequest) SetXForwarded() {
81 clientIP, _, err := net.SplitHostPort(r.In.RemoteAddr)
82 if err == nil {
83 prior := r.Out.Header["X-Forwarded-For"]
84 if len(prior) > 0 {
85 clientIP = strings.Join(prior, ", ") + ", " + clientIP
86 }
87 r.Out.Header.Set("X-Forwarded-For", clientIP)
88 } else {
89 r.Out.Header.Del("X-Forwarded-For")
90 }
91 r.Out.Header.Set("X-Forwarded-Host", r.In.Host)
92 if r.In.TLS == nil {
93 r.Out.Header.Set("X-Forwarded-Proto", "http")
94 } else {
95 r.Out.Header.Set("X-Forwarded-Proto", "https")
96 }
97 }
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112 type ReverseProxy struct {
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134 Rewrite func(*ProxyRequest)
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164 Director func(*http.Request)
165
166
167
168 Transport http.RoundTripper
169
170
171
172
173
174
175
176
177
178
179
180 FlushInterval time.Duration
181
182
183
184
185 ErrorLog *log.Logger
186
187
188
189
190 BufferPool BufferPool
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205 ModifyResponse func(*http.Response) error
206
207
208
209
210
211
212 ErrorHandler func(http.ResponseWriter, *http.Request, error)
213 }
214
215
216
217 type BufferPool interface {
218 Get() []byte
219 Put([]byte)
220 }
221
222 func singleJoiningSlash(a, b string) string {
223 aslash := strings.HasSuffix(a, "/")
224 bslash := strings.HasPrefix(b, "/")
225 switch {
226 case aslash && bslash:
227 return a + b[1:]
228 case !aslash && !bslash:
229 return a + "/" + b
230 }
231 return a + b
232 }
233
234 func joinURLPath(a, b *url.URL) (path, rawpath string) {
235 if a.RawPath == "" && b.RawPath == "" {
236 return singleJoiningSlash(a.Path, b.Path), ""
237 }
238
239
240 apath := a.EscapedPath()
241 bpath := b.EscapedPath()
242
243 aslash := strings.HasSuffix(apath, "/")
244 bslash := strings.HasPrefix(bpath, "/")
245
246 switch {
247 case aslash && bslash:
248 return a.Path + b.Path[1:], apath + bpath[1:]
249 case !aslash && !bslash:
250 return a.Path + "/" + b.Path, apath + "/" + bpath
251 }
252 return a.Path + b.Path, apath + bpath
253 }
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275 func NewSingleHostReverseProxy(target *url.URL) *ReverseProxy {
276 director := func(req *http.Request) {
277 rewriteRequestURL(req, target)
278 }
279 return &ReverseProxy{Director: director}
280 }
281
282 func rewriteRequestURL(req *http.Request, target *url.URL) {
283 targetQuery := target.RawQuery
284 req.URL.Scheme = target.Scheme
285 req.URL.Host = target.Host
286 req.URL.Path, req.URL.RawPath = joinURLPath(target, req.URL)
287 if targetQuery == "" || req.URL.RawQuery == "" {
288 req.URL.RawQuery = targetQuery + req.URL.RawQuery
289 } else {
290 req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery
291 }
292 }
293
294 func copyHeader(dst, src http.Header) {
295 for k, vv := range src {
296 for _, v := range vv {
297 dst.Add(k, v)
298 }
299 }
300 }
301
302
303
304
305
306
307 var hopHeaders = []string{
308 "Connection",
309 "Proxy-Connection",
310 "Keep-Alive",
311 "Proxy-Authenticate",
312 "Proxy-Authorization",
313 "Te",
314 "Trailer",
315 "Transfer-Encoding",
316 "Upgrade",
317 }
318
319 func (p *ReverseProxy) defaultErrorHandler(rw http.ResponseWriter, req *http.Request, err error) {
320 p.logf("http: proxy error: %v", err)
321 rw.WriteHeader(http.StatusBadGateway)
322 }
323
324 func (p *ReverseProxy) getErrorHandler() func(http.ResponseWriter, *http.Request, error) {
325 if p.ErrorHandler != nil {
326 return p.ErrorHandler
327 }
328 return p.defaultErrorHandler
329 }
330
331
332
333 func (p *ReverseProxy) modifyResponse(rw http.ResponseWriter, res *http.Response, req *http.Request) bool {
334 if p.ModifyResponse == nil {
335 return true
336 }
337 if err := p.ModifyResponse(res); err != nil {
338 res.Body.Close()
339 p.getErrorHandler()(rw, req, err)
340 return false
341 }
342 return true
343 }
344
345 func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
346 transport := p.Transport
347 if transport == nil {
348 transport = http.DefaultTransport
349 }
350
351 ctx := req.Context()
352 if ctx.Done() != nil {
353
354
355
356
357
358
359
360
361
362
363 } else if cn, ok := rw.(http.CloseNotifier); ok {
364 var cancel context.CancelFunc
365 ctx, cancel = context.WithCancel(ctx)
366 defer cancel()
367 notifyChan := cn.CloseNotify()
368 go func() {
369 select {
370 case <-notifyChan:
371 cancel()
372 case <-ctx.Done():
373 }
374 }()
375 }
376
377 outreq := req.Clone(ctx)
378 if req.ContentLength == 0 {
379 outreq.Body = nil
380 }
381 if outreq.Body != nil {
382
383
384
385
386
387
388 defer outreq.Body.Close()
389 }
390 if outreq.Header == nil {
391 outreq.Header = make(http.Header)
392 }
393
394 if (p.Director != nil) == (p.Rewrite != nil) {
395 p.getErrorHandler()(rw, req, errors.New("ReverseProxy must have exactly one of Director or Rewrite set"))
396 return
397 }
398
399 if p.Director != nil {
400 p.Director(outreq)
401 if outreq.Form != nil {
402 outreq.URL.RawQuery = cleanQueryParams(outreq.URL.RawQuery)
403 }
404 }
405 outreq.Close = false
406
407 reqUpType := upgradeType(outreq.Header)
408 if !ascii.IsPrint(reqUpType) {
409 p.getErrorHandler()(rw, req, fmt.Errorf("client tried to switch to invalid protocol %q", reqUpType))
410 return
411 }
412 removeHopByHopHeaders(outreq.Header)
413
414
415
416
417
418
419 if httpguts.HeaderValuesContainsToken(req.Header["Te"], "trailers") {
420 outreq.Header.Set("Te", "trailers")
421 }
422
423
424
425 if reqUpType != "" {
426 outreq.Header.Set("Connection", "Upgrade")
427 outreq.Header.Set("Upgrade", reqUpType)
428 }
429
430 if p.Rewrite != nil {
431
432
433
434 outreq.Header.Del("Forwarded")
435 outreq.Header.Del("X-Forwarded-For")
436 outreq.Header.Del("X-Forwarded-Host")
437 outreq.Header.Del("X-Forwarded-Proto")
438
439
440 outreq.URL.RawQuery = cleanQueryParams(outreq.URL.RawQuery)
441
442 pr := &ProxyRequest{
443 In: req,
444 Out: outreq,
445 }
446 p.Rewrite(pr)
447 outreq = pr.Out
448 } else {
449 if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
450
451
452
453 prior, ok := outreq.Header["X-Forwarded-For"]
454 omit := ok && prior == nil
455 if len(prior) > 0 {
456 clientIP = strings.Join(prior, ", ") + ", " + clientIP
457 }
458 if !omit {
459 outreq.Header.Set("X-Forwarded-For", clientIP)
460 }
461 }
462 }
463
464 if _, ok := outreq.Header["User-Agent"]; !ok {
465
466
467 outreq.Header.Set("User-Agent", "")
468 }
469
470 var (
471 roundTripMutex sync.Mutex
472 roundTripDone bool
473 )
474 trace := &httptrace.ClientTrace{
475 Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
476 roundTripMutex.Lock()
477 defer roundTripMutex.Unlock()
478 if roundTripDone {
479
480
481 return nil
482 }
483 h := rw.Header()
484 copyHeader(h, http.Header(header))
485 rw.WriteHeader(code)
486
487
488 clear(h)
489 return nil
490 },
491 }
492 outreq = outreq.WithContext(httptrace.WithClientTrace(outreq.Context(), trace))
493
494 res, err := transport.RoundTrip(outreq)
495 roundTripMutex.Lock()
496 roundTripDone = true
497 roundTripMutex.Unlock()
498 if err != nil {
499 p.getErrorHandler()(rw, outreq, err)
500 return
501 }
502
503
504 if res.StatusCode == http.StatusSwitchingProtocols {
505 if !p.modifyResponse(rw, res, outreq) {
506 return
507 }
508 p.handleUpgradeResponse(rw, outreq, res)
509 return
510 }
511
512 removeHopByHopHeaders(res.Header)
513
514 if !p.modifyResponse(rw, res, outreq) {
515 return
516 }
517
518 copyHeader(rw.Header(), res.Header)
519
520
521
522 announcedTrailers := len(res.Trailer)
523 if announcedTrailers > 0 {
524 trailerKeys := make([]string, 0, len(res.Trailer))
525 for k := range res.Trailer {
526 trailerKeys = append(trailerKeys, k)
527 }
528 rw.Header().Add("Trailer", strings.Join(trailerKeys, ", "))
529 }
530
531 rw.WriteHeader(res.StatusCode)
532
533 err = p.copyResponse(rw, res.Body, p.flushInterval(res))
534 if err != nil {
535 defer res.Body.Close()
536
537
538
539 if !shouldPanicOnCopyError(req) {
540 p.logf("suppressing panic for copyResponse error in test; copy error: %v", err)
541 return
542 }
543 panic(http.ErrAbortHandler)
544 }
545 res.Body.Close()
546
547 if len(res.Trailer) > 0 {
548
549
550
551 http.NewResponseController(rw).Flush()
552 }
553
554 if len(res.Trailer) == announcedTrailers {
555 copyHeader(rw.Header(), res.Trailer)
556 return
557 }
558
559 for k, vv := range res.Trailer {
560 k = http.TrailerPrefix + k
561 for _, v := range vv {
562 rw.Header().Add(k, v)
563 }
564 }
565 }
566
567 var inOurTests bool
568
569
570
571
572
573
574 func shouldPanicOnCopyError(req *http.Request) bool {
575 if inOurTests {
576
577 return true
578 }
579 if req.Context().Value(http.ServerContextKey) != nil {
580
581
582 return true
583 }
584
585
586 return false
587 }
588
589
590 func removeHopByHopHeaders(h http.Header) {
591
592 for _, f := range h["Connection"] {
593 for sf := range strings.SplitSeq(f, ",") {
594 if sf = textproto.TrimString(sf); sf != "" {
595 h.Del(sf)
596 }
597 }
598 }
599
600
601
602 for _, f := range hopHeaders {
603 h.Del(f)
604 }
605 }
606
607
608
609 func (p *ReverseProxy) flushInterval(res *http.Response) time.Duration {
610 resCT := res.Header.Get("Content-Type")
611
612
613
614 if baseCT, _, _ := mime.ParseMediaType(resCT); baseCT == "text/event-stream" {
615 return -1
616 }
617
618
619 if res.ContentLength == -1 {
620 return -1
621 }
622
623 return p.FlushInterval
624 }
625
626 func (p *ReverseProxy) copyResponse(dst http.ResponseWriter, src io.Reader, flushInterval time.Duration) error {
627 var w io.Writer = dst
628
629 if flushInterval != 0 {
630 mlw := &maxLatencyWriter{
631 dst: dst,
632 flush: http.NewResponseController(dst).Flush,
633 latency: flushInterval,
634 }
635 defer mlw.stop()
636
637
638 mlw.flushPending = true
639 mlw.t = time.AfterFunc(flushInterval, mlw.delayedFlush)
640
641 w = mlw
642 }
643
644 var buf []byte
645 if p.BufferPool != nil {
646 buf = p.BufferPool.Get()
647 defer p.BufferPool.Put(buf)
648 }
649 _, err := p.copyBuffer(w, src, buf)
650 return err
651 }
652
653
654
655 func (p *ReverseProxy) copyBuffer(dst io.Writer, src io.Reader, buf []byte) (int64, error) {
656 if len(buf) == 0 {
657 buf = make([]byte, 32*1024)
658 }
659 var written int64
660 for {
661 nr, rerr := src.Read(buf)
662 if rerr != nil && rerr != io.EOF && rerr != context.Canceled {
663 p.logf("httputil: ReverseProxy read error during body copy: %v", rerr)
664 }
665 if nr > 0 {
666 nw, werr := dst.Write(buf[:nr])
667 if nw > 0 {
668 written += int64(nw)
669 }
670 if werr != nil {
671 return written, werr
672 }
673 if nr != nw {
674 return written, io.ErrShortWrite
675 }
676 }
677 if rerr != nil {
678 if rerr == io.EOF {
679 rerr = nil
680 }
681 return written, rerr
682 }
683 }
684 }
685
686 func (p *ReverseProxy) logf(format string, args ...any) {
687 if p.ErrorLog != nil {
688 p.ErrorLog.Printf(format, args...)
689 } else {
690 log.Printf(format, args...)
691 }
692 }
693
694 type maxLatencyWriter struct {
695 dst io.Writer
696 flush func() error
697 latency time.Duration
698
699 mu sync.Mutex
700 t *time.Timer
701 flushPending bool
702 }
703
704 func (m *maxLatencyWriter) Write(p []byte) (n int, err error) {
705 m.mu.Lock()
706 defer m.mu.Unlock()
707 n, err = m.dst.Write(p)
708 if m.latency < 0 {
709 m.flush()
710 return
711 }
712 if m.flushPending {
713 return
714 }
715 if m.t == nil {
716 m.t = time.AfterFunc(m.latency, m.delayedFlush)
717 } else {
718 m.t.Reset(m.latency)
719 }
720 m.flushPending = true
721 return
722 }
723
724 func (m *maxLatencyWriter) delayedFlush() {
725 m.mu.Lock()
726 defer m.mu.Unlock()
727 if !m.flushPending {
728 return
729 }
730 m.flush()
731 m.flushPending = false
732 }
733
734 func (m *maxLatencyWriter) stop() {
735 m.mu.Lock()
736 defer m.mu.Unlock()
737 m.flushPending = false
738 if m.t != nil {
739 m.t.Stop()
740 }
741 }
742
743 func upgradeType(h http.Header) string {
744 if !httpguts.HeaderValuesContainsToken(h["Connection"], "Upgrade") {
745 return ""
746 }
747 return h.Get("Upgrade")
748 }
749
750 func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.Request, res *http.Response) {
751 reqUpType := upgradeType(req.Header)
752 resUpType := upgradeType(res.Header)
753 if !ascii.IsPrint(resUpType) {
754 p.getErrorHandler()(rw, req, fmt.Errorf("backend tried to switch to invalid protocol %q", resUpType))
755 return
756 }
757 if !ascii.EqualFold(reqUpType, resUpType) {
758 p.getErrorHandler()(rw, req, fmt.Errorf("backend tried to switch protocol %q when %q was requested", resUpType, reqUpType))
759 return
760 }
761
762 backConn, ok := res.Body.(io.ReadWriteCloser)
763 if !ok {
764 p.getErrorHandler()(rw, req, fmt.Errorf("internal error: 101 switching protocols response with non-writable body"))
765 return
766 }
767
768 rc := http.NewResponseController(rw)
769 conn, brw, hijackErr := rc.Hijack()
770 if errors.Is(hijackErr, http.ErrNotSupported) {
771 p.getErrorHandler()(rw, req, fmt.Errorf("can't switch protocols using non-Hijacker ResponseWriter type %T", rw))
772 return
773 }
774
775 backConnCloseCh := make(chan bool)
776 go func() {
777
778
779 select {
780 case <-req.Context().Done():
781 case <-backConnCloseCh:
782 }
783 backConn.Close()
784 }()
785 defer close(backConnCloseCh)
786
787 if hijackErr != nil {
788 p.getErrorHandler()(rw, req, fmt.Errorf("Hijack failed on protocol switch: %v", hijackErr))
789 return
790 }
791 defer conn.Close()
792
793 copyHeader(rw.Header(), res.Header)
794
795 res.Header = rw.Header()
796 res.Body = nil
797 if err := res.Write(brw); err != nil {
798 p.getErrorHandler()(rw, req, fmt.Errorf("response write: %v", err))
799 return
800 }
801 if err := brw.Flush(); err != nil {
802 p.getErrorHandler()(rw, req, fmt.Errorf("response flush: %v", err))
803 return
804 }
805 errc := make(chan error, 1)
806 spc := switchProtocolCopier{user: conn, backend: backConn}
807 go spc.copyToBackend(errc)
808 go spc.copyFromBackend(errc)
809
810
811
812 err := <-errc
813 if err == nil {
814 err = <-errc
815 }
816 }
817
818 var errCopyDone = errors.New("hijacked connection copy complete")
819
820
821
822 type switchProtocolCopier struct {
823 user, backend io.ReadWriter
824 }
825
826 func (c switchProtocolCopier) copyFromBackend(errc chan<- error) {
827 if _, err := io.Copy(c.user, c.backend); err != nil {
828 errc <- err
829 return
830 }
831
832
833 if wc, ok := c.user.(interface{ CloseWrite() error }); ok {
834 errc <- wc.CloseWrite()
835 return
836 }
837
838 errc <- errCopyDone
839 }
840
841 func (c switchProtocolCopier) copyToBackend(errc chan<- error) {
842 if _, err := io.Copy(c.backend, c.user); err != nil {
843 errc <- err
844 return
845 }
846
847
848 if wc, ok := c.backend.(interface{ CloseWrite() error }); ok {
849 errc <- wc.CloseWrite()
850 return
851 }
852
853 errc <- errCopyDone
854 }
855
856 func cleanQueryParams(s string) string {
857 reencode := func(s string) string {
858 v, _ := url.ParseQuery(s)
859 return v.Encode()
860 }
861 for i := 0; i < len(s); {
862 switch s[i] {
863 case ';':
864 return reencode(s)
865 case '%':
866 if i+2 >= len(s) || !ishex(s[i+1]) || !ishex(s[i+2]) {
867 return reencode(s)
868 }
869 i += 3
870 default:
871 i++
872 }
873 }
874 return s
875 }
876
877 func ishex(c byte) bool {
878 switch {
879 case '0' <= c && c <= '9':
880 return true
881 case 'a' <= c && c <= 'f':
882 return true
883 case 'A' <= c && c <= 'F':
884 return true
885 }
886 return false
887 }
888
View as plain text