1
2
3
4
5
6
7 package httputil
8
9 import (
10 "bufio"
11 "bytes"
12 "context"
13 "errors"
14 "fmt"
15 "io"
16 "log"
17 "net"
18 "net/http"
19 "net/http/httptest"
20 "net/http/httptrace"
21 "net/http/internal/ascii"
22 "net/textproto"
23 "net/url"
24 "os"
25 "reflect"
26 "runtime"
27 "slices"
28 "strconv"
29 "strings"
30 "sync"
31 "testing"
32 "time"
33 )
34
35 const fakeHopHeader = "X-Fake-Hop-Header-For-Test"
36
37 func init() {
38 inOurTests = true
39 hopHeaders = append(hopHeaders, fakeHopHeader)
40 }
41
42 func TestReverseProxy(t *testing.T) {
43 const backendResponse = "I am the backend"
44 const backendStatus = 404
45 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
46 if r.Method == "GET" && r.FormValue("mode") == "hangup" {
47 c, _, _ := w.(http.Hijacker).Hijack()
48 c.Close()
49 return
50 }
51 if len(r.TransferEncoding) > 0 {
52 t.Errorf("backend got unexpected TransferEncoding: %v", r.TransferEncoding)
53 }
54 if r.Header.Get("X-Forwarded-For") == "" {
55 t.Errorf("didn't get X-Forwarded-For header")
56 }
57 if c := r.Header.Get("Connection"); c != "" {
58 t.Errorf("handler got Connection header value %q", c)
59 }
60 if c := r.Header.Get("Te"); c != "trailers" {
61 t.Errorf("handler got Te header value %q; want 'trailers'", c)
62 }
63 if c := r.Header.Get("Upgrade"); c != "" {
64 t.Errorf("handler got Upgrade header value %q", c)
65 }
66 if c := r.Header.Get("Proxy-Connection"); c != "" {
67 t.Errorf("handler got Proxy-Connection header value %q", c)
68 }
69 if g, e := r.Host, "some-name"; g != e {
70 t.Errorf("backend got Host header %q, want %q", g, e)
71 }
72 w.Header().Set("Trailers", "not a special header field name")
73 w.Header().Set("Trailer", "X-Trailer")
74 w.Header().Set("X-Foo", "bar")
75 w.Header().Set("Upgrade", "foo")
76 w.Header().Set(fakeHopHeader, "foo")
77 w.Header().Add("X-Multi-Value", "foo")
78 w.Header().Add("X-Multi-Value", "bar")
79 http.SetCookie(w, &http.Cookie{Name: "flavor", Value: "chocolateChip"})
80 w.WriteHeader(backendStatus)
81 w.Write([]byte(backendResponse))
82 w.Header().Set("X-Trailer", "trailer_value")
83 w.Header().Set(http.TrailerPrefix+"X-Unannounced-Trailer", "unannounced_trailer_value")
84 }))
85 defer backend.Close()
86 backendURL, err := url.Parse(backend.URL)
87 if err != nil {
88 t.Fatal(err)
89 }
90 proxyHandler := NewSingleHostReverseProxy(backendURL)
91 proxyHandler.ErrorLog = log.New(io.Discard, "", 0)
92 frontend := httptest.NewServer(proxyHandler)
93 defer frontend.Close()
94 frontendClient := frontend.Client()
95
96 getReq, _ := http.NewRequest("GET", frontend.URL, nil)
97 getReq.Host = "some-name"
98 getReq.Header.Set("Connection", "close, TE")
99 getReq.Header.Add("Te", "foo")
100 getReq.Header.Add("Te", "bar, trailers")
101 getReq.Header.Set("Proxy-Connection", "should be deleted")
102 getReq.Header.Set("Upgrade", "foo")
103 getReq.Close = true
104 res, err := frontendClient.Do(getReq)
105 if err != nil {
106 t.Fatalf("Get: %v", err)
107 }
108 if g, e := res.StatusCode, backendStatus; g != e {
109 t.Errorf("got res.StatusCode %d; expected %d", g, e)
110 }
111 if g, e := res.Header.Get("X-Foo"), "bar"; g != e {
112 t.Errorf("got X-Foo %q; expected %q", g, e)
113 }
114 if c := res.Header.Get(fakeHopHeader); c != "" {
115 t.Errorf("got %s header value %q", fakeHopHeader, c)
116 }
117 if g, e := res.Header.Get("Trailers"), "not a special header field name"; g != e {
118 t.Errorf("header Trailers = %q; want %q", g, e)
119 }
120 if g, e := len(res.Header["X-Multi-Value"]), 2; g != e {
121 t.Errorf("got %d X-Multi-Value header values; expected %d", g, e)
122 }
123 if g, e := len(res.Header["Set-Cookie"]), 1; g != e {
124 t.Fatalf("got %d SetCookies, want %d", g, e)
125 }
126 if g, e := res.Trailer, (http.Header{"X-Trailer": nil}); !reflect.DeepEqual(g, e) {
127 t.Errorf("before reading body, Trailer = %#v; want %#v", g, e)
128 }
129 if cookie := res.Cookies()[0]; cookie.Name != "flavor" {
130 t.Errorf("unexpected cookie %q", cookie.Name)
131 }
132 bodyBytes, _ := io.ReadAll(res.Body)
133 if g, e := string(bodyBytes), backendResponse; g != e {
134 t.Errorf("got body %q; expected %q", g, e)
135 }
136 if g, e := res.Trailer.Get("X-Trailer"), "trailer_value"; g != e {
137 t.Errorf("Trailer(X-Trailer) = %q ; want %q", g, e)
138 }
139 if g, e := res.Trailer.Get("X-Unannounced-Trailer"), "unannounced_trailer_value"; g != e {
140 t.Errorf("Trailer(X-Unannounced-Trailer) = %q ; want %q", g, e)
141 }
142 res.Body.Close()
143
144
145
146 getReq, _ = http.NewRequest("GET", frontend.URL+"/?mode=hangup", nil)
147 getReq.Close = true
148 res, err = frontendClient.Do(getReq)
149 if err != nil {
150 t.Fatal(err)
151 }
152 res.Body.Close()
153 if res.StatusCode != http.StatusBadGateway {
154 t.Errorf("request to bad proxy = %v; want 502 StatusBadGateway", res.Status)
155 }
156
157 }
158
159
160
161 func TestReverseProxyStripHeadersPresentInConnection(t *testing.T) {
162 const fakeConnectionToken = "X-Fake-Connection-Token"
163 const backendResponse = "I am the backend"
164
165
166
167 const someConnHeader = "X-Some-Conn-Header"
168
169 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
170 if c := r.Header.Get("Connection"); c != "" {
171 t.Errorf("handler got header %q = %q; want empty", "Connection", c)
172 }
173 if c := r.Header.Get(fakeConnectionToken); c != "" {
174 t.Errorf("handler got header %q = %q; want empty", fakeConnectionToken, c)
175 }
176 if c := r.Header.Get(someConnHeader); c != "" {
177 t.Errorf("handler got header %q = %q; want empty", someConnHeader, c)
178 }
179 w.Header().Add("Connection", "Upgrade, "+fakeConnectionToken)
180 w.Header().Add("Connection", someConnHeader)
181 w.Header().Set(someConnHeader, "should be deleted")
182 w.Header().Set(fakeConnectionToken, "should be deleted")
183 io.WriteString(w, backendResponse)
184 }))
185 defer backend.Close()
186 backendURL, err := url.Parse(backend.URL)
187 if err != nil {
188 t.Fatal(err)
189 }
190 proxyHandler := NewSingleHostReverseProxy(backendURL)
191 frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
192 proxyHandler.ServeHTTP(w, r)
193 if c := r.Header.Get(someConnHeader); c != "should be deleted" {
194 t.Errorf("handler modified header %q = %q; want %q", someConnHeader, c, "should be deleted")
195 }
196 if c := r.Header.Get(fakeConnectionToken); c != "should be deleted" {
197 t.Errorf("handler modified header %q = %q; want %q", fakeConnectionToken, c, "should be deleted")
198 }
199 c := r.Header["Connection"]
200 var cf []string
201 for _, f := range c {
202 for sf := range strings.SplitSeq(f, ",") {
203 if sf = strings.TrimSpace(sf); sf != "" {
204 cf = append(cf, sf)
205 }
206 }
207 }
208 slices.Sort(cf)
209 expectedValues := []string{"Upgrade", someConnHeader, fakeConnectionToken}
210 slices.Sort(expectedValues)
211 if !slices.Equal(cf, expectedValues) {
212 t.Errorf("handler modified header %q = %q; want %q", "Connection", cf, expectedValues)
213 }
214 }))
215 defer frontend.Close()
216
217 getReq, _ := http.NewRequest("GET", frontend.URL, nil)
218 getReq.Header.Add("Connection", "Upgrade, "+fakeConnectionToken)
219 getReq.Header.Add("Connection", someConnHeader)
220 getReq.Header.Set(someConnHeader, "should be deleted")
221 getReq.Header.Set(fakeConnectionToken, "should be deleted")
222 res, err := frontend.Client().Do(getReq)
223 if err != nil {
224 t.Fatalf("Get: %v", err)
225 }
226 defer res.Body.Close()
227 bodyBytes, err := io.ReadAll(res.Body)
228 if err != nil {
229 t.Fatalf("reading body: %v", err)
230 }
231 if got, want := string(bodyBytes), backendResponse; got != want {
232 t.Errorf("got body %q; want %q", got, want)
233 }
234 if c := res.Header.Get("Connection"); c != "" {
235 t.Errorf("handler got header %q = %q; want empty", "Connection", c)
236 }
237 if c := res.Header.Get(someConnHeader); c != "" {
238 t.Errorf("handler got header %q = %q; want empty", someConnHeader, c)
239 }
240 if c := res.Header.Get(fakeConnectionToken); c != "" {
241 t.Errorf("handler got header %q = %q; want empty", fakeConnectionToken, c)
242 }
243 }
244
245 func TestReverseProxyStripEmptyConnection(t *testing.T) {
246
247 const backendResponse = "I am the backend"
248
249
250
251 const someConnHeader = "X-Some-Conn-Header"
252
253 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
254 if c := r.Header.Values("Connection"); len(c) != 0 {
255 t.Errorf("handler got header %q = %v; want empty", "Connection", c)
256 }
257 if c := r.Header.Get(someConnHeader); c != "" {
258 t.Errorf("handler got header %q = %q; want empty", someConnHeader, c)
259 }
260 w.Header().Add("Connection", "")
261 w.Header().Add("Connection", someConnHeader)
262 w.Header().Set(someConnHeader, "should be deleted")
263 io.WriteString(w, backendResponse)
264 }))
265 defer backend.Close()
266 backendURL, err := url.Parse(backend.URL)
267 if err != nil {
268 t.Fatal(err)
269 }
270 proxyHandler := NewSingleHostReverseProxy(backendURL)
271 frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
272 proxyHandler.ServeHTTP(w, r)
273 if c := r.Header.Get(someConnHeader); c != "should be deleted" {
274 t.Errorf("handler modified header %q = %q; want %q", someConnHeader, c, "should be deleted")
275 }
276 }))
277 defer frontend.Close()
278
279 getReq, _ := http.NewRequest("GET", frontend.URL, nil)
280 getReq.Header.Add("Connection", "")
281 getReq.Header.Add("Connection", someConnHeader)
282 getReq.Header.Set(someConnHeader, "should be deleted")
283 res, err := frontend.Client().Do(getReq)
284 if err != nil {
285 t.Fatalf("Get: %v", err)
286 }
287 defer res.Body.Close()
288 bodyBytes, err := io.ReadAll(res.Body)
289 if err != nil {
290 t.Fatalf("reading body: %v", err)
291 }
292 if got, want := string(bodyBytes), backendResponse; got != want {
293 t.Errorf("got body %q; want %q", got, want)
294 }
295 if c := res.Header.Get("Connection"); c != "" {
296 t.Errorf("handler got header %q = %q; want empty", "Connection", c)
297 }
298 if c := res.Header.Get(someConnHeader); c != "" {
299 t.Errorf("handler got header %q = %q; want empty", someConnHeader, c)
300 }
301 }
302
303 func TestXForwardedFor(t *testing.T) {
304 const prevForwardedFor = "client ip"
305 const backendResponse = "I am the backend"
306 const backendStatus = 404
307 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
308 if r.Header.Get("X-Forwarded-For") == "" {
309 t.Errorf("didn't get X-Forwarded-For header")
310 }
311 if !strings.Contains(r.Header.Get("X-Forwarded-For"), prevForwardedFor) {
312 t.Errorf("X-Forwarded-For didn't contain prior data")
313 }
314 w.WriteHeader(backendStatus)
315 w.Write([]byte(backendResponse))
316 }))
317 defer backend.Close()
318 backendURL, err := url.Parse(backend.URL)
319 if err != nil {
320 t.Fatal(err)
321 }
322 proxyHandler := NewSingleHostReverseProxy(backendURL)
323 frontend := httptest.NewServer(proxyHandler)
324 defer frontend.Close()
325
326 getReq, _ := http.NewRequest("GET", frontend.URL, nil)
327 getReq.Header.Set("Connection", "close")
328 getReq.Header.Set("X-Forwarded-For", prevForwardedFor)
329 getReq.Close = true
330 res, err := frontend.Client().Do(getReq)
331 if err != nil {
332 t.Fatalf("Get: %v", err)
333 }
334 defer res.Body.Close()
335 if g, e := res.StatusCode, backendStatus; g != e {
336 t.Errorf("got res.StatusCode %d; expected %d", g, e)
337 }
338 bodyBytes, _ := io.ReadAll(res.Body)
339 if g, e := string(bodyBytes), backendResponse; g != e {
340 t.Errorf("got body %q; expected %q", g, e)
341 }
342 }
343
344
345 func TestXForwardedFor_Omit(t *testing.T) {
346 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
347 if v := r.Header.Get("X-Forwarded-For"); v != "" {
348 t.Errorf("got X-Forwarded-For header: %q", v)
349 }
350 w.Write([]byte("hi"))
351 }))
352 defer backend.Close()
353 backendURL, err := url.Parse(backend.URL)
354 if err != nil {
355 t.Fatal(err)
356 }
357 proxyHandler := NewSingleHostReverseProxy(backendURL)
358 frontend := httptest.NewServer(proxyHandler)
359 defer frontend.Close()
360
361 oldDirector := proxyHandler.Director
362 proxyHandler.Director = func(r *http.Request) {
363 r.Header["X-Forwarded-For"] = nil
364 oldDirector(r)
365 }
366
367 getReq, _ := http.NewRequest("GET", frontend.URL, nil)
368 getReq.Host = "some-name"
369 getReq.Close = true
370 res, err := frontend.Client().Do(getReq)
371 if err != nil {
372 t.Fatalf("Get: %v", err)
373 }
374 res.Body.Close()
375 }
376
377 func TestReverseProxyRewriteStripsForwarded(t *testing.T) {
378 headers := []string{
379 "Forwarded",
380 "X-Forwarded-For",
381 "X-Forwarded-Host",
382 "X-Forwarded-Proto",
383 }
384 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
385 for _, h := range headers {
386 if v := r.Header.Get(h); v != "" {
387 t.Errorf("got %v header: %q", h, v)
388 }
389 }
390 }))
391 defer backend.Close()
392 backendURL, err := url.Parse(backend.URL)
393 if err != nil {
394 t.Fatal(err)
395 }
396 proxyHandler := &ReverseProxy{
397 Rewrite: func(r *ProxyRequest) {
398 r.SetURL(backendURL)
399 },
400 }
401 frontend := httptest.NewServer(proxyHandler)
402 defer frontend.Close()
403
404 getReq, _ := http.NewRequest("GET", frontend.URL, nil)
405 getReq.Host = "some-name"
406 getReq.Close = true
407 for _, h := range headers {
408 getReq.Header.Set(h, "x")
409 }
410 res, err := frontend.Client().Do(getReq)
411 if err != nil {
412 t.Fatalf("Get: %v", err)
413 }
414 res.Body.Close()
415 }
416
417 var proxyQueryTests = []struct {
418 baseSuffix string
419 reqSuffix string
420 want string
421 }{
422 {"", "", ""},
423 {"?sta=tic", "?us=er", "sta=tic&us=er"},
424 {"", "?us=er", "us=er"},
425 {"?sta=tic", "", "sta=tic"},
426 }
427
428 func TestReverseProxyQuery(t *testing.T) {
429 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
430 w.Header().Set("X-Got-Query", r.URL.RawQuery)
431 w.Write([]byte("hi"))
432 }))
433 defer backend.Close()
434
435 for i, tt := range proxyQueryTests {
436 backendURL, err := url.Parse(backend.URL + tt.baseSuffix)
437 if err != nil {
438 t.Fatal(err)
439 }
440 frontend := httptest.NewServer(NewSingleHostReverseProxy(backendURL))
441 req, _ := http.NewRequest("GET", frontend.URL+tt.reqSuffix, nil)
442 req.Close = true
443 res, err := frontend.Client().Do(req)
444 if err != nil {
445 t.Fatalf("%d. Get: %v", i, err)
446 }
447 if g, e := res.Header.Get("X-Got-Query"), tt.want; g != e {
448 t.Errorf("%d. got query %q; expected %q", i, g, e)
449 }
450 res.Body.Close()
451 frontend.Close()
452 }
453 }
454
455 func TestReverseProxyFlushInterval(t *testing.T) {
456 const expected = "hi"
457 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
458 w.Write([]byte(expected))
459 }))
460 defer backend.Close()
461
462 backendURL, err := url.Parse(backend.URL)
463 if err != nil {
464 t.Fatal(err)
465 }
466
467 proxyHandler := NewSingleHostReverseProxy(backendURL)
468 proxyHandler.FlushInterval = time.Microsecond
469
470 frontend := httptest.NewServer(proxyHandler)
471 defer frontend.Close()
472
473 req, _ := http.NewRequest("GET", frontend.URL, nil)
474 req.Close = true
475 res, err := frontend.Client().Do(req)
476 if err != nil {
477 t.Fatalf("Get: %v", err)
478 }
479 defer res.Body.Close()
480 if bodyBytes, _ := io.ReadAll(res.Body); string(bodyBytes) != expected {
481 t.Errorf("got body %q; expected %q", bodyBytes, expected)
482 }
483 }
484
485 type mockFlusher struct {
486 http.ResponseWriter
487 flushed bool
488 }
489
490 func (m *mockFlusher) Flush() {
491 m.flushed = true
492 }
493
494 type wrappedRW struct {
495 http.ResponseWriter
496 }
497
498 func (w *wrappedRW) Unwrap() http.ResponseWriter {
499 return w.ResponseWriter
500 }
501
502 func TestReverseProxyResponseControllerFlushInterval(t *testing.T) {
503 const expected = "hi"
504 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
505 w.Write([]byte(expected))
506 }))
507 defer backend.Close()
508
509 backendURL, err := url.Parse(backend.URL)
510 if err != nil {
511 t.Fatal(err)
512 }
513
514 mf := &mockFlusher{}
515 proxyHandler := NewSingleHostReverseProxy(backendURL)
516 proxyHandler.FlushInterval = -1
517 proxyWithMiddleware := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
518 mf.ResponseWriter = w
519 w = &wrappedRW{mf}
520 proxyHandler.ServeHTTP(w, r)
521 })
522
523 frontend := httptest.NewServer(proxyWithMiddleware)
524 defer frontend.Close()
525
526 req, _ := http.NewRequest("GET", frontend.URL, nil)
527 req.Close = true
528 res, err := frontend.Client().Do(req)
529 if err != nil {
530 t.Fatalf("Get: %v", err)
531 }
532 defer res.Body.Close()
533 if bodyBytes, _ := io.ReadAll(res.Body); string(bodyBytes) != expected {
534 t.Errorf("got body %q; expected %q", bodyBytes, expected)
535 }
536 if !mf.flushed {
537 t.Errorf("response writer was not flushed")
538 }
539 }
540
541 func TestReverseProxyFlushIntervalHeaders(t *testing.T) {
542 const expected = "hi"
543 stopCh := make(chan struct{})
544 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
545 w.Header().Add("MyHeader", expected)
546 w.WriteHeader(200)
547 w.(http.Flusher).Flush()
548 <-stopCh
549 }))
550 defer backend.Close()
551 defer close(stopCh)
552
553 backendURL, err := url.Parse(backend.URL)
554 if err != nil {
555 t.Fatal(err)
556 }
557
558 proxyHandler := NewSingleHostReverseProxy(backendURL)
559 proxyHandler.FlushInterval = time.Microsecond
560
561 frontend := httptest.NewServer(proxyHandler)
562 defer frontend.Close()
563
564 req, _ := http.NewRequest("GET", frontend.URL, nil)
565 req.Close = true
566
567 ctx, cancel := context.WithTimeout(req.Context(), 10*time.Second)
568 defer cancel()
569 req = req.WithContext(ctx)
570
571 res, err := frontend.Client().Do(req)
572 if err != nil {
573 t.Fatalf("Get: %v", err)
574 }
575 defer res.Body.Close()
576
577 if res.Header.Get("MyHeader") != expected {
578 t.Errorf("got header %q; expected %q", res.Header.Get("MyHeader"), expected)
579 }
580 }
581
582 func TestReverseProxyCancellation(t *testing.T) {
583 const backendResponse = "I am the backend"
584
585 reqInFlight := make(chan struct{})
586 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
587 close(reqInFlight)
588
589 select {
590 case <-time.After(10 * time.Second):
591
592
593 t.Error("Handler never saw CloseNotify")
594 return
595 case <-w.(http.CloseNotifier).CloseNotify():
596 }
597
598 w.WriteHeader(http.StatusOK)
599 w.Write([]byte(backendResponse))
600 }))
601
602 defer backend.Close()
603
604 backend.Config.ErrorLog = log.New(io.Discard, "", 0)
605
606 backendURL, err := url.Parse(backend.URL)
607 if err != nil {
608 t.Fatal(err)
609 }
610
611 proxyHandler := NewSingleHostReverseProxy(backendURL)
612
613
614
615 proxyHandler.ErrorLog = log.New(io.Discard, "", 0)
616
617 frontend := httptest.NewServer(proxyHandler)
618 defer frontend.Close()
619 frontendClient := frontend.Client()
620
621 getReq, _ := http.NewRequest("GET", frontend.URL, nil)
622 go func() {
623 <-reqInFlight
624 frontendClient.Transport.(*http.Transport).CancelRequest(getReq)
625 }()
626 res, err := frontendClient.Do(getReq)
627 if res != nil {
628 t.Errorf("got response %v; want nil", res.Status)
629 }
630 if err == nil {
631
632
633
634 t.Error("Server.Client().Do() returned nil error; want non-nil error")
635 }
636 }
637
638 func req(t *testing.T, v string) *http.Request {
639 req, err := http.ReadRequest(bufio.NewReader(strings.NewReader(v)))
640 if err != nil {
641 t.Fatal(err)
642 }
643 return req
644 }
645
646
647 func TestNilBody(t *testing.T) {
648 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
649 w.Write([]byte("hi"))
650 }))
651 defer backend.Close()
652
653 frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
654 backURL, _ := url.Parse(backend.URL)
655 rp := NewSingleHostReverseProxy(backURL)
656 r := req(t, "GET / HTTP/1.0\r\n\r\n")
657 r.Body = nil
658 rp.ServeHTTP(w, r)
659 }))
660 defer frontend.Close()
661
662 res, err := http.Get(frontend.URL)
663 if err != nil {
664 t.Fatal(err)
665 }
666 defer res.Body.Close()
667 slurp, err := io.ReadAll(res.Body)
668 if err != nil {
669 t.Fatal(err)
670 }
671 if string(slurp) != "hi" {
672 t.Errorf("Got %q; want %q", slurp, "hi")
673 }
674 }
675
676
677 func TestUserAgentHeader(t *testing.T) {
678 var gotUA string
679 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
680 gotUA = r.Header.Get("User-Agent")
681 }))
682 defer backend.Close()
683 backendURL, err := url.Parse(backend.URL)
684 if err != nil {
685 t.Fatal(err)
686 }
687
688 proxyHandler := new(ReverseProxy)
689 proxyHandler.ErrorLog = log.New(io.Discard, "", 0)
690 proxyHandler.Director = func(req *http.Request) {
691 req.URL = backendURL
692 }
693 frontend := httptest.NewServer(proxyHandler)
694 defer frontend.Close()
695 frontendClient := frontend.Client()
696
697 for _, sentUA := range []string{"explicit UA", ""} {
698 getReq, _ := http.NewRequest("GET", frontend.URL, nil)
699 getReq.Header.Set("User-Agent", sentUA)
700 getReq.Close = true
701 res, err := frontendClient.Do(getReq)
702 if err != nil {
703 t.Fatalf("Get: %v", err)
704 }
705 res.Body.Close()
706 if got, want := gotUA, sentUA; got != want {
707 t.Errorf("got forwarded User-Agent %q, want %q", got, want)
708 }
709 }
710 }
711
712 type bufferPool struct {
713 get func() []byte
714 put func([]byte)
715 }
716
717 func (bp bufferPool) Get() []byte { return bp.get() }
718 func (bp bufferPool) Put(v []byte) { bp.put(v) }
719
720 func TestReverseProxyGetPutBuffer(t *testing.T) {
721 const msg = "hi"
722 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
723 io.WriteString(w, msg)
724 }))
725 defer backend.Close()
726
727 backendURL, err := url.Parse(backend.URL)
728 if err != nil {
729 t.Fatal(err)
730 }
731
732 var (
733 mu sync.Mutex
734 log []string
735 )
736 addLog := func(event string) {
737 mu.Lock()
738 defer mu.Unlock()
739 log = append(log, event)
740 }
741 rp := NewSingleHostReverseProxy(backendURL)
742 const size = 1234
743 rp.BufferPool = bufferPool{
744 get: func() []byte {
745 addLog("getBuf")
746 return make([]byte, size)
747 },
748 put: func(p []byte) {
749 addLog("putBuf-" + strconv.Itoa(len(p)))
750 },
751 }
752 frontend := httptest.NewServer(rp)
753 defer frontend.Close()
754
755 req, _ := http.NewRequest("GET", frontend.URL, nil)
756 req.Close = true
757 res, err := frontend.Client().Do(req)
758 if err != nil {
759 t.Fatalf("Get: %v", err)
760 }
761 slurp, err := io.ReadAll(res.Body)
762 res.Body.Close()
763 if err != nil {
764 t.Fatalf("reading body: %v", err)
765 }
766 if string(slurp) != msg {
767 t.Errorf("msg = %q; want %q", slurp, msg)
768 }
769 wantLog := []string{"getBuf", "putBuf-" + strconv.Itoa(size)}
770 mu.Lock()
771 defer mu.Unlock()
772 if !slices.Equal(log, wantLog) {
773 t.Errorf("Log events = %q; want %q", log, wantLog)
774 }
775 }
776
777 func TestReverseProxy_Post(t *testing.T) {
778 const backendResponse = "I am the backend"
779 const backendStatus = 200
780 var requestBody = bytes.Repeat([]byte("a"), 1<<20)
781 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
782 slurp, err := io.ReadAll(r.Body)
783 if err != nil {
784 t.Errorf("Backend body read = %v", err)
785 }
786 if len(slurp) != len(requestBody) {
787 t.Errorf("Backend read %d request body bytes; want %d", len(slurp), len(requestBody))
788 }
789 if !bytes.Equal(slurp, requestBody) {
790 t.Error("Backend read wrong request body.")
791 }
792 w.Write([]byte(backendResponse))
793 }))
794 defer backend.Close()
795 backendURL, err := url.Parse(backend.URL)
796 if err != nil {
797 t.Fatal(err)
798 }
799 proxyHandler := NewSingleHostReverseProxy(backendURL)
800 frontend := httptest.NewServer(proxyHandler)
801 defer frontend.Close()
802
803 postReq, _ := http.NewRequest("POST", frontend.URL, bytes.NewReader(requestBody))
804 res, err := frontend.Client().Do(postReq)
805 if err != nil {
806 t.Fatalf("Do: %v", err)
807 }
808 defer res.Body.Close()
809 if g, e := res.StatusCode, backendStatus; g != e {
810 t.Errorf("got res.StatusCode %d; expected %d", g, e)
811 }
812 bodyBytes, _ := io.ReadAll(res.Body)
813 if g, e := string(bodyBytes), backendResponse; g != e {
814 t.Errorf("got body %q; expected %q", g, e)
815 }
816 }
817
818 type RoundTripperFunc func(*http.Request) (*http.Response, error)
819
820 func (fn RoundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
821 return fn(req)
822 }
823
824
825 func TestReverseProxy_NilBody(t *testing.T) {
826 backendURL, _ := url.Parse("http://fake.tld/")
827 proxyHandler := NewSingleHostReverseProxy(backendURL)
828 proxyHandler.ErrorLog = log.New(io.Discard, "", 0)
829 proxyHandler.Transport = RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
830 if req.Body != nil {
831 t.Error("Body != nil; want a nil Body")
832 }
833 return nil, errors.New("done testing the interesting part; so force a 502 Gateway error")
834 })
835 frontend := httptest.NewServer(proxyHandler)
836 defer frontend.Close()
837
838 res, err := frontend.Client().Get(frontend.URL)
839 if err != nil {
840 t.Fatal(err)
841 }
842 defer res.Body.Close()
843 if res.StatusCode != 502 {
844 t.Errorf("status code = %v; want 502 (Gateway Error)", res.Status)
845 }
846 }
847
848
849 func TestReverseProxy_AllocatedHeader(t *testing.T) {
850 proxyHandler := new(ReverseProxy)
851 proxyHandler.ErrorLog = log.New(io.Discard, "", 0)
852 proxyHandler.Director = func(*http.Request) {}
853 proxyHandler.Transport = RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
854 if req.Header == nil {
855 t.Error("Header == nil; want a non-nil Header")
856 }
857 return nil, errors.New("done testing the interesting part; so force a 502 Gateway error")
858 })
859
860 proxyHandler.ServeHTTP(httptest.NewRecorder(), &http.Request{
861 Method: "GET",
862 URL: &url.URL{Scheme: "http", Host: "fake.tld", Path: "/"},
863 Proto: "HTTP/1.0",
864 ProtoMajor: 1,
865 })
866 }
867
868
869
870 func TestReverseProxyModifyResponse(t *testing.T) {
871 backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
872 w.Header().Add("X-Hit-Mod", fmt.Sprintf("%v", r.URL.Path == "/mod"))
873 }))
874 defer backendServer.Close()
875
876 rpURL, _ := url.Parse(backendServer.URL)
877 rproxy := NewSingleHostReverseProxy(rpURL)
878 rproxy.ErrorLog = log.New(io.Discard, "", 0)
879 rproxy.ModifyResponse = func(resp *http.Response) error {
880 if resp.Header.Get("X-Hit-Mod") != "true" {
881 return fmt.Errorf("tried to by-pass proxy")
882 }
883 return nil
884 }
885
886 frontendProxy := httptest.NewServer(rproxy)
887 defer frontendProxy.Close()
888
889 tests := []struct {
890 url string
891 wantCode int
892 }{
893 {frontendProxy.URL + "/mod", http.StatusOK},
894 {frontendProxy.URL + "/schedule", http.StatusBadGateway},
895 }
896
897 for i, tt := range tests {
898 resp, err := http.Get(tt.url)
899 if err != nil {
900 t.Fatalf("failed to reach proxy: %v", err)
901 }
902 if g, e := resp.StatusCode, tt.wantCode; g != e {
903 t.Errorf("#%d: got res.StatusCode %d; expected %d", i, g, e)
904 }
905 resp.Body.Close()
906 }
907 }
908
909 type failingRoundTripper struct{}
910
911 func (failingRoundTripper) RoundTrip(*http.Request) (*http.Response, error) {
912 return nil, errors.New("some error")
913 }
914
915 type staticResponseRoundTripper struct{ res *http.Response }
916
917 func (rt staticResponseRoundTripper) RoundTrip(*http.Request) (*http.Response, error) {
918 return rt.res, nil
919 }
920
921 func TestReverseProxyErrorHandler(t *testing.T) {
922 tests := []struct {
923 name string
924 wantCode int
925 errorHandler func(http.ResponseWriter, *http.Request, error)
926 transport http.RoundTripper
927 modifyResponse func(*http.Response) error
928 }{
929 {
930 name: "default",
931 wantCode: http.StatusBadGateway,
932 },
933 {
934 name: "errorhandler",
935 wantCode: http.StatusTeapot,
936 errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) },
937 },
938 {
939 name: "modifyresponse_noerr",
940 transport: staticResponseRoundTripper{
941 &http.Response{StatusCode: 345, Body: http.NoBody},
942 },
943 modifyResponse: func(res *http.Response) error {
944 res.StatusCode++
945 return nil
946 },
947 errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) },
948 wantCode: 346,
949 },
950 {
951 name: "modifyresponse_err",
952 transport: staticResponseRoundTripper{
953 &http.Response{StatusCode: 345, Body: http.NoBody},
954 },
955 modifyResponse: func(res *http.Response) error {
956 res.StatusCode++
957 return errors.New("some error to trigger errorHandler")
958 },
959 errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) },
960 wantCode: http.StatusTeapot,
961 },
962 }
963
964 for _, tt := range tests {
965 t.Run(tt.name, func(t *testing.T) {
966 target := &url.URL{
967 Scheme: "http",
968 Host: "dummy.tld",
969 Path: "/",
970 }
971 rproxy := NewSingleHostReverseProxy(target)
972 rproxy.Transport = tt.transport
973 rproxy.ModifyResponse = tt.modifyResponse
974 if rproxy.Transport == nil {
975 rproxy.Transport = failingRoundTripper{}
976 }
977 rproxy.ErrorLog = log.New(io.Discard, "", 0)
978 if tt.errorHandler != nil {
979 rproxy.ErrorHandler = tt.errorHandler
980 }
981 frontendProxy := httptest.NewServer(rproxy)
982 defer frontendProxy.Close()
983
984 resp, err := http.Get(frontendProxy.URL + "/test")
985 if err != nil {
986 t.Fatalf("failed to reach proxy: %v", err)
987 }
988 if g, e := resp.StatusCode, tt.wantCode; g != e {
989 t.Errorf("got res.StatusCode %d; expected %d", g, e)
990 }
991 resp.Body.Close()
992 })
993 }
994 }
995
996
997 func TestReverseProxy_CopyBuffer(t *testing.T) {
998 backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
999 out := "this call was relayed by the reverse proxy"
1000
1001 w.Header().Set("Content-Length", fmt.Sprintf("%d", len(out)*2))
1002 fmt.Fprintln(w, out)
1003 }))
1004 defer backendServer.Close()
1005
1006 rpURL, err := url.Parse(backendServer.URL)
1007 if err != nil {
1008 t.Fatal(err)
1009 }
1010
1011 var proxyLog bytes.Buffer
1012 rproxy := NewSingleHostReverseProxy(rpURL)
1013 rproxy.ErrorLog = log.New(&proxyLog, "", log.Lshortfile)
1014 donec := make(chan bool, 1)
1015 frontendProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1016 defer func() { donec <- true }()
1017 rproxy.ServeHTTP(w, r)
1018 }))
1019 defer frontendProxy.Close()
1020
1021 if _, err = frontendProxy.Client().Get(frontendProxy.URL); err == nil {
1022 t.Fatalf("want non-nil error")
1023 }
1024
1025
1026
1027
1028 <-donec
1029
1030 expected := []string{
1031 "EOF",
1032 "read",
1033 }
1034 for _, phrase := range expected {
1035 if !bytes.Contains(proxyLog.Bytes(), []byte(phrase)) {
1036 t.Errorf("expected log to contain phrase %q", phrase)
1037 }
1038 }
1039 }
1040
1041 type staticTransport struct {
1042 res *http.Response
1043 }
1044
1045 func (t *staticTransport) RoundTrip(r *http.Request) (*http.Response, error) {
1046 return t.res, nil
1047 }
1048
1049 func BenchmarkServeHTTP(b *testing.B) {
1050 res := &http.Response{
1051 StatusCode: 200,
1052 Body: io.NopCloser(strings.NewReader("")),
1053 }
1054 proxy := &ReverseProxy{
1055 Director: func(*http.Request) {},
1056 Transport: &staticTransport{res},
1057 }
1058
1059 w := httptest.NewRecorder()
1060 r := httptest.NewRequest("GET", "/", nil)
1061
1062 b.ReportAllocs()
1063 for i := 0; i < b.N; i++ {
1064 proxy.ServeHTTP(w, r)
1065 }
1066 }
1067
1068 func TestServeHTTPDeepCopy(t *testing.T) {
1069 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1070 w.Write([]byte("Hello Gopher!"))
1071 }))
1072 defer backend.Close()
1073 backendURL, err := url.Parse(backend.URL)
1074 if err != nil {
1075 t.Fatal(err)
1076 }
1077
1078 type result struct {
1079 before, after string
1080 }
1081
1082 resultChan := make(chan result, 1)
1083 proxyHandler := NewSingleHostReverseProxy(backendURL)
1084 frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1085 before := r.URL.String()
1086 proxyHandler.ServeHTTP(w, r)
1087 after := r.URL.String()
1088 resultChan <- result{before: before, after: after}
1089 }))
1090 defer frontend.Close()
1091
1092 want := result{before: "/", after: "/"}
1093
1094 res, err := frontend.Client().Get(frontend.URL)
1095 if err != nil {
1096 t.Fatalf("Do: %v", err)
1097 }
1098 res.Body.Close()
1099
1100 got := <-resultChan
1101 if got != want {
1102 t.Errorf("got = %+v; want = %+v", got, want)
1103 }
1104 }
1105
1106
1107
1108 func TestClonesRequestHeaders(t *testing.T) {
1109 log.SetOutput(io.Discard)
1110 defer log.SetOutput(os.Stderr)
1111 req, _ := http.NewRequest("GET", "http://foo.tld/", nil)
1112 req.RemoteAddr = "1.2.3.4:56789"
1113 rp := &ReverseProxy{
1114 Director: func(req *http.Request) {
1115 req.Header.Set("From-Director", "1")
1116 },
1117 Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) {
1118 if v := req.Header.Get("From-Director"); v != "1" {
1119 t.Errorf("From-Directory value = %q; want 1", v)
1120 }
1121 return nil, io.EOF
1122 }),
1123 }
1124 rp.ServeHTTP(httptest.NewRecorder(), req)
1125
1126 for _, h := range []string{
1127 "From-Director",
1128 "X-Forwarded-For",
1129 } {
1130 if req.Header.Get(h) != "" {
1131 t.Errorf("%v header mutation modified caller's request", h)
1132 }
1133 }
1134 }
1135
1136 type roundTripperFunc func(req *http.Request) (*http.Response, error)
1137
1138 func (fn roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
1139 return fn(req)
1140 }
1141
1142 func TestModifyResponseClosesBody(t *testing.T) {
1143 req, _ := http.NewRequest("GET", "http://foo.tld/", nil)
1144 req.RemoteAddr = "1.2.3.4:56789"
1145 closeCheck := new(checkCloser)
1146 logBuf := new(strings.Builder)
1147 outErr := errors.New("ModifyResponse error")
1148 rp := &ReverseProxy{
1149 Director: func(req *http.Request) {},
1150 Transport: &staticTransport{&http.Response{
1151 StatusCode: 200,
1152 Body: closeCheck,
1153 }},
1154 ErrorLog: log.New(logBuf, "", 0),
1155 ModifyResponse: func(*http.Response) error {
1156 return outErr
1157 },
1158 }
1159 rec := httptest.NewRecorder()
1160 rp.ServeHTTP(rec, req)
1161 res := rec.Result()
1162 if g, e := res.StatusCode, http.StatusBadGateway; g != e {
1163 t.Errorf("got res.StatusCode %d; expected %d", g, e)
1164 }
1165 if !closeCheck.closed {
1166 t.Errorf("body should have been closed")
1167 }
1168 if g, e := logBuf.String(), outErr.Error(); !strings.Contains(g, e) {
1169 t.Errorf("ErrorLog %q does not contain %q", g, e)
1170 }
1171 }
1172
1173 type checkCloser struct {
1174 closed bool
1175 }
1176
1177 func (cc *checkCloser) Close() error {
1178 cc.closed = true
1179 return nil
1180 }
1181
1182 func (cc *checkCloser) Read(b []byte) (int, error) {
1183 return len(b), nil
1184 }
1185
1186
1187 func TestReverseProxy_PanicBodyError(t *testing.T) {
1188 log.SetOutput(io.Discard)
1189 defer log.SetOutput(os.Stderr)
1190 backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1191 out := "this call was relayed by the reverse proxy"
1192
1193 w.Header().Set("Content-Length", fmt.Sprintf("%d", len(out)*2))
1194 fmt.Fprintln(w, out)
1195 }))
1196 defer backendServer.Close()
1197
1198 rpURL, err := url.Parse(backendServer.URL)
1199 if err != nil {
1200 t.Fatal(err)
1201 }
1202
1203 rproxy := NewSingleHostReverseProxy(rpURL)
1204
1205
1206
1207 defer func() {
1208 err := recover()
1209 if err == nil {
1210 t.Fatal("handler should have panicked")
1211 }
1212 if err != http.ErrAbortHandler {
1213 t.Fatal("expected ErrAbortHandler, got", err)
1214 }
1215 }()
1216 req, _ := http.NewRequest("GET", "http://foo.tld/", nil)
1217 rproxy.ServeHTTP(httptest.NewRecorder(), req)
1218 }
1219
1220
1221 func TestReverseProxy_PanicClosesIncomingBody(t *testing.T) {
1222 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1223 out := "this call was relayed by the reverse proxy"
1224
1225 w.Header().Set("Content-Length", fmt.Sprintf("%d", len(out)*2))
1226 fmt.Fprintln(w, out)
1227 }))
1228 defer backend.Close()
1229 backendURL, err := url.Parse(backend.URL)
1230 if err != nil {
1231 t.Fatal(err)
1232 }
1233 proxyHandler := NewSingleHostReverseProxy(backendURL)
1234 proxyHandler.ErrorLog = log.New(io.Discard, "", 0)
1235 frontend := httptest.NewServer(proxyHandler)
1236 defer frontend.Close()
1237 frontendClient := frontend.Client()
1238
1239 var wg sync.WaitGroup
1240 for i := 0; i < 2; i++ {
1241 wg.Add(1)
1242 go func() {
1243 defer wg.Done()
1244 for j := 0; j < 10; j++ {
1245 const reqLen = 6 * 1024 * 1024
1246 req, _ := http.NewRequest("POST", frontend.URL, &io.LimitedReader{R: neverEnding('x'), N: reqLen})
1247 req.ContentLength = reqLen
1248 resp, _ := frontendClient.Transport.RoundTrip(req)
1249 if resp != nil {
1250 io.Copy(io.Discard, resp.Body)
1251 resp.Body.Close()
1252 }
1253 }
1254 }()
1255 }
1256 wg.Wait()
1257 }
1258
1259 func TestSelectFlushInterval(t *testing.T) {
1260 tests := []struct {
1261 name string
1262 p *ReverseProxy
1263 res *http.Response
1264 want time.Duration
1265 }{
1266 {
1267 name: "default",
1268 res: &http.Response{},
1269 p: &ReverseProxy{FlushInterval: 123},
1270 want: 123,
1271 },
1272 {
1273 name: "server-sent events overrides non-zero",
1274 res: &http.Response{
1275 Header: http.Header{
1276 "Content-Type": {"text/event-stream"},
1277 },
1278 },
1279 p: &ReverseProxy{FlushInterval: 123},
1280 want: -1,
1281 },
1282 {
1283 name: "server-sent events overrides zero",
1284 res: &http.Response{
1285 Header: http.Header{
1286 "Content-Type": {"text/event-stream"},
1287 },
1288 },
1289 p: &ReverseProxy{FlushInterval: 0},
1290 want: -1,
1291 },
1292 {
1293 name: "server-sent events with media-type parameters overrides non-zero",
1294 res: &http.Response{
1295 Header: http.Header{
1296 "Content-Type": {"text/event-stream;charset=utf-8"},
1297 },
1298 },
1299 p: &ReverseProxy{FlushInterval: 123},
1300 want: -1,
1301 },
1302 {
1303 name: "server-sent events with media-type parameters overrides zero",
1304 res: &http.Response{
1305 Header: http.Header{
1306 "Content-Type": {"text/event-stream;charset=utf-8"},
1307 },
1308 },
1309 p: &ReverseProxy{FlushInterval: 0},
1310 want: -1,
1311 },
1312 {
1313 name: "Content-Length: -1, overrides non-zero",
1314 res: &http.Response{
1315 ContentLength: -1,
1316 },
1317 p: &ReverseProxy{FlushInterval: 123},
1318 want: -1,
1319 },
1320 {
1321 name: "Content-Length: -1, overrides zero",
1322 res: &http.Response{
1323 ContentLength: -1,
1324 },
1325 p: &ReverseProxy{FlushInterval: 0},
1326 want: -1,
1327 },
1328 }
1329 for _, tt := range tests {
1330 t.Run(tt.name, func(t *testing.T) {
1331 got := tt.p.flushInterval(tt.res)
1332 if got != tt.want {
1333 t.Errorf("flushLatency = %v; want %v", got, tt.want)
1334 }
1335 })
1336 }
1337 }
1338
1339 func TestReverseProxyWebSocket(t *testing.T) {
1340 backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1341 if upgradeType(r.Header) != "websocket" {
1342 t.Error("unexpected backend request")
1343 http.Error(w, "unexpected request", 400)
1344 return
1345 }
1346 c, _, err := w.(http.Hijacker).Hijack()
1347 if err != nil {
1348 t.Error(err)
1349 return
1350 }
1351 defer c.Close()
1352 io.WriteString(c, "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: WebSocket\r\n\r\n")
1353 bs := bufio.NewScanner(c)
1354 if !bs.Scan() {
1355 t.Errorf("backend failed to read line from client: %v", bs.Err())
1356 return
1357 }
1358 fmt.Fprintf(c, "backend got %q\n", bs.Text())
1359 }))
1360 defer backendServer.Close()
1361
1362 backURL, _ := url.Parse(backendServer.URL)
1363 rproxy := NewSingleHostReverseProxy(backURL)
1364 rproxy.ErrorLog = log.New(io.Discard, "", 0)
1365 rproxy.ModifyResponse = func(res *http.Response) error {
1366 res.Header.Add("X-Modified", "true")
1367 return nil
1368 }
1369
1370 handler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
1371 rw.Header().Set("X-Header", "X-Value")
1372 rproxy.ServeHTTP(rw, req)
1373 if got, want := rw.Header().Get("X-Modified"), "true"; got != want {
1374 t.Errorf("response writer X-Modified header = %q; want %q", got, want)
1375 }
1376 })
1377
1378 frontendProxy := httptest.NewServer(handler)
1379 defer frontendProxy.Close()
1380
1381 req, _ := http.NewRequest("GET", frontendProxy.URL, nil)
1382 req.Header.Set("Connection", "Upgrade")
1383 req.Header.Set("Upgrade", "websocket")
1384
1385 c := frontendProxy.Client()
1386 res, err := c.Do(req)
1387 if err != nil {
1388 t.Fatal(err)
1389 }
1390 if res.StatusCode != 101 {
1391 t.Fatalf("status = %v; want 101", res.Status)
1392 }
1393
1394 got := res.Header.Get("X-Header")
1395 want := "X-Value"
1396 if got != want {
1397 t.Errorf("Header(XHeader) = %q; want %q", got, want)
1398 }
1399
1400 if !ascii.EqualFold(upgradeType(res.Header), "websocket") {
1401 t.Fatalf("not websocket upgrade; got %#v", res.Header)
1402 }
1403 rwc, ok := res.Body.(io.ReadWriteCloser)
1404 if !ok {
1405 t.Fatalf("response body is of type %T; does not implement ReadWriteCloser", res.Body)
1406 }
1407 defer rwc.Close()
1408
1409 if got, want := res.Header.Get("X-Modified"), "true"; got != want {
1410 t.Errorf("response X-Modified header = %q; want %q", got, want)
1411 }
1412
1413 io.WriteString(rwc, "Hello\n")
1414 bs := bufio.NewScanner(rwc)
1415 if !bs.Scan() {
1416 t.Fatalf("Scan: %v", bs.Err())
1417 }
1418 got = bs.Text()
1419 want = `backend got "Hello"`
1420 if got != want {
1421 t.Errorf("got %#q, want %#q", got, want)
1422 }
1423 }
1424
1425 func TestReverseProxyWebSocketCancellation(t *testing.T) {
1426 n := 5
1427 triggerCancelCh := make(chan bool, n)
1428 nthResponse := func(i int) string {
1429 return fmt.Sprintf("backend response #%d\n", i)
1430 }
1431 terminalMsg := "final message"
1432
1433 cst := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1434 if g, ws := upgradeType(r.Header), "websocket"; g != ws {
1435 t.Errorf("Unexpected upgrade type %q, want %q", g, ws)
1436 http.Error(w, "Unexpected request", 400)
1437 return
1438 }
1439 conn, bufrw, err := w.(http.Hijacker).Hijack()
1440 if err != nil {
1441 t.Error(err)
1442 return
1443 }
1444 defer conn.Close()
1445
1446 upgradeMsg := "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: WebSocket\r\n\r\n"
1447 if _, err := io.WriteString(conn, upgradeMsg); err != nil {
1448 t.Error(err)
1449 return
1450 }
1451 if _, _, err := bufrw.ReadLine(); err != nil {
1452 t.Errorf("Failed to read line from client: %v", err)
1453 return
1454 }
1455
1456 for i := 0; i < n; i++ {
1457 if _, err := bufrw.WriteString(nthResponse(i)); err != nil {
1458 select {
1459 case <-triggerCancelCh:
1460 default:
1461 t.Errorf("Writing response #%d failed: %v", i, err)
1462 }
1463 return
1464 }
1465 bufrw.Flush()
1466 time.Sleep(time.Second)
1467 }
1468 if _, err := bufrw.WriteString(terminalMsg); err != nil {
1469 select {
1470 case <-triggerCancelCh:
1471 default:
1472 t.Errorf("Failed to write terminal message: %v", err)
1473 }
1474 }
1475 bufrw.Flush()
1476 }))
1477 defer cst.Close()
1478
1479 backendURL, _ := url.Parse(cst.URL)
1480 rproxy := NewSingleHostReverseProxy(backendURL)
1481 rproxy.ErrorLog = log.New(io.Discard, "", 0)
1482 rproxy.ModifyResponse = func(res *http.Response) error {
1483 res.Header.Add("X-Modified", "true")
1484 return nil
1485 }
1486
1487 handler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
1488 rw.Header().Set("X-Header", "X-Value")
1489 ctx, cancel := context.WithCancel(req.Context())
1490 go func() {
1491 <-triggerCancelCh
1492 cancel()
1493 }()
1494 rproxy.ServeHTTP(rw, req.WithContext(ctx))
1495 })
1496
1497 frontendProxy := httptest.NewServer(handler)
1498 defer frontendProxy.Close()
1499
1500 req, _ := http.NewRequest("GET", frontendProxy.URL, nil)
1501 req.Header.Set("Connection", "Upgrade")
1502 req.Header.Set("Upgrade", "websocket")
1503
1504 res, err := frontendProxy.Client().Do(req)
1505 if err != nil {
1506 t.Fatalf("Dialing to frontend proxy: %v", err)
1507 }
1508 defer res.Body.Close()
1509 if g, w := res.StatusCode, 101; g != w {
1510 t.Fatalf("Switching protocols failed, got: %d, want: %d", g, w)
1511 }
1512
1513 if g, w := res.Header.Get("X-Header"), "X-Value"; g != w {
1514 t.Errorf("X-Header mismatch\n\tgot: %q\n\twant: %q", g, w)
1515 }
1516
1517 if g, w := upgradeType(res.Header), "websocket"; !ascii.EqualFold(g, w) {
1518 t.Fatalf("Upgrade header mismatch\n\tgot: %q\n\twant: %q", g, w)
1519 }
1520
1521 rwc, ok := res.Body.(io.ReadWriteCloser)
1522 if !ok {
1523 t.Fatalf("Response body type mismatch, got %T, want io.ReadWriteCloser", res.Body)
1524 }
1525
1526 if got, want := res.Header.Get("X-Modified"), "true"; got != want {
1527 t.Errorf("response X-Modified header = %q; want %q", got, want)
1528 }
1529
1530 if _, err := io.WriteString(rwc, "Hello\n"); err != nil {
1531 t.Fatalf("Failed to write first message: %v", err)
1532 }
1533
1534
1535
1536 br := bufio.NewReader(rwc)
1537 for {
1538 line, err := br.ReadString('\n')
1539 switch {
1540 case line == terminalMsg:
1541 t.Fatalf("The websocket request was not canceled, unfortunately!")
1542
1543 case err == io.EOF:
1544 return
1545
1546 case err != nil:
1547 t.Fatalf("Unexpected error: %v", err)
1548
1549 case line == nthResponse(0):
1550
1551 close(triggerCancelCh)
1552 }
1553 }
1554 }
1555
1556 func TestReverseProxyWebSocketHalfTCP(t *testing.T) {
1557
1558
1559
1560
1561
1562
1563
1564 switch runtime.GOOS {
1565 case "plan9":
1566 t.Skipf("not supported on %s", runtime.GOOS)
1567 }
1568
1569 mustRead := func(t *testing.T, conn *net.TCPConn, msg string) {
1570 b := make([]byte, len(msg))
1571 if _, err := conn.Read(b); err != nil {
1572 t.Errorf("failed to read: %v", err)
1573 }
1574
1575 if got, want := string(b), msg; got != want {
1576 t.Errorf("got %#q, want %#q", got, want)
1577 }
1578 }
1579
1580 mustReadError := func(t *testing.T, conn *net.TCPConn, e error) {
1581 b := make([]byte, 1)
1582 if _, err := conn.Read(b); !errors.Is(err, e) {
1583 t.Errorf("failed to read error: %v", err)
1584 }
1585 }
1586
1587 mustWrite := func(t *testing.T, conn *net.TCPConn, msg string) {
1588 if _, err := conn.Write([]byte(msg)); err != nil {
1589 t.Errorf("failed to write: %v", err)
1590 }
1591 }
1592
1593 mustCloseRead := func(t *testing.T, conn *net.TCPConn) {
1594 if err := conn.CloseRead(); err != nil {
1595 t.Errorf("failed to CloseRead: %v", err)
1596 }
1597 }
1598
1599 mustCloseWrite := func(t *testing.T, conn *net.TCPConn) {
1600 if err := conn.CloseWrite(); err != nil {
1601 t.Errorf("failed to CloseWrite: %v", err)
1602 }
1603 }
1604
1605 tests := map[string]func(t *testing.T, cli, srv *net.TCPConn){
1606 "server close read": func(t *testing.T, cli, srv *net.TCPConn) {
1607 mustCloseRead(t, srv)
1608 mustWrite(t, srv, "server sends")
1609 mustRead(t, cli, "server sends")
1610 },
1611 "server close write": func(t *testing.T, cli, srv *net.TCPConn) {
1612 mustCloseWrite(t, srv)
1613 mustWrite(t, cli, "client sends")
1614 mustRead(t, srv, "client sends")
1615 mustReadError(t, cli, io.EOF)
1616 },
1617 "client close read": func(t *testing.T, cli, srv *net.TCPConn) {
1618 mustCloseRead(t, cli)
1619 mustWrite(t, cli, "client sends")
1620 mustRead(t, srv, "client sends")
1621 },
1622 "client close write": func(t *testing.T, cli, srv *net.TCPConn) {
1623 mustCloseWrite(t, cli)
1624 mustWrite(t, srv, "server sends")
1625 mustRead(t, cli, "server sends")
1626 mustReadError(t, srv, io.EOF)
1627 },
1628 }
1629
1630 for name, test := range tests {
1631 t.Run(name, func(t *testing.T) {
1632 var srv *net.TCPConn
1633
1634 backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1635 if g, ws := upgradeType(r.Header), "websocket"; g != ws {
1636 t.Fatalf("Unexpected upgrade type %q, want %q", g, ws)
1637 }
1638
1639 conn, _, err := w.(http.Hijacker).Hijack()
1640 if err != nil {
1641 conn.Close()
1642 t.Fatalf("hijack failed: %v", err)
1643 }
1644
1645 var ok bool
1646 if srv, ok = conn.(*net.TCPConn); !ok {
1647 conn.Close()
1648 t.Fatal("conn is not a TCPConn")
1649 }
1650
1651 upgradeMsg := "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: WebSocket\r\n\r\n"
1652 if _, err := io.WriteString(srv, upgradeMsg); err != nil {
1653 srv.Close()
1654 t.Fatalf("backend upgrade failed: %v", err)
1655 }
1656 }))
1657 defer backendServer.Close()
1658
1659 backendURL, _ := url.Parse(backendServer.URL)
1660 rproxy := NewSingleHostReverseProxy(backendURL)
1661 rproxy.ErrorLog = log.New(io.Discard, "", 0)
1662 frontendProxy := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
1663 rproxy.ServeHTTP(rw, req)
1664 }))
1665 defer frontendProxy.Close()
1666
1667 frontendURL, _ := url.Parse(frontendProxy.URL)
1668 addr, err := net.ResolveTCPAddr("tcp", frontendURL.Host)
1669 if err != nil {
1670 t.Fatalf("failed to resolve TCP address: %v", err)
1671 }
1672 cli, err := net.DialTCP("tcp", nil, addr)
1673 if err != nil {
1674 t.Fatalf("failed to dial TCP address: %v", err)
1675 }
1676 defer cli.Close()
1677
1678 req, _ := http.NewRequest("GET", frontendProxy.URL, nil)
1679 req.Header.Set("Connection", "Upgrade")
1680 req.Header.Set("Upgrade", "websocket")
1681 if err := req.Write(cli); err != nil {
1682 t.Fatalf("failed to write request: %v", err)
1683 }
1684
1685 br := bufio.NewReader(cli)
1686 resp, err := http.ReadResponse(br, &http.Request{Method: "GET"})
1687 if err != nil {
1688 t.Fatalf("failed to read response: %v", err)
1689 }
1690 if resp.StatusCode != 101 {
1691 t.Fatalf("status code not 101: %v", resp.StatusCode)
1692 }
1693 if strings.ToLower(resp.Header.Get("Upgrade")) != "websocket" ||
1694 strings.ToLower(resp.Header.Get("Connection")) != "upgrade" {
1695 t.Fatalf("frontend upgrade failed")
1696 }
1697 defer srv.Close()
1698
1699 test(t, cli, srv)
1700 })
1701 }
1702 }
1703
1704 func TestReverseProxyUpgradeNoCloseWrite(t *testing.T) {
1705
1706
1707
1708 backendDone := make(chan struct{})
1709 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1710 w.Header().Set("Connection", "upgrade")
1711 w.Header().Set("Upgrade", "u")
1712 w.WriteHeader(101)
1713 conn, _, err := http.NewResponseController(w).Hijack()
1714 if err != nil {
1715 t.Errorf("Hijack: %v", err)
1716 }
1717 io.Copy(io.Discard, conn)
1718 close(backendDone)
1719 }))
1720 backendURL, err := url.Parse(backend.URL)
1721 if err != nil {
1722 t.Fatal(err)
1723 }
1724
1725
1726
1727 proxyHandler := NewSingleHostReverseProxy(backendURL)
1728 proxyHandler.ModifyResponse = func(resp *http.Response) error {
1729 type readWriteCloserOnly struct {
1730 io.ReadWriteCloser
1731 }
1732 resp.Body = readWriteCloserOnly{resp.Body.(io.ReadWriteCloser)}
1733 return nil
1734 }
1735 frontend := httptest.NewServer(proxyHandler)
1736 defer frontend.Close()
1737
1738
1739 req, _ := http.NewRequest("GET", frontend.URL, nil)
1740 req.Header.Set("Connection", "upgrade")
1741 req.Header.Set("Upgrade", "u")
1742 resp, err := frontend.Client().Do(req)
1743 if err != nil {
1744 t.Fatal(err)
1745 }
1746 resp.Body.Close()
1747
1748
1749 <-backendDone
1750 }
1751
1752 func TestUnannouncedTrailer(t *testing.T) {
1753 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1754 w.WriteHeader(http.StatusOK)
1755 w.(http.Flusher).Flush()
1756 w.Header().Set(http.TrailerPrefix+"X-Unannounced-Trailer", "unannounced_trailer_value")
1757 }))
1758 defer backend.Close()
1759 backendURL, err := url.Parse(backend.URL)
1760 if err != nil {
1761 t.Fatal(err)
1762 }
1763 proxyHandler := NewSingleHostReverseProxy(backendURL)
1764 proxyHandler.ErrorLog = log.New(io.Discard, "", 0)
1765 frontend := httptest.NewServer(proxyHandler)
1766 defer frontend.Close()
1767 frontendClient := frontend.Client()
1768
1769 res, err := frontendClient.Get(frontend.URL)
1770 if err != nil {
1771 t.Fatalf("Get: %v", err)
1772 }
1773
1774 io.ReadAll(res.Body)
1775 res.Body.Close()
1776 if g, w := res.Trailer.Get("X-Unannounced-Trailer"), "unannounced_trailer_value"; g != w {
1777 t.Errorf("Trailer(X-Unannounced-Trailer) = %q; want %q", g, w)
1778 }
1779
1780 }
1781
1782 func TestSetURL(t *testing.T) {
1783 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1784 w.Write([]byte(r.Host))
1785 }))
1786 defer backend.Close()
1787 backendURL, err := url.Parse(backend.URL)
1788 if err != nil {
1789 t.Fatal(err)
1790 }
1791 proxyHandler := &ReverseProxy{
1792 Rewrite: func(r *ProxyRequest) {
1793 r.SetURL(backendURL)
1794 },
1795 }
1796 frontend := httptest.NewServer(proxyHandler)
1797 defer frontend.Close()
1798 frontendClient := frontend.Client()
1799
1800 res, err := frontendClient.Get(frontend.URL)
1801 if err != nil {
1802 t.Fatalf("Get: %v", err)
1803 }
1804 defer res.Body.Close()
1805
1806 body, err := io.ReadAll(res.Body)
1807 if err != nil {
1808 t.Fatalf("Reading body: %v", err)
1809 }
1810
1811 if got, want := string(body), backendURL.Host; got != want {
1812 t.Errorf("backend got Host %q, want %q", got, want)
1813 }
1814 }
1815
1816 func TestSingleJoinSlash(t *testing.T) {
1817 tests := []struct {
1818 slasha string
1819 slashb string
1820 expected string
1821 }{
1822 {"https://www.google.com/", "/favicon.ico", "https://www.google.com/favicon.ico"},
1823 {"https://www.google.com", "/favicon.ico", "https://www.google.com/favicon.ico"},
1824 {"https://www.google.com", "favicon.ico", "https://www.google.com/favicon.ico"},
1825 {"https://www.google.com", "", "https://www.google.com/"},
1826 {"", "favicon.ico", "/favicon.ico"},
1827 }
1828 for _, tt := range tests {
1829 if got := singleJoiningSlash(tt.slasha, tt.slashb); got != tt.expected {
1830 t.Errorf("singleJoiningSlash(%q,%q) want %q got %q",
1831 tt.slasha,
1832 tt.slashb,
1833 tt.expected,
1834 got)
1835 }
1836 }
1837 }
1838
1839 func TestJoinURLPath(t *testing.T) {
1840 tests := []struct {
1841 a *url.URL
1842 b *url.URL
1843 wantPath string
1844 wantRaw string
1845 }{
1846 {&url.URL{Path: "/a/b"}, &url.URL{Path: "/c"}, "/a/b/c", ""},
1847 {&url.URL{Path: "/a/b", RawPath: "badpath"}, &url.URL{Path: "c"}, "/a/b/c", "/a/b/c"},
1848 {&url.URL{Path: "/a/b", RawPath: "/a%2Fb"}, &url.URL{Path: "/c"}, "/a/b/c", "/a%2Fb/c"},
1849 {&url.URL{Path: "/a/b", RawPath: "/a%2Fb"}, &url.URL{Path: "/c"}, "/a/b/c", "/a%2Fb/c"},
1850 {&url.URL{Path: "/a/b/", RawPath: "/a%2Fb%2F"}, &url.URL{Path: "c"}, "/a/b//c", "/a%2Fb%2F/c"},
1851 {&url.URL{Path: "/a/b/", RawPath: "/a%2Fb/"}, &url.URL{Path: "/c/d", RawPath: "/c%2Fd"}, "/a/b/c/d", "/a%2Fb/c%2Fd"},
1852 }
1853
1854 for _, tt := range tests {
1855 p, rp := joinURLPath(tt.a, tt.b)
1856 if p != tt.wantPath || rp != tt.wantRaw {
1857 t.Errorf("joinURLPath(URL(%q,%q),URL(%q,%q)) want (%q,%q) got (%q,%q)",
1858 tt.a.Path, tt.a.RawPath,
1859 tt.b.Path, tt.b.RawPath,
1860 tt.wantPath, tt.wantRaw,
1861 p, rp)
1862 }
1863 }
1864 }
1865
1866 func TestReverseProxyRewriteReplacesOut(t *testing.T) {
1867 const content = "response_content"
1868 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1869 w.Write([]byte(content))
1870 }))
1871 defer backend.Close()
1872 proxyHandler := &ReverseProxy{
1873 Rewrite: func(r *ProxyRequest) {
1874 r.Out, _ = http.NewRequest("GET", backend.URL, nil)
1875 },
1876 }
1877 frontend := httptest.NewServer(proxyHandler)
1878 defer frontend.Close()
1879
1880 res, err := frontend.Client().Get(frontend.URL)
1881 if err != nil {
1882 t.Fatalf("Get: %v", err)
1883 }
1884 defer res.Body.Close()
1885 body, _ := io.ReadAll(res.Body)
1886 if got, want := string(body), content; got != want {
1887 t.Errorf("got response %q, want %q", got, want)
1888 }
1889 }
1890
1891 func Test1xxHeadersNotModifiedAfterRoundTrip(t *testing.T) {
1892
1893
1894
1895
1896 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1897 for i := 0; i < 5; i++ {
1898 w.WriteHeader(103)
1899 }
1900 }))
1901 defer backend.Close()
1902 backendURL, err := url.Parse(backend.URL)
1903 if err != nil {
1904 t.Fatal(err)
1905 }
1906 proxyHandler := NewSingleHostReverseProxy(backendURL)
1907 proxyHandler.ErrorLog = log.New(io.Discard, "", 0)
1908
1909 rw := &testResponseWriter{}
1910 func() {
1911
1912
1913 ctx, cancel := context.WithCancel(context.Background())
1914 defer cancel()
1915 ctx = httptrace.WithClientTrace(ctx, &httptrace.ClientTrace{
1916 Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
1917 cancel()
1918 return nil
1919 },
1920 })
1921
1922 req, _ := http.NewRequestWithContext(ctx, "GET", "http://go.dev/", nil)
1923 proxyHandler.ServeHTTP(rw, req)
1924 }()
1925
1926
1927
1928 for _ = range rw.Header() {
1929 }
1930 }
1931
1932 func Test1xxResponses(t *testing.T) {
1933 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1934 h := w.Header()
1935 h.Add("Link", "</style.css>; rel=preload; as=style")
1936 h.Add("Link", "</script.js>; rel=preload; as=script")
1937 w.WriteHeader(http.StatusEarlyHints)
1938
1939 h.Add("Link", "</foo.js>; rel=preload; as=script")
1940 w.WriteHeader(http.StatusProcessing)
1941
1942 w.Write([]byte("Hello"))
1943 }))
1944 defer backend.Close()
1945 backendURL, err := url.Parse(backend.URL)
1946 if err != nil {
1947 t.Fatal(err)
1948 }
1949 proxyHandler := NewSingleHostReverseProxy(backendURL)
1950 proxyHandler.ErrorLog = log.New(io.Discard, "", 0)
1951 frontend := httptest.NewServer(proxyHandler)
1952 defer frontend.Close()
1953 frontendClient := frontend.Client()
1954
1955 checkLinkHeaders := func(t *testing.T, expected, got []string) {
1956 t.Helper()
1957
1958 if len(expected) != len(got) {
1959 t.Errorf("Expected %d link headers; got %d", len(expected), len(got))
1960 }
1961
1962 for i := range expected {
1963 if i >= len(got) {
1964 t.Errorf("Expected %q link header; got nothing", expected[i])
1965
1966 continue
1967 }
1968
1969 if expected[i] != got[i] {
1970 t.Errorf("Expected %q link header; got %q", expected[i], got[i])
1971 }
1972 }
1973 }
1974
1975 var respCounter uint8
1976 trace := &httptrace.ClientTrace{
1977 Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
1978 switch code {
1979 case http.StatusEarlyHints:
1980 checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script"}, header["Link"])
1981 case http.StatusProcessing:
1982 checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script", "</foo.js>; rel=preload; as=script"}, header["Link"])
1983 default:
1984 t.Error("Unexpected 1xx response")
1985 }
1986
1987 respCounter++
1988
1989 return nil
1990 },
1991 }
1992 req, _ := http.NewRequestWithContext(httptrace.WithClientTrace(context.Background(), trace), "GET", frontend.URL, nil)
1993
1994 res, err := frontendClient.Do(req)
1995 if err != nil {
1996 t.Fatalf("Get: %v", err)
1997 }
1998
1999 defer res.Body.Close()
2000
2001 if respCounter != 2 {
2002 t.Errorf("Expected 2 1xx responses; got %d", respCounter)
2003 }
2004 checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script", "</foo.js>; rel=preload; as=script"}, res.Header["Link"])
2005
2006 body, _ := io.ReadAll(res.Body)
2007 if string(body) != "Hello" {
2008 t.Errorf("Read body %q; want Hello", body)
2009 }
2010 }
2011
2012 const (
2013 testWantsCleanQuery = true
2014 testWantsRawQuery = false
2015 )
2016
2017 func TestReverseProxyQueryParameterSmugglingDirectorDoesNotParseForm(t *testing.T) {
2018 testReverseProxyQueryParameterSmuggling(t, testWantsRawQuery, func(u *url.URL) *ReverseProxy {
2019 proxyHandler := NewSingleHostReverseProxy(u)
2020 oldDirector := proxyHandler.Director
2021 proxyHandler.Director = func(r *http.Request) {
2022 oldDirector(r)
2023 }
2024 return proxyHandler
2025 })
2026 }
2027
2028 func TestReverseProxyQueryParameterSmugglingDirectorParsesForm(t *testing.T) {
2029 testReverseProxyQueryParameterSmuggling(t, testWantsCleanQuery, func(u *url.URL) *ReverseProxy {
2030 proxyHandler := NewSingleHostReverseProxy(u)
2031 oldDirector := proxyHandler.Director
2032 proxyHandler.Director = func(r *http.Request) {
2033
2034
2035 r.FormValue("a")
2036 oldDirector(r)
2037 }
2038 return proxyHandler
2039 })
2040 }
2041
2042 func TestReverseProxyQueryParameterSmugglingRewrite(t *testing.T) {
2043 testReverseProxyQueryParameterSmuggling(t, testWantsCleanQuery, func(u *url.URL) *ReverseProxy {
2044 return &ReverseProxy{
2045 Rewrite: func(r *ProxyRequest) {
2046 r.SetURL(u)
2047 },
2048 }
2049 })
2050 }
2051
2052 func TestReverseProxyQueryParameterSmugglingRewritePreservesRawQuery(t *testing.T) {
2053 testReverseProxyQueryParameterSmuggling(t, testWantsRawQuery, func(u *url.URL) *ReverseProxy {
2054 return &ReverseProxy{
2055 Rewrite: func(r *ProxyRequest) {
2056 r.SetURL(u)
2057 r.Out.URL.RawQuery = r.In.URL.RawQuery
2058 },
2059 }
2060 })
2061 }
2062
2063 func testReverseProxyQueryParameterSmuggling(t *testing.T, wantCleanQuery bool, newProxy func(*url.URL) *ReverseProxy) {
2064 const content = "response_content"
2065 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
2066 w.Write([]byte(r.URL.RawQuery))
2067 }))
2068 defer backend.Close()
2069 backendURL, err := url.Parse(backend.URL)
2070 if err != nil {
2071 t.Fatal(err)
2072 }
2073 proxyHandler := newProxy(backendURL)
2074 frontend := httptest.NewServer(proxyHandler)
2075 defer frontend.Close()
2076
2077
2078 backend.Config.ErrorLog = log.New(io.Discard, "", 0)
2079 frontend.Config.ErrorLog = log.New(io.Discard, "", 0)
2080
2081 for _, test := range []struct {
2082 rawQuery string
2083 cleanQuery string
2084 }{{
2085 rawQuery: "a=1&a=2;b=3",
2086 cleanQuery: "a=1",
2087 }, {
2088 rawQuery: "a=1&a=%zz&b=3",
2089 cleanQuery: "a=1&b=3",
2090 }} {
2091 res, err := frontend.Client().Get(frontend.URL + "?" + test.rawQuery)
2092 if err != nil {
2093 t.Fatalf("Get: %v", err)
2094 }
2095 defer res.Body.Close()
2096 body, _ := io.ReadAll(res.Body)
2097 wantQuery := test.rawQuery
2098 if wantCleanQuery {
2099 wantQuery = test.cleanQuery
2100 }
2101 if got, want := string(body), wantQuery; got != want {
2102 t.Errorf("proxy forwarded raw query %q as %q, want %q", test.rawQuery, got, want)
2103 }
2104 }
2105 }
2106
2107
2108
2109 func TestReverseProxyHijackCopyError(t *testing.T) {
2110 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
2111 w.Header().Set("Upgrade", "someproto")
2112 w.WriteHeader(http.StatusSwitchingProtocols)
2113 }))
2114 defer backend.Close()
2115 backendURL, err := url.Parse(backend.URL)
2116 if err != nil {
2117 t.Fatal(err)
2118 }
2119 proxyHandler := &ReverseProxy{
2120 Rewrite: func(r *ProxyRequest) {
2121 r.SetURL(backendURL)
2122 },
2123 ModifyResponse: func(resp *http.Response) error {
2124 resp.Body = &testReadWriteCloser{
2125 read: func([]byte) (int, error) {
2126 return 0, errors.New("read error")
2127 },
2128 }
2129 return nil
2130 },
2131 }
2132
2133 hijacked := false
2134 rw := &testResponseWriter{
2135 writeHeader: func(statusCode int) {
2136 if hijacked {
2137 t.Errorf("WriteHeader(%v) called after Hijack", statusCode)
2138 }
2139 },
2140 hijack: func() (net.Conn, *bufio.ReadWriter, error) {
2141 hijacked = true
2142 cli, srv := net.Pipe()
2143 go io.Copy(io.Discard, cli)
2144 return srv, bufio.NewReadWriter(bufio.NewReader(srv), bufio.NewWriter(srv)), nil
2145 },
2146 }
2147 req, _ := http.NewRequest("GET", "http://example.tld/", nil)
2148 req.Header.Set("Upgrade", "someproto")
2149 proxyHandler.ServeHTTP(rw, req)
2150 }
2151
2152 type testResponseWriter struct {
2153 h http.Header
2154 writeHeader func(int)
2155 write func([]byte) (int, error)
2156 hijack func() (net.Conn, *bufio.ReadWriter, error)
2157 }
2158
2159 func (rw *testResponseWriter) Header() http.Header {
2160 if rw.h == nil {
2161 rw.h = make(http.Header)
2162 }
2163 return rw.h
2164 }
2165
2166 func (rw *testResponseWriter) WriteHeader(statusCode int) {
2167 if rw.writeHeader != nil {
2168 rw.writeHeader(statusCode)
2169 }
2170 }
2171
2172 func (rw *testResponseWriter) Write(p []byte) (int, error) {
2173 if rw.write != nil {
2174 return rw.write(p)
2175 }
2176 return len(p), nil
2177 }
2178
2179 func (rw *testResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
2180 if rw.hijack != nil {
2181 return rw.hijack()
2182 }
2183 return nil, nil, errors.ErrUnsupported
2184 }
2185
2186 type testReadWriteCloser struct {
2187 read func([]byte) (int, error)
2188 write func([]byte) (int, error)
2189 close func() error
2190 }
2191
2192 func (rc *testReadWriteCloser) Read(p []byte) (int, error) {
2193 if rc.read != nil {
2194 return rc.read(p)
2195 }
2196 return 0, io.EOF
2197 }
2198
2199 func (rc *testReadWriteCloser) Write(p []byte) (int, error) {
2200 if rc.write != nil {
2201 return rc.write(p)
2202 }
2203 return len(p), nil
2204 }
2205
2206 func (rc *testReadWriteCloser) Close() error {
2207 if rc.close != nil {
2208 return rc.close()
2209 }
2210 return nil
2211 }
2212
View as plain text