Source file
src/net/http/serve_test.go
1
2
3
4
5
6
7 package http_test
8
9 import (
10 "bufio"
11 "bytes"
12 "compress/gzip"
13 "compress/zlib"
14 "context"
15 "crypto/tls"
16 "crypto/x509"
17 "encoding/json"
18 "errors"
19 "fmt"
20 "internal/synctest"
21 "internal/testenv"
22 "io"
23 "log"
24 "math/rand"
25 "mime/multipart"
26 "net"
27 . "net/http"
28 "net/http/httptest"
29 "net/http/httptrace"
30 "net/http/httputil"
31 "net/http/internal"
32 "net/http/internal/testcert"
33 "net/url"
34 "os"
35 "path/filepath"
36 "reflect"
37 "regexp"
38 "runtime"
39 "slices"
40 "strconv"
41 "strings"
42 "sync"
43 "sync/atomic"
44 "syscall"
45 "testing"
46 "time"
47 )
48
49 type dummyAddr string
50 type oneConnListener struct {
51 conn net.Conn
52 }
53
54 func (l *oneConnListener) Accept() (c net.Conn, err error) {
55 c = l.conn
56 if c == nil {
57 err = io.EOF
58 return
59 }
60 err = nil
61 l.conn = nil
62 return
63 }
64
65 func (l *oneConnListener) Close() error {
66 return nil
67 }
68
69 func (l *oneConnListener) Addr() net.Addr {
70 return dummyAddr("test-address")
71 }
72
73 func (a dummyAddr) Network() string {
74 return string(a)
75 }
76
77 func (a dummyAddr) String() string {
78 return string(a)
79 }
80
81 type noopConn struct{}
82
83 func (noopConn) LocalAddr() net.Addr { return dummyAddr("local-addr") }
84 func (noopConn) RemoteAddr() net.Addr { return dummyAddr("remote-addr") }
85 func (noopConn) SetDeadline(t time.Time) error { return nil }
86 func (noopConn) SetReadDeadline(t time.Time) error { return nil }
87 func (noopConn) SetWriteDeadline(t time.Time) error { return nil }
88
89 type rwTestConn struct {
90 io.Reader
91 io.Writer
92 noopConn
93
94 closeFunc func() error
95 closec chan bool
96 }
97
98 func (c *rwTestConn) Close() error {
99 if c.closeFunc != nil {
100 return c.closeFunc()
101 }
102 select {
103 case c.closec <- true:
104 default:
105 }
106 return nil
107 }
108
109 type testConn struct {
110 readMu sync.Mutex
111 readBuf bytes.Buffer
112 writeBuf bytes.Buffer
113 closec chan bool
114 noopConn
115 }
116
117 func newTestConn() *testConn {
118 return &testConn{closec: make(chan bool, 1)}
119 }
120
121 func (c *testConn) Read(b []byte) (int, error) {
122 c.readMu.Lock()
123 defer c.readMu.Unlock()
124 return c.readBuf.Read(b)
125 }
126
127 func (c *testConn) Write(b []byte) (int, error) {
128 return c.writeBuf.Write(b)
129 }
130
131 func (c *testConn) Close() error {
132 select {
133 case c.closec <- true:
134 default:
135 }
136 return nil
137 }
138
139
140
141 func reqBytes(req string) []byte {
142 return []byte(strings.ReplaceAll(strings.TrimSpace(req), "\n", "\r\n") + "\r\n\r\n")
143 }
144
145 type handlerTest struct {
146 logbuf bytes.Buffer
147 handler Handler
148 }
149
150 func newHandlerTest(h Handler) handlerTest {
151 return handlerTest{handler: h}
152 }
153
154 func (ht *handlerTest) rawResponse(req string) string {
155 reqb := reqBytes(req)
156 var output strings.Builder
157 conn := &rwTestConn{
158 Reader: bytes.NewReader(reqb),
159 Writer: &output,
160 closec: make(chan bool, 1),
161 }
162 ln := &oneConnListener{conn: conn}
163 srv := &Server{
164 ErrorLog: log.New(&ht.logbuf, "", 0),
165 Handler: ht.handler,
166 }
167 go srv.Serve(ln)
168 <-conn.closec
169 return output.String()
170 }
171
172 func TestConsumingBodyOnNextConn(t *testing.T) {
173 t.Parallel()
174 defer afterTest(t)
175 conn := new(testConn)
176 for i := 0; i < 2; i++ {
177 conn.readBuf.Write([]byte(
178 "POST / HTTP/1.1\r\n" +
179 "Host: test\r\n" +
180 "Content-Length: 11\r\n" +
181 "\r\n" +
182 "foo=1&bar=1"))
183 }
184
185 reqNum := 0
186 ch := make(chan *Request)
187 servech := make(chan error)
188 listener := &oneConnListener{conn}
189 handler := func(res ResponseWriter, req *Request) {
190 reqNum++
191 ch <- req
192 }
193
194 go func() {
195 servech <- Serve(listener, HandlerFunc(handler))
196 }()
197
198 var req *Request
199 req = <-ch
200 if req == nil {
201 t.Fatal("Got nil first request.")
202 }
203 if req.Method != "POST" {
204 t.Errorf("For request #1's method, got %q; expected %q",
205 req.Method, "POST")
206 }
207
208 req = <-ch
209 if req == nil {
210 t.Fatal("Got nil first request.")
211 }
212 if req.Method != "POST" {
213 t.Errorf("For request #2's method, got %q; expected %q",
214 req.Method, "POST")
215 }
216
217 if serveerr := <-servech; serveerr != io.EOF {
218 t.Errorf("Serve returned %q; expected EOF", serveerr)
219 }
220 }
221
222 type stringHandler string
223
224 func (s stringHandler) ServeHTTP(w ResponseWriter, r *Request) {
225 w.Header().Set("Result", string(s))
226 }
227
228 var handlers = []struct {
229 pattern string
230 msg string
231 }{
232 {"/", "Default"},
233 {"/someDir/", "someDir"},
234 {"/#/", "hash"},
235 {"someHost.com/someDir/", "someHost.com/someDir"},
236 }
237
238 var vtests = []struct {
239 url string
240 expected string
241 }{
242 {"http://localhost/someDir/apage", "someDir"},
243 {"http://localhost/%23/apage", "hash"},
244 {"http://localhost/otherDir/apage", "Default"},
245 {"http://someHost.com/someDir/apage", "someHost.com/someDir"},
246 {"http://otherHost.com/someDir/apage", "someDir"},
247 {"http://otherHost.com/aDir/apage", "Default"},
248
249 {"http://localhost/someDir", "/someDir/"},
250 {"http://localhost/%23", "/%23/"},
251 {"http://someHost.com/someDir", "/someDir/"},
252 }
253
254 func TestHostHandlers(t *testing.T) { run(t, testHostHandlers, []testMode{http1Mode}) }
255 func testHostHandlers(t *testing.T, mode testMode) {
256 mux := NewServeMux()
257 for _, h := range handlers {
258 mux.Handle(h.pattern, stringHandler(h.msg))
259 }
260 ts := newClientServerTest(t, mode, mux).ts
261
262 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
263 if err != nil {
264 t.Fatal(err)
265 }
266 defer conn.Close()
267 cc := httputil.NewClientConn(conn, nil)
268 for _, vt := range vtests {
269 var r *Response
270 var req Request
271 if req.URL, err = url.Parse(vt.url); err != nil {
272 t.Errorf("cannot parse url: %v", err)
273 continue
274 }
275 if err := cc.Write(&req); err != nil {
276 t.Errorf("writing request: %v", err)
277 continue
278 }
279 r, err := cc.Read(&req)
280 if err != nil {
281 t.Errorf("reading response: %v", err)
282 continue
283 }
284 switch r.StatusCode {
285 case StatusOK:
286 s := r.Header.Get("Result")
287 if s != vt.expected {
288 t.Errorf("Get(%q) = %q, want %q", vt.url, s, vt.expected)
289 }
290 case StatusMovedPermanently:
291 s := r.Header.Get("Location")
292 if s != vt.expected {
293 t.Errorf("Get(%q) = %q, want %q", vt.url, s, vt.expected)
294 }
295 default:
296 t.Errorf("Get(%q) unhandled status code %d", vt.url, r.StatusCode)
297 }
298 }
299 }
300
301 var serveMuxRegister = []struct {
302 pattern string
303 h Handler
304 }{
305 {"/dir/", serve(200)},
306 {"/search", serve(201)},
307 {"codesearch.google.com/search", serve(202)},
308 {"codesearch.google.com/", serve(203)},
309 {"example.com/", HandlerFunc(checkQueryStringHandler)},
310 }
311
312
313 func serve(code int) HandlerFunc {
314 return func(w ResponseWriter, r *Request) {
315 w.WriteHeader(code)
316 }
317 }
318
319
320
321
322 func checkQueryStringHandler(w ResponseWriter, r *Request) {
323 u := *r.URL
324 u.Scheme = "http"
325 u.Host = r.Host
326 u.RawQuery = ""
327 if "http://"+r.URL.RawQuery == u.String() {
328 w.WriteHeader(200)
329 } else {
330 w.WriteHeader(500)
331 }
332 }
333
334 var serveMuxTests = []struct {
335 method string
336 host string
337 path string
338 code int
339 pattern string
340 }{
341 {"GET", "google.com", "/", 404, ""},
342 {"GET", "google.com", "/dir", 301, "/dir/"},
343 {"GET", "google.com", "/dir/", 200, "/dir/"},
344 {"GET", "google.com", "/dir/file", 200, "/dir/"},
345 {"GET", "google.com", "/search", 201, "/search"},
346 {"GET", "google.com", "/search/", 404, ""},
347 {"GET", "google.com", "/search/foo", 404, ""},
348 {"GET", "codesearch.google.com", "/search", 202, "codesearch.google.com/search"},
349 {"GET", "codesearch.google.com", "/search/", 203, "codesearch.google.com/"},
350 {"GET", "codesearch.google.com", "/search/foo", 203, "codesearch.google.com/"},
351 {"GET", "codesearch.google.com", "/", 203, "codesearch.google.com/"},
352 {"GET", "codesearch.google.com:443", "/", 203, "codesearch.google.com/"},
353 {"GET", "images.google.com", "/search", 201, "/search"},
354 {"GET", "images.google.com", "/search/", 404, ""},
355 {"GET", "images.google.com", "/search/foo", 404, ""},
356 {"GET", "google.com", "/../search", 301, "/search"},
357 {"GET", "google.com", "/dir/..", 301, ""},
358 {"GET", "google.com", "/dir/..", 301, ""},
359 {"GET", "google.com", "/dir/./file", 301, "/dir/"},
360
361
362
363 {"CONNECT", "google.com", "/dir", 301, "/dir/"},
364 {"CONNECT", "google.com", "/../search", 404, ""},
365 {"CONNECT", "google.com", "/dir/..", 200, "/dir/"},
366 {"CONNECT", "google.com", "/dir/..", 200, "/dir/"},
367 {"CONNECT", "google.com", "/dir/./file", 200, "/dir/"},
368 }
369
370 func TestServeMuxHandler(t *testing.T) {
371 setParallel(t)
372 mux := NewServeMux()
373 for _, e := range serveMuxRegister {
374 mux.Handle(e.pattern, e.h)
375 }
376
377 for _, tt := range serveMuxTests {
378 r := &Request{
379 Method: tt.method,
380 Host: tt.host,
381 URL: &url.URL{
382 Path: tt.path,
383 },
384 }
385 h, pattern := mux.Handler(r)
386 rr := httptest.NewRecorder()
387 h.ServeHTTP(rr, r)
388 if pattern != tt.pattern || rr.Code != tt.code {
389 t.Errorf("%s %s %s = %d, %q, want %d, %q", tt.method, tt.host, tt.path, rr.Code, pattern, tt.code, tt.pattern)
390 }
391 }
392 }
393
394
395 func TestServeMuxHandlerTrailingSlash(t *testing.T) {
396 setParallel(t)
397 mux := NewServeMux()
398 const original = "/{x}/"
399 mux.Handle(original, NotFoundHandler())
400 r, _ := NewRequest("POST", "/foo", nil)
401 _, p := mux.Handler(r)
402 if p != original {
403 t.Errorf("got %q, want %q", p, original)
404 }
405 }
406
407
408 func TestServeMuxHandleFuncWithNilHandler(t *testing.T) {
409 setParallel(t)
410 defer func() {
411 if err := recover(); err == nil {
412 t.Error("expected call to mux.HandleFunc to panic")
413 }
414 }()
415 mux := NewServeMux()
416 mux.HandleFunc("/", nil)
417 }
418
419 var serveMuxTests2 = []struct {
420 method string
421 host string
422 url string
423 code int
424 redirOk bool
425 }{
426 {"GET", "google.com", "/", 404, false},
427 {"GET", "example.com", "/test/?example.com/test/", 200, false},
428 {"GET", "example.com", "test/?example.com/test/", 200, true},
429 }
430
431
432
433 func TestServeMuxHandlerRedirects(t *testing.T) {
434 setParallel(t)
435 mux := NewServeMux()
436 for _, e := range serveMuxRegister {
437 mux.Handle(e.pattern, e.h)
438 }
439
440 for _, tt := range serveMuxTests2 {
441 tries := 1
442 turl := tt.url
443 for {
444 u, e := url.Parse(turl)
445 if e != nil {
446 t.Fatal(e)
447 }
448 r := &Request{
449 Method: tt.method,
450 Host: tt.host,
451 URL: u,
452 }
453 h, _ := mux.Handler(r)
454 rr := httptest.NewRecorder()
455 h.ServeHTTP(rr, r)
456 if rr.Code != 301 {
457 if rr.Code != tt.code {
458 t.Errorf("%s %s %s = %d, want %d", tt.method, tt.host, tt.url, rr.Code, tt.code)
459 }
460 break
461 }
462 if !tt.redirOk {
463 t.Errorf("%s %s %s, unexpected redirect", tt.method, tt.host, tt.url)
464 break
465 }
466 turl = rr.HeaderMap.Get("Location")
467 tries--
468 }
469 if tries < 0 {
470 t.Errorf("%s %s %s, too many redirects", tt.method, tt.host, tt.url)
471 }
472 }
473 }
474
475
476 func TestMuxRedirectLeadingSlashes(t *testing.T) {
477 setParallel(t)
478 paths := []string{"//foo.txt", "///foo.txt", "/../../foo.txt"}
479 for _, path := range paths {
480 req, err := ReadRequest(bufio.NewReader(strings.NewReader("GET " + path + " HTTP/1.1\r\nHost: test\r\n\r\n")))
481 if err != nil {
482 t.Errorf("%s", err)
483 }
484 mux := NewServeMux()
485 resp := httptest.NewRecorder()
486
487 mux.ServeHTTP(resp, req)
488
489 if loc, expected := resp.Header().Get("Location"), "/foo.txt"; loc != expected {
490 t.Errorf("Expected Location header set to %q; got %q", expected, loc)
491 return
492 }
493
494 if code, expected := resp.Code, StatusMovedPermanently; code != expected {
495 t.Errorf("Expected response code of StatusMovedPermanently; got %d", code)
496 return
497 }
498 }
499 }
500
501
502
503
504
505 func TestServeWithSlashRedirectKeepsQueryString(t *testing.T) {
506 run(t, testServeWithSlashRedirectKeepsQueryString, []testMode{http1Mode})
507 }
508 func testServeWithSlashRedirectKeepsQueryString(t *testing.T, mode testMode) {
509 writeBackQuery := func(w ResponseWriter, r *Request) {
510 fmt.Fprintf(w, "%s", r.URL.RawQuery)
511 }
512
513 mux := NewServeMux()
514 mux.HandleFunc("/testOne", writeBackQuery)
515 mux.HandleFunc("/testTwo/", writeBackQuery)
516 mux.HandleFunc("/testThree", writeBackQuery)
517 mux.HandleFunc("/testThree/", func(w ResponseWriter, r *Request) {
518 fmt.Fprintf(w, "%s:bar", r.URL.RawQuery)
519 })
520
521 ts := newClientServerTest(t, mode, mux).ts
522
523 tests := [...]struct {
524 path string
525 method string
526 want string
527 statusOk bool
528 }{
529 0: {"/testOne?this=that", "GET", "this=that", true},
530 1: {"/testTwo?foo=bar", "GET", "foo=bar", true},
531 2: {"/testTwo?a=1&b=2&a=3", "GET", "a=1&b=2&a=3", true},
532 3: {"/testTwo?", "GET", "", true},
533 4: {"/testThree?foo", "GET", "foo", true},
534 5: {"/testThree/?foo", "GET", "foo:bar", true},
535 6: {"/testThree?foo", "CONNECT", "foo", true},
536 7: {"/testThree/?foo", "CONNECT", "foo:bar", true},
537
538
539 8: {"/testOne/foo/..?foo", "GET", "foo", true},
540 9: {"/testOne/foo/..?foo", "CONNECT", "404 page not found\n", false},
541 }
542
543 for i, tt := range tests {
544 req, _ := NewRequest(tt.method, ts.URL+tt.path, nil)
545 res, err := ts.Client().Do(req)
546 if err != nil {
547 continue
548 }
549 slurp, _ := io.ReadAll(res.Body)
550 res.Body.Close()
551 if !tt.statusOk {
552 if got, want := res.StatusCode, 404; got != want {
553 t.Errorf("#%d: Status = %d; want = %d", i, got, want)
554 }
555 }
556 if got, want := string(slurp), tt.want; got != want {
557 t.Errorf("#%d: Body = %q; want = %q", i, got, want)
558 }
559 }
560 }
561
562 func TestServeWithSlashRedirectForHostPatterns(t *testing.T) {
563 setParallel(t)
564
565 mux := NewServeMux()
566 mux.Handle("example.com/pkg/foo/", stringHandler("example.com/pkg/foo/"))
567 mux.Handle("example.com/pkg/bar", stringHandler("example.com/pkg/bar"))
568 mux.Handle("example.com/pkg/bar/", stringHandler("example.com/pkg/bar/"))
569 mux.Handle("example.com:3000/pkg/connect/", stringHandler("example.com:3000/pkg/connect/"))
570 mux.Handle("example.com:9000/", stringHandler("example.com:9000/"))
571 mux.Handle("/pkg/baz/", stringHandler("/pkg/baz/"))
572
573 tests := []struct {
574 method string
575 url string
576 code int
577 loc string
578 want string
579 }{
580 {"GET", "http://example.com/", 404, "", ""},
581 {"GET", "http://example.com/pkg/foo", 301, "/pkg/foo/", ""},
582 {"GET", "http://example.com/pkg/bar", 200, "", "example.com/pkg/bar"},
583 {"GET", "http://example.com/pkg/bar/", 200, "", "example.com/pkg/bar/"},
584 {"GET", "http://example.com/pkg/baz", 301, "/pkg/baz/", ""},
585 {"GET", "http://example.com:3000/pkg/foo", 301, "/pkg/foo/", ""},
586 {"CONNECT", "http://example.com/", 404, "", ""},
587 {"CONNECT", "http://example.com:3000/", 404, "", ""},
588 {"CONNECT", "http://example.com:9000/", 200, "", "example.com:9000/"},
589 {"CONNECT", "http://example.com/pkg/foo", 301, "/pkg/foo/", ""},
590 {"CONNECT", "http://example.com:3000/pkg/foo", 404, "", ""},
591 {"CONNECT", "http://example.com:3000/pkg/baz", 301, "/pkg/baz/", ""},
592 {"CONNECT", "http://example.com:3000/pkg/connect", 301, "/pkg/connect/", ""},
593 }
594
595 for i, tt := range tests {
596 req, _ := NewRequest(tt.method, tt.url, nil)
597 w := httptest.NewRecorder()
598 mux.ServeHTTP(w, req)
599
600 if got, want := w.Code, tt.code; got != want {
601 t.Errorf("#%d: Status = %d; want = %d", i, got, want)
602 }
603
604 if tt.code == 301 {
605 if got, want := w.HeaderMap.Get("Location"), tt.loc; got != want {
606 t.Errorf("#%d: Location = %q; want = %q", i, got, want)
607 }
608 } else {
609 if got, want := w.HeaderMap.Get("Result"), tt.want; got != want {
610 t.Errorf("#%d: Result = %q; want = %q", i, got, want)
611 }
612 }
613 }
614 }
615
616
617
618
619 func TestMuxNoSlashRedirectWithTrailingSlash(t *testing.T) {
620 mux := NewServeMux()
621 mux.HandleFunc("/{x}/", func(w ResponseWriter, r *Request) {
622 fmt.Fprintln(w, "ok")
623 })
624 w := httptest.NewRecorder()
625 req, _ := NewRequest("GET", "/", nil)
626 mux.ServeHTTP(w, req)
627 if g, w := w.Code, 404; g != w {
628 t.Errorf("got %d, want %d", g, w)
629 }
630 }
631
632
633
634
635 func TestMuxNoSlash405WithTrailingSlash(t *testing.T) {
636 mux := NewServeMux()
637 mux.HandleFunc("GET /{x}/", func(w ResponseWriter, r *Request) {
638 fmt.Fprintln(w, "ok")
639 })
640 w := httptest.NewRecorder()
641 req, _ := NewRequest("GET", "/", nil)
642 mux.ServeHTTP(w, req)
643 if g, w := w.Code, 404; g != w {
644 t.Errorf("got %d, want %d", g, w)
645 }
646 }
647
648 func TestShouldRedirectConcurrency(t *testing.T) { run(t, testShouldRedirectConcurrency) }
649 func testShouldRedirectConcurrency(t *testing.T, mode testMode) {
650 mux := NewServeMux()
651 newClientServerTest(t, mode, mux)
652 mux.HandleFunc("/", func(w ResponseWriter, r *Request) {})
653 }
654
655 func BenchmarkServeMux(b *testing.B) { benchmarkServeMux(b, true) }
656 func BenchmarkServeMux_SkipServe(b *testing.B) { benchmarkServeMux(b, false) }
657 func benchmarkServeMux(b *testing.B, runHandler bool) {
658 type test struct {
659 path string
660 code int
661 req *Request
662 }
663
664
665 var tests []test
666 endpoints := []string{"search", "dir", "file", "change", "count", "s"}
667 for _, e := range endpoints {
668 for i := 200; i < 230; i++ {
669 p := fmt.Sprintf("/%s/%d/", e, i)
670 tests = append(tests, test{
671 path: p,
672 code: i,
673 req: &Request{Method: "GET", Host: "localhost", URL: &url.URL{Path: p}},
674 })
675 }
676 }
677 mux := NewServeMux()
678 for _, tt := range tests {
679 mux.Handle(tt.path, serve(tt.code))
680 }
681
682 rw := httptest.NewRecorder()
683 b.ReportAllocs()
684 b.ResetTimer()
685 for i := 0; i < b.N; i++ {
686 for _, tt := range tests {
687 *rw = httptest.ResponseRecorder{}
688 h, pattern := mux.Handler(tt.req)
689 if runHandler {
690 h.ServeHTTP(rw, tt.req)
691 if pattern != tt.path || rw.Code != tt.code {
692 b.Fatalf("got %d, %q, want %d, %q", rw.Code, pattern, tt.code, tt.path)
693 }
694 }
695 }
696 }
697 }
698
699 func TestServerTimeouts(t *testing.T) { run(t, testServerTimeouts, []testMode{http1Mode}) }
700 func testServerTimeouts(t *testing.T, mode testMode) {
701 runTimeSensitiveTest(t, []time.Duration{
702 10 * time.Millisecond,
703 50 * time.Millisecond,
704 100 * time.Millisecond,
705 500 * time.Millisecond,
706 1 * time.Second,
707 }, func(t *testing.T, timeout time.Duration) error {
708 return testServerTimeoutsWithTimeout(t, timeout, mode)
709 })
710 }
711
712 func testServerTimeoutsWithTimeout(t *testing.T, timeout time.Duration, mode testMode) error {
713 var reqNum atomic.Int32
714 cst := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
715 fmt.Fprintf(res, "req=%d", reqNum.Add(1))
716 }), func(ts *httptest.Server) {
717 ts.Config.ReadTimeout = timeout
718 ts.Config.WriteTimeout = timeout
719 })
720 defer cst.close()
721 ts := cst.ts
722
723
724 c := ts.Client()
725 r, err := c.Get(ts.URL)
726 if err != nil {
727 return fmt.Errorf("http Get #1: %v", err)
728 }
729 got, err := io.ReadAll(r.Body)
730 expected := "req=1"
731 if string(got) != expected || err != nil {
732 return fmt.Errorf("Unexpected response for request #1; got %q ,%v; expected %q, nil",
733 string(got), err, expected)
734 }
735
736
737 t1 := time.Now()
738 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
739 if err != nil {
740 return fmt.Errorf("Dial: %v", err)
741 }
742 buf := make([]byte, 1)
743 n, err := conn.Read(buf)
744 conn.Close()
745 latency := time.Since(t1)
746 if n != 0 || err != io.EOF {
747 return fmt.Errorf("Read = %v, %v, wanted %v, %v", n, err, 0, io.EOF)
748 }
749 minLatency := timeout / 5 * 4
750 if latency < minLatency {
751 return fmt.Errorf("got EOF after %s, want >= %s", latency, minLatency)
752 }
753
754
755
756
757 r, err = c.Get(ts.URL)
758 if err != nil {
759 return fmt.Errorf("http Get #2: %v", err)
760 }
761 got, err = io.ReadAll(r.Body)
762 r.Body.Close()
763 expected = "req=2"
764 if string(got) != expected || err != nil {
765 return fmt.Errorf("Get #2 got %q, %v, want %q, nil", string(got), err, expected)
766 }
767
768 if !testing.Short() {
769 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
770 if err != nil {
771 return fmt.Errorf("long Dial: %v", err)
772 }
773 defer conn.Close()
774 go io.Copy(io.Discard, conn)
775 for i := 0; i < 5; i++ {
776 _, err := conn.Write([]byte("GET / HTTP/1.1\r\nHost: foo\r\n\r\n"))
777 if err != nil {
778 return fmt.Errorf("on write %d: %v", i, err)
779 }
780 time.Sleep(timeout / 2)
781 }
782 }
783 return nil
784 }
785
786 func TestServerReadTimeout(t *testing.T) { run(t, testServerReadTimeout) }
787 func testServerReadTimeout(t *testing.T, mode testMode) {
788 respBody := "response body"
789 for timeout := 5 * time.Millisecond; ; timeout *= 2 {
790 cst := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
791 _, err := io.Copy(io.Discard, req.Body)
792 if !errors.Is(err, os.ErrDeadlineExceeded) {
793 t.Errorf("server timed out reading request body: got err %v; want os.ErrDeadlineExceeded", err)
794 }
795 res.Write([]byte(respBody))
796 }), func(ts *httptest.Server) {
797 ts.Config.ReadHeaderTimeout = -1
798 ts.Config.ReadTimeout = timeout
799 t.Logf("Server.Config.ReadTimeout = %v", timeout)
800 })
801
802 var retries atomic.Int32
803 cst.c.Transport.(*Transport).Proxy = func(*Request) (*url.URL, error) {
804 if retries.Add(1) != 1 {
805 return nil, errors.New("too many retries")
806 }
807 return nil, nil
808 }
809
810 pr, pw := io.Pipe()
811 res, err := cst.c.Post(cst.ts.URL, "text/apocryphal", pr)
812 if err != nil {
813 t.Logf("Get error, retrying: %v", err)
814 cst.close()
815 continue
816 }
817 defer res.Body.Close()
818 got, err := io.ReadAll(res.Body)
819 if string(got) != respBody || err != nil {
820 t.Errorf("client read response body: %q, %v; want %q, nil", string(got), err, respBody)
821 }
822 pw.Close()
823 break
824 }
825 }
826
827 func TestServerNoReadTimeout(t *testing.T) { run(t, testServerNoReadTimeout) }
828 func testServerNoReadTimeout(t *testing.T, mode testMode) {
829 reqBody := "Hello, Gophers!"
830 resBody := "Hi, Gophers!"
831 for _, timeout := range []time.Duration{0, -1} {
832 cst := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
833 ctl := NewResponseController(res)
834 ctl.EnableFullDuplex()
835 res.WriteHeader(StatusOK)
836
837
838 if err := ctl.Flush(); err != nil {
839 t.Errorf("server flush response: %v", err)
840 return
841 }
842 got, err := io.ReadAll(req.Body)
843 if string(got) != reqBody || err != nil {
844 t.Errorf("server read request body: %v; got %q, want %q", err, got, reqBody)
845 }
846 res.Write([]byte(resBody))
847 }), func(ts *httptest.Server) {
848 ts.Config.ReadTimeout = timeout
849 t.Logf("Server.Config.ReadTimeout = %d", timeout)
850 })
851
852 pr, pw := io.Pipe()
853 res, err := cst.c.Post(cst.ts.URL, "text/plain", pr)
854 if err != nil {
855 t.Fatal(err)
856 }
857 defer res.Body.Close()
858
859
860 time.Sleep(10 * time.Millisecond)
861 pw.Write([]byte(reqBody))
862 pw.Close()
863
864 got, err := io.ReadAll(res.Body)
865 if string(got) != resBody || err != nil {
866 t.Errorf("client read response body: %v; got %v, want %q", err, got, resBody)
867 }
868 }
869 }
870
871 func TestServerWriteTimeout(t *testing.T) { run(t, testServerWriteTimeout) }
872 func testServerWriteTimeout(t *testing.T, mode testMode) {
873 for timeout := 5 * time.Millisecond; ; timeout *= 2 {
874 errc := make(chan error, 2)
875 cst := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
876 errc <- nil
877 _, err := io.Copy(res, neverEnding('a'))
878 errc <- err
879 }), func(ts *httptest.Server) {
880 ts.Config.WriteTimeout = timeout
881 t.Logf("Server.Config.WriteTimeout = %v", timeout)
882 })
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901 var retries atomic.Int32
902 cst.c.Transport.(*Transport).Proxy = func(*Request) (*url.URL, error) {
903 if retries.Add(1) != 1 {
904 return nil, errors.New("too many retries")
905 }
906 return nil, nil
907 }
908
909 res, err := cst.c.Get(cst.ts.URL)
910 if err != nil {
911
912 t.Logf("Get error, retrying: %v", err)
913 cst.close()
914 continue
915 }
916 defer res.Body.Close()
917 _, err = io.Copy(io.Discard, res.Body)
918 if err == nil {
919 t.Errorf("client reading from truncated request body: got nil error, want non-nil")
920 }
921 select {
922 case <-errc:
923 err = <-errc
924 if !errors.Is(err, os.ErrDeadlineExceeded) {
925 t.Errorf("server timed out writing request body: got err %v; want os.ErrDeadlineExceeded", err)
926 }
927 return
928 default:
929
930 t.Logf("handler didn't run, retrying")
931 cst.close()
932 }
933 }
934 }
935
936 func TestServerNoWriteTimeout(t *testing.T) { run(t, testServerNoWriteTimeout) }
937 func testServerNoWriteTimeout(t *testing.T, mode testMode) {
938 for _, timeout := range []time.Duration{0, -1} {
939 cst := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
940 _, err := io.Copy(res, neverEnding('a'))
941 t.Logf("server write response: %v", err)
942 }), func(ts *httptest.Server) {
943 ts.Config.WriteTimeout = timeout
944 t.Logf("Server.Config.WriteTimeout = %d", timeout)
945 })
946
947 res, err := cst.c.Get(cst.ts.URL)
948 if err != nil {
949 t.Fatal(err)
950 }
951 defer res.Body.Close()
952 n, err := io.CopyN(io.Discard, res.Body, 1<<20)
953 if n != 1<<20 || err != nil {
954 t.Errorf("client read response body: %d, %v", n, err)
955 }
956
957
958 res.Body.Close()
959 cst.ts.Config.Shutdown(context.Background())
960 }
961 }
962
963
964 func TestWriteDeadlineExtendedOnNewRequest(t *testing.T) {
965 run(t, testWriteDeadlineExtendedOnNewRequest)
966 }
967 func testWriteDeadlineExtendedOnNewRequest(t *testing.T, mode testMode) {
968 if testing.Short() {
969 t.Skip("skipping in short mode")
970 }
971 ts := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {}),
972 func(ts *httptest.Server) {
973 ts.Config.WriteTimeout = 250 * time.Millisecond
974 },
975 ).ts
976
977 c := ts.Client()
978
979 for i := 1; i <= 3; i++ {
980 req, err := NewRequest("GET", ts.URL, nil)
981 if err != nil {
982 t.Fatal(err)
983 }
984
985 r, err := c.Do(req)
986 if err != nil {
987 t.Fatalf("http2 Get #%d: %v", i, err)
988 }
989 r.Body.Close()
990 time.Sleep(ts.Config.WriteTimeout / 2)
991 }
992 }
993
994
995
996 func tryTimeouts(t *testing.T, testFunc func(timeout time.Duration) error) {
997 tries := []time.Duration{250 * time.Millisecond, 500 * time.Millisecond, 1 * time.Second}
998 for i, timeout := range tries {
999 err := testFunc(timeout)
1000 if err == nil {
1001 return
1002 }
1003 t.Logf("failed at %v: %v", timeout, err)
1004 if i != len(tries)-1 {
1005 t.Logf("retrying at %v ...", tries[i+1])
1006 }
1007 }
1008 t.Fatal("all attempts failed")
1009 }
1010
1011
1012 func TestWriteDeadlineEnforcedPerStream(t *testing.T) {
1013 if testing.Short() {
1014 t.Skip("skipping in short mode")
1015 }
1016 setParallel(t)
1017 run(t, func(t *testing.T, mode testMode) {
1018 tryTimeouts(t, func(timeout time.Duration) error {
1019 return testWriteDeadlineEnforcedPerStream(t, mode, timeout)
1020 })
1021 })
1022 }
1023
1024 func testWriteDeadlineEnforcedPerStream(t *testing.T, mode testMode, timeout time.Duration) error {
1025 firstRequest := make(chan bool, 1)
1026 cst := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
1027 select {
1028 case firstRequest <- true:
1029
1030 default:
1031
1032 time.Sleep(timeout)
1033 }
1034 }), func(ts *httptest.Server) {
1035 ts.Config.WriteTimeout = timeout / 2
1036 })
1037 defer cst.close()
1038 ts := cst.ts
1039
1040 c := ts.Client()
1041
1042 req, err := NewRequest("GET", ts.URL, nil)
1043 if err != nil {
1044 return fmt.Errorf("NewRequest: %v", err)
1045 }
1046 r, err := c.Do(req)
1047 if err != nil {
1048 return fmt.Errorf("Get #1: %v", err)
1049 }
1050 r.Body.Close()
1051
1052 req, err = NewRequest("GET", ts.URL, nil)
1053 if err != nil {
1054 return fmt.Errorf("NewRequest: %v", err)
1055 }
1056 r, err = c.Do(req)
1057 if err == nil {
1058 r.Body.Close()
1059 return fmt.Errorf("Get #2 expected error, got nil")
1060 }
1061 if mode == http2Mode {
1062 expected := "stream ID 3; INTERNAL_ERROR"
1063 if !strings.Contains(err.Error(), expected) {
1064 return fmt.Errorf("http2 Get #2: expected error to contain %q, got %q", expected, err)
1065 }
1066 }
1067 return nil
1068 }
1069
1070
1071 func TestNoWriteDeadline(t *testing.T) {
1072 if testing.Short() {
1073 t.Skip("skipping in short mode")
1074 }
1075 setParallel(t)
1076 defer afterTest(t)
1077 run(t, func(t *testing.T, mode testMode) {
1078 tryTimeouts(t, func(timeout time.Duration) error {
1079 return testNoWriteDeadline(t, mode, timeout)
1080 })
1081 })
1082 }
1083
1084 func testNoWriteDeadline(t *testing.T, mode testMode, timeout time.Duration) error {
1085 firstRequest := make(chan bool, 1)
1086 cst := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
1087 select {
1088 case firstRequest <- true:
1089
1090 default:
1091
1092 time.Sleep(timeout)
1093 }
1094 }))
1095 defer cst.close()
1096 ts := cst.ts
1097
1098 c := ts.Client()
1099
1100 for i := 0; i < 2; i++ {
1101 req, err := NewRequest("GET", ts.URL, nil)
1102 if err != nil {
1103 return fmt.Errorf("NewRequest: %v", err)
1104 }
1105 r, err := c.Do(req)
1106 if err != nil {
1107 return fmt.Errorf("Get #%d: %v", i, err)
1108 }
1109 r.Body.Close()
1110 }
1111 return nil
1112 }
1113
1114
1115
1116
1117 func TestOnlyWriteTimeout(t *testing.T) { run(t, testOnlyWriteTimeout, []testMode{http1Mode}) }
1118 func testOnlyWriteTimeout(t *testing.T, mode testMode) {
1119 var (
1120 mu sync.RWMutex
1121 conn net.Conn
1122 )
1123 var afterTimeoutErrc = make(chan error, 1)
1124 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, req *Request) {
1125 buf := make([]byte, 512<<10)
1126 _, err := w.Write(buf)
1127 if err != nil {
1128 t.Errorf("handler Write error: %v", err)
1129 return
1130 }
1131 mu.RLock()
1132 defer mu.RUnlock()
1133 if conn == nil {
1134 t.Error("no established connection found")
1135 return
1136 }
1137 conn.SetWriteDeadline(time.Now().Add(-30 * time.Second))
1138 _, err = w.Write(buf)
1139 afterTimeoutErrc <- err
1140 }), func(ts *httptest.Server) {
1141 ts.Listener = trackLastConnListener{ts.Listener, &mu, &conn}
1142 }).ts
1143
1144 c := ts.Client()
1145
1146 err := func() error {
1147 res, err := c.Get(ts.URL)
1148 if err != nil {
1149 return err
1150 }
1151 _, err = io.Copy(io.Discard, res.Body)
1152 res.Body.Close()
1153 return err
1154 }()
1155 if err == nil {
1156 t.Errorf("expected an error copying body from Get request")
1157 }
1158
1159 if err := <-afterTimeoutErrc; err == nil {
1160 t.Error("expected write error after timeout")
1161 }
1162 }
1163
1164
1165 type trackLastConnListener struct {
1166 net.Listener
1167
1168 mu *sync.RWMutex
1169 last *net.Conn
1170 }
1171
1172 func (l trackLastConnListener) Accept() (c net.Conn, err error) {
1173 c, err = l.Listener.Accept()
1174 if err == nil {
1175 l.mu.Lock()
1176 *l.last = c
1177 l.mu.Unlock()
1178 }
1179 return
1180 }
1181
1182
1183 func TestIdentityResponse(t *testing.T) { run(t, testIdentityResponse) }
1184 func testIdentityResponse(t *testing.T, mode testMode) {
1185 if mode == http2Mode {
1186 t.Skip("https://go.dev/issue/56019")
1187 }
1188
1189 handler := HandlerFunc(func(rw ResponseWriter, req *Request) {
1190 rw.Header().Set("Content-Length", "3")
1191 rw.Header().Set("Transfer-Encoding", req.FormValue("te"))
1192 switch {
1193 case req.FormValue("overwrite") == "1":
1194 _, err := rw.Write([]byte("foo TOO LONG"))
1195 if err != ErrContentLength {
1196 t.Errorf("expected ErrContentLength; got %v", err)
1197 }
1198 case req.FormValue("underwrite") == "1":
1199 rw.Header().Set("Content-Length", "500")
1200 rw.Write([]byte("too short"))
1201 default:
1202 rw.Write([]byte("foo"))
1203 }
1204 })
1205
1206 ts := newClientServerTest(t, mode, handler).ts
1207 c := ts.Client()
1208
1209
1210
1211
1212
1213 for _, te := range []string{"", "identity"} {
1214 url := ts.URL + "/?te=" + te
1215 res, err := c.Get(url)
1216 if err != nil {
1217 t.Fatalf("error with Get of %s: %v", url, err)
1218 }
1219 if cl, expected := res.ContentLength, int64(3); cl != expected {
1220 t.Errorf("for %s expected res.ContentLength of %d; got %d", url, expected, cl)
1221 }
1222 if cl, expected := res.Header.Get("Content-Length"), "3"; cl != expected {
1223 t.Errorf("for %s expected Content-Length header of %q; got %q", url, expected, cl)
1224 }
1225 if tl, expected := len(res.TransferEncoding), 0; tl != expected {
1226 t.Errorf("for %s expected len(res.TransferEncoding) of %d; got %d (%v)",
1227 url, expected, tl, res.TransferEncoding)
1228 }
1229 res.Body.Close()
1230 }
1231
1232
1233 url := ts.URL + "/?overwrite=1"
1234 res, err := c.Get(url)
1235 if err != nil {
1236 t.Fatalf("error with Get of %s: %v", url, err)
1237 }
1238 res.Body.Close()
1239
1240 if mode != http1Mode {
1241 return
1242 }
1243
1244
1245
1246 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
1247 if err != nil {
1248 t.Fatalf("error dialing: %v", err)
1249 }
1250 _, err = conn.Write([]byte("GET /?underwrite=1 HTTP/1.1\r\nHost: foo\r\n\r\n"))
1251 if err != nil {
1252 t.Fatalf("error writing: %v", err)
1253 }
1254
1255
1256 got, _ := io.ReadAll(conn)
1257 expectedSuffix := "\r\n\r\ntoo short"
1258 if !strings.HasSuffix(string(got), expectedSuffix) {
1259 t.Errorf("Expected output to end with %q; got response body %q",
1260 expectedSuffix, string(got))
1261 }
1262 }
1263
1264 func testTCPConnectionCloses(t *testing.T, req string, h Handler) {
1265 setParallel(t)
1266 s := newClientServerTest(t, http1Mode, h).ts
1267
1268 conn, err := net.Dial("tcp", s.Listener.Addr().String())
1269 if err != nil {
1270 t.Fatal("dial error:", err)
1271 }
1272 defer conn.Close()
1273
1274 _, err = fmt.Fprint(conn, req)
1275 if err != nil {
1276 t.Fatal("print error:", err)
1277 }
1278
1279 r := bufio.NewReader(conn)
1280 res, err := ReadResponse(r, &Request{Method: "GET"})
1281 if err != nil {
1282 t.Fatal("ReadResponse error:", err)
1283 }
1284
1285 _, err = io.ReadAll(r)
1286 if err != nil {
1287 t.Fatal("read error:", err)
1288 }
1289
1290 if !res.Close {
1291 t.Errorf("Response.Close = false; want true")
1292 }
1293 }
1294
1295 func testTCPConnectionStaysOpen(t *testing.T, req string, handler Handler) {
1296 setParallel(t)
1297 ts := newClientServerTest(t, http1Mode, handler).ts
1298 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
1299 if err != nil {
1300 t.Fatal(err)
1301 }
1302 defer conn.Close()
1303 br := bufio.NewReader(conn)
1304 for i := 0; i < 2; i++ {
1305 if _, err := io.WriteString(conn, req); err != nil {
1306 t.Fatal(err)
1307 }
1308 res, err := ReadResponse(br, nil)
1309 if err != nil {
1310 t.Fatalf("res %d: %v", i+1, err)
1311 }
1312 if _, err := io.Copy(io.Discard, res.Body); err != nil {
1313 t.Fatalf("res %d body copy: %v", i+1, err)
1314 }
1315 res.Body.Close()
1316 }
1317 }
1318
1319
1320 func TestServeHTTP10Close(t *testing.T) {
1321 testTCPConnectionCloses(t, "GET / HTTP/1.0\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
1322 ServeFile(w, r, "testdata/file")
1323 }))
1324 }
1325
1326
1327 func TestClientCanClose(t *testing.T) {
1328 testTCPConnectionCloses(t, "GET / HTTP/1.1\r\nHost: foo\r\nConnection: close\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
1329
1330 }))
1331 }
1332
1333
1334
1335 func TestHandlersCanSetConnectionClose11(t *testing.T) {
1336 testTCPConnectionCloses(t, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
1337 w.Header().Set("Connection", "close")
1338 }))
1339 }
1340
1341 func TestHandlersCanSetConnectionClose10(t *testing.T) {
1342 testTCPConnectionCloses(t, "GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
1343 w.Header().Set("Connection", "close")
1344 }))
1345 }
1346
1347 func TestHTTP2UpgradeClosesConnection(t *testing.T) {
1348 testTCPConnectionCloses(t, "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
1349
1350
1351 }))
1352 }
1353
1354 func send204(w ResponseWriter, r *Request) { w.WriteHeader(204) }
1355 func send304(w ResponseWriter, r *Request) { w.WriteHeader(304) }
1356
1357
1358 func TestHTTP10KeepAlive204Response(t *testing.T) {
1359 testTCPConnectionStaysOpen(t, "GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n", HandlerFunc(send204))
1360 }
1361
1362 func TestHTTP11KeepAlive204Response(t *testing.T) {
1363 testTCPConnectionStaysOpen(t, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n", HandlerFunc(send204))
1364 }
1365
1366 func TestHTTP10KeepAlive304Response(t *testing.T) {
1367 testTCPConnectionStaysOpen(t,
1368 "GET / HTTP/1.0\r\nConnection: keep-alive\r\nIf-Modified-Since: Mon, 02 Jan 2006 15:04:05 GMT\r\n\r\n",
1369 HandlerFunc(send304))
1370 }
1371
1372
1373 func TestKeepAliveFinalChunkWithEOF(t *testing.T) { run(t, testKeepAliveFinalChunkWithEOF) }
1374 func testKeepAliveFinalChunkWithEOF(t *testing.T, mode testMode) {
1375 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1376 w.(Flusher).Flush()
1377 w.Write([]byte("{\"Addr\": \"" + r.RemoteAddr + "\"}"))
1378 }))
1379 type data struct {
1380 Addr string
1381 }
1382 var addrs [2]data
1383 for i := range addrs {
1384 res, err := cst.c.Get(cst.ts.URL)
1385 if err != nil {
1386 t.Fatal(err)
1387 }
1388 if err := json.NewDecoder(res.Body).Decode(&addrs[i]); err != nil {
1389 t.Fatal(err)
1390 }
1391 if addrs[i].Addr == "" {
1392 t.Fatal("no address")
1393 }
1394 res.Body.Close()
1395 }
1396 if addrs[0] != addrs[1] {
1397 t.Fatalf("connection not reused")
1398 }
1399 }
1400
1401 func TestSetsRemoteAddr(t *testing.T) { run(t, testSetsRemoteAddr) }
1402 func testSetsRemoteAddr(t *testing.T, mode testMode) {
1403 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1404 fmt.Fprintf(w, "%s", r.RemoteAddr)
1405 }))
1406
1407 res, err := cst.c.Get(cst.ts.URL)
1408 if err != nil {
1409 t.Fatalf("Get error: %v", err)
1410 }
1411 body, err := io.ReadAll(res.Body)
1412 if err != nil {
1413 t.Fatalf("ReadAll error: %v", err)
1414 }
1415 ip := string(body)
1416 if !strings.HasPrefix(ip, "127.0.0.1:") && !strings.HasPrefix(ip, "[::1]:") {
1417 t.Fatalf("Expected local addr; got %q", ip)
1418 }
1419 }
1420
1421 type blockingRemoteAddrListener struct {
1422 net.Listener
1423 conns chan<- net.Conn
1424 }
1425
1426 func (l *blockingRemoteAddrListener) Accept() (net.Conn, error) {
1427 c, err := l.Listener.Accept()
1428 if err != nil {
1429 return nil, err
1430 }
1431 brac := &blockingRemoteAddrConn{
1432 Conn: c,
1433 addrs: make(chan net.Addr, 1),
1434 }
1435 l.conns <- brac
1436 return brac, nil
1437 }
1438
1439 type blockingRemoteAddrConn struct {
1440 net.Conn
1441 addrs chan net.Addr
1442 }
1443
1444 func (c *blockingRemoteAddrConn) RemoteAddr() net.Addr {
1445 return <-c.addrs
1446 }
1447
1448
1449 func TestServerAllowsBlockingRemoteAddr(t *testing.T) {
1450 run(t, testServerAllowsBlockingRemoteAddr, []testMode{http1Mode})
1451 }
1452 func testServerAllowsBlockingRemoteAddr(t *testing.T, mode testMode) {
1453 conns := make(chan net.Conn)
1454 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1455 fmt.Fprintf(w, "RA:%s", r.RemoteAddr)
1456 }), func(ts *httptest.Server) {
1457 ts.Listener = &blockingRemoteAddrListener{
1458 Listener: ts.Listener,
1459 conns: conns,
1460 }
1461 }).ts
1462
1463 c := ts.Client()
1464
1465 c.Transport.(*Transport).DisableKeepAlives = true
1466
1467 fetch := func(num int, response chan<- string) {
1468 resp, err := c.Get(ts.URL)
1469 if err != nil {
1470 t.Errorf("Request %d: %v", num, err)
1471 response <- ""
1472 return
1473 }
1474 defer resp.Body.Close()
1475 body, err := io.ReadAll(resp.Body)
1476 if err != nil {
1477 t.Errorf("Request %d: %v", num, err)
1478 response <- ""
1479 return
1480 }
1481 response <- string(body)
1482 }
1483
1484
1485 response1c := make(chan string, 1)
1486 go fetch(1, response1c)
1487
1488
1489 conn1 := <-conns
1490
1491
1492 response2c := make(chan string, 1)
1493 go fetch(2, response2c)
1494 conn2 := <-conns
1495
1496
1497 conn2.(*blockingRemoteAddrConn).addrs <- &net.TCPAddr{
1498 IP: net.ParseIP("12.12.12.12"), Port: 12}
1499
1500
1501 response2 := <-response2c
1502 if g, e := response2, "RA:12.12.12.12:12"; g != e {
1503 t.Fatalf("response 2 addr = %q; want %q", g, e)
1504 }
1505
1506
1507 conn1.(*blockingRemoteAddrConn).addrs <- &net.TCPAddr{
1508 IP: net.ParseIP("21.21.21.21"), Port: 21}
1509
1510
1511 response1 := <-response1c
1512 if g, e := response1, "RA:21.21.21.21:21"; g != e {
1513 t.Fatalf("response 1 addr = %q; want %q", g, e)
1514 }
1515 }
1516
1517
1518
1519 func TestHeadResponses(t *testing.T) { run(t, testHeadResponses) }
1520 func testHeadResponses(t *testing.T, mode testMode) {
1521 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1522 _, err := w.Write([]byte("<html>"))
1523 if err != nil {
1524 t.Errorf("ResponseWriter.Write: %v", err)
1525 }
1526
1527
1528 _, err = io.Copy(w, struct{ io.Reader }{strings.NewReader("789a")})
1529 if err != nil {
1530 t.Errorf("Copy(ResponseWriter, ...): %v", err)
1531 }
1532 }))
1533 res, err := cst.c.Head(cst.ts.URL)
1534 if err != nil {
1535 t.Error(err)
1536 }
1537 if len(res.TransferEncoding) > 0 {
1538 t.Errorf("expected no TransferEncoding; got %v", res.TransferEncoding)
1539 }
1540 if ct := res.Header.Get("Content-Type"); ct != "text/html; charset=utf-8" {
1541 t.Errorf("Content-Type: %q; want text/html; charset=utf-8", ct)
1542 }
1543 if v := res.ContentLength; v != 10 {
1544 t.Errorf("Content-Length: %d; want 10", v)
1545 }
1546 body, err := io.ReadAll(res.Body)
1547 if err != nil {
1548 t.Error(err)
1549 }
1550 if len(body) > 0 {
1551 t.Errorf("got unexpected body %q", string(body))
1552 }
1553 }
1554
1555
1556
1557 func TestHeadReaderFrom(t *testing.T) { run(t, testHeadReaderFrom, []testMode{http1Mode}) }
1558 func testHeadReaderFrom(t *testing.T, mode testMode) {
1559
1560 wantBody := strings.Repeat("a", 4096)
1561 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1562 w.(io.ReaderFrom).ReadFrom(strings.NewReader(wantBody))
1563 }))
1564 res, err := cst.c.Head(cst.ts.URL)
1565 if err != nil {
1566 t.Fatal(err)
1567 }
1568 res.Body.Close()
1569 res, err = cst.c.Get(cst.ts.URL)
1570 if err != nil {
1571 t.Fatal(err)
1572 }
1573 gotBody, err := io.ReadAll(res.Body)
1574 res.Body.Close()
1575 if err != nil {
1576 t.Fatal(err)
1577 }
1578 if string(gotBody) != wantBody {
1579 t.Errorf("got unexpected body len=%v, want %v", len(gotBody), len(wantBody))
1580 }
1581 }
1582
1583 func TestTLSHandshakeTimeout(t *testing.T) {
1584 run(t, testTLSHandshakeTimeout, []testMode{https1Mode, http2Mode})
1585 }
1586 func testTLSHandshakeTimeout(t *testing.T, mode testMode) {
1587 errLog := new(strings.Builder)
1588 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}),
1589 func(ts *httptest.Server) {
1590 ts.Config.ReadTimeout = 250 * time.Millisecond
1591 ts.Config.ErrorLog = log.New(errLog, "", 0)
1592 },
1593 )
1594 ts := cst.ts
1595
1596 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
1597 if err != nil {
1598 t.Fatalf("Dial: %v", err)
1599 }
1600 var buf [1]byte
1601 n, err := conn.Read(buf[:])
1602 if err == nil || n != 0 {
1603 t.Errorf("Read = %d, %v; want an error and no bytes", n, err)
1604 }
1605 conn.Close()
1606
1607 cst.close()
1608 if v := errLog.String(); !strings.Contains(v, "timeout") && !strings.Contains(v, "TLS handshake") {
1609 t.Errorf("expected a TLS handshake timeout error; got %q", v)
1610 }
1611 }
1612
1613 func TestTLSServer(t *testing.T) { run(t, testTLSServer, []testMode{https1Mode, http2Mode}) }
1614 func testTLSServer(t *testing.T, mode testMode) {
1615 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1616 if r.TLS != nil {
1617 w.Header().Set("X-TLS-Set", "true")
1618 if r.TLS.HandshakeComplete {
1619 w.Header().Set("X-TLS-HandshakeComplete", "true")
1620 }
1621 }
1622 }), func(ts *httptest.Server) {
1623 ts.Config.ErrorLog = log.New(io.Discard, "", 0)
1624 }).ts
1625
1626
1627
1628
1629
1630
1631 idleConn, err := net.Dial("tcp", ts.Listener.Addr().String())
1632 if err != nil {
1633 t.Fatalf("Dial: %v", err)
1634 }
1635 defer idleConn.Close()
1636
1637 if !strings.HasPrefix(ts.URL, "https://") {
1638 t.Errorf("expected test TLS server to start with https://, got %q", ts.URL)
1639 return
1640 }
1641 client := ts.Client()
1642 res, err := client.Get(ts.URL)
1643 if err != nil {
1644 t.Error(err)
1645 return
1646 }
1647 if res == nil {
1648 t.Errorf("got nil Response")
1649 return
1650 }
1651 defer res.Body.Close()
1652 if res.Header.Get("X-TLS-Set") != "true" {
1653 t.Errorf("expected X-TLS-Set response header")
1654 return
1655 }
1656 if res.Header.Get("X-TLS-HandshakeComplete") != "true" {
1657 t.Errorf("expected X-TLS-HandshakeComplete header")
1658 }
1659 }
1660
1661 type fakeConnectionStateConn struct {
1662 net.Conn
1663 }
1664
1665 func (fcsc *fakeConnectionStateConn) ConnectionState() tls.ConnectionState {
1666 return tls.ConnectionState{
1667 ServerName: "example.com",
1668 }
1669 }
1670
1671 func TestTLSServerWithoutTLSConn(t *testing.T) {
1672
1673 pr, pw := net.Pipe()
1674 c := make(chan int)
1675 listener := &oneConnListener{&fakeConnectionStateConn{pr}}
1676 server := &Server{
1677 Handler: HandlerFunc(func(writer ResponseWriter, request *Request) {
1678 if request.TLS == nil {
1679 t.Fatal("request.TLS is nil, expected not nil")
1680 }
1681 if request.TLS.ServerName != "example.com" {
1682 t.Fatalf("request.TLS.ServerName is %s, expected %s", request.TLS.ServerName, "example.com")
1683 }
1684 writer.Header().Set("X-TLS-ServerName", "example.com")
1685 }),
1686 }
1687
1688
1689 go func() {
1690 req, _ := NewRequest(MethodGet, "https://example.com", nil)
1691 req.Write(pw)
1692
1693 resp, _ := ReadResponse(bufio.NewReader(pw), req)
1694 if hdr := resp.Header.Get("X-TLS-ServerName"); hdr != "example.com" {
1695 t.Errorf("response header X-TLS-ServerName is %s, expected %s", hdr, "example.com")
1696 }
1697 close(c)
1698 pw.Close()
1699 }()
1700
1701 server.Serve(listener)
1702
1703
1704 <-c
1705 pr.Close()
1706 }
1707
1708 func TestServeTLS(t *testing.T) {
1709 CondSkipHTTP2(t)
1710
1711 defer afterTest(t)
1712 defer SetTestHookServerServe(nil)
1713
1714 cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
1715 if err != nil {
1716 t.Fatal(err)
1717 }
1718 tlsConf := &tls.Config{
1719 Certificates: []tls.Certificate{cert},
1720 }
1721
1722 ln := newLocalListener(t)
1723 defer ln.Close()
1724 addr := ln.Addr().String()
1725
1726 serving := make(chan bool, 1)
1727 SetTestHookServerServe(func(s *Server, ln net.Listener) {
1728 serving <- true
1729 })
1730 handler := HandlerFunc(func(w ResponseWriter, r *Request) {})
1731 s := &Server{
1732 Addr: addr,
1733 TLSConfig: tlsConf,
1734 Handler: handler,
1735 }
1736 errc := make(chan error, 1)
1737 go func() { errc <- s.ServeTLS(ln, "", "") }()
1738 select {
1739 case err := <-errc:
1740 t.Fatalf("ServeTLS: %v", err)
1741 case <-serving:
1742 }
1743
1744 c, err := tls.Dial("tcp", ln.Addr().String(), &tls.Config{
1745 InsecureSkipVerify: true,
1746 NextProtos: []string{"h2", "http/1.1"},
1747 })
1748 if err != nil {
1749 t.Fatal(err)
1750 }
1751 defer c.Close()
1752 if got, want := c.ConnectionState().NegotiatedProtocol, "h2"; got != want {
1753 t.Errorf("NegotiatedProtocol = %q; want %q", got, want)
1754 }
1755 if got, want := c.ConnectionState().NegotiatedProtocolIsMutual, true; got != want {
1756 t.Errorf("NegotiatedProtocolIsMutual = %v; want %v", got, want)
1757 }
1758 }
1759
1760
1761 func TestTLSServerRejectHTTPRequests(t *testing.T) {
1762 run(t, testTLSServerRejectHTTPRequests, []testMode{https1Mode, http2Mode})
1763 }
1764 func testTLSServerRejectHTTPRequests(t *testing.T, mode testMode) {
1765 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1766 t.Error("unexpected HTTPS request")
1767 }), func(ts *httptest.Server) {
1768 var errBuf bytes.Buffer
1769 ts.Config.ErrorLog = log.New(&errBuf, "", 0)
1770 }).ts
1771 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
1772 if err != nil {
1773 t.Fatal(err)
1774 }
1775 defer conn.Close()
1776 io.WriteString(conn, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n")
1777 slurp, err := io.ReadAll(conn)
1778 if err != nil {
1779 t.Fatal(err)
1780 }
1781 const wantPrefix = "HTTP/1.0 400 Bad Request\r\n"
1782 if !strings.HasPrefix(string(slurp), wantPrefix) {
1783 t.Errorf("response = %q; wanted prefix %q", slurp, wantPrefix)
1784 }
1785 }
1786
1787
1788 func TestAutomaticHTTP2_Serve_NoTLSConfig(t *testing.T) {
1789 testAutomaticHTTP2_Serve(t, nil, true)
1790 }
1791
1792 func TestAutomaticHTTP2_Serve_NonH2TLSConfig(t *testing.T) {
1793 testAutomaticHTTP2_Serve(t, &tls.Config{}, false)
1794 }
1795
1796 func TestAutomaticHTTP2_Serve_H2TLSConfig(t *testing.T) {
1797 testAutomaticHTTP2_Serve(t, &tls.Config{NextProtos: []string{"h2"}}, true)
1798 }
1799
1800 func testAutomaticHTTP2_Serve(t *testing.T, tlsConf *tls.Config, wantH2 bool) {
1801 setParallel(t)
1802 defer afterTest(t)
1803 ln := newLocalListener(t)
1804 ln.Close()
1805 var s Server
1806 s.TLSConfig = tlsConf
1807 if err := s.Serve(ln); err == nil {
1808 t.Fatal("expected an error")
1809 }
1810 gotH2 := s.TLSNextProto["h2"] != nil
1811 if gotH2 != wantH2 {
1812 t.Errorf("http2 configured = %v; want %v", gotH2, wantH2)
1813 }
1814 }
1815
1816 func TestAutomaticHTTP2_Serve_WithTLSConfig(t *testing.T) {
1817 setParallel(t)
1818 defer afterTest(t)
1819 ln := newLocalListener(t)
1820 ln.Close()
1821 var s Server
1822
1823
1824 s.TLSConfig = &tls.Config{
1825 NextProtos: []string{"h2"},
1826 }
1827 if err := s.Serve(ln); err == nil {
1828 t.Fatal("expected an error")
1829 }
1830 on := s.TLSNextProto["h2"] != nil
1831 if !on {
1832 t.Errorf("http2 wasn't automatically enabled")
1833 }
1834 }
1835
1836 func TestAutomaticHTTP2_ListenAndServe(t *testing.T) {
1837 cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
1838 if err != nil {
1839 t.Fatal(err)
1840 }
1841 testAutomaticHTTP2_ListenAndServe(t, &tls.Config{
1842 Certificates: []tls.Certificate{cert},
1843 })
1844 }
1845
1846 func TestAutomaticHTTP2_ListenAndServe_GetCertificate(t *testing.T) {
1847 cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
1848 if err != nil {
1849 t.Fatal(err)
1850 }
1851 testAutomaticHTTP2_ListenAndServe(t, &tls.Config{
1852 GetCertificate: func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
1853 return &cert, nil
1854 },
1855 })
1856 }
1857
1858 func TestAutomaticHTTP2_ListenAndServe_GetConfigForClient(t *testing.T) {
1859 cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
1860 if err != nil {
1861 t.Fatal(err)
1862 }
1863 conf := &tls.Config{
1864
1865
1866 NextProtos: []string{"h2"},
1867 Certificates: []tls.Certificate{cert},
1868 }
1869 testAutomaticHTTP2_ListenAndServe(t, &tls.Config{
1870 GetConfigForClient: func(clientHello *tls.ClientHelloInfo) (*tls.Config, error) {
1871 return conf, nil
1872 },
1873 })
1874 }
1875
1876 func testAutomaticHTTP2_ListenAndServe(t *testing.T, tlsConf *tls.Config) {
1877 CondSkipHTTP2(t)
1878
1879 defer afterTest(t)
1880 defer SetTestHookServerServe(nil)
1881 var ok bool
1882 var s *Server
1883 const maxTries = 5
1884 var ln net.Listener
1885 Try:
1886 for try := 0; try < maxTries; try++ {
1887 ln = newLocalListener(t)
1888 addr := ln.Addr().String()
1889 ln.Close()
1890 t.Logf("Got %v", addr)
1891 lnc := make(chan net.Listener, 1)
1892 SetTestHookServerServe(func(s *Server, ln net.Listener) {
1893 lnc <- ln
1894 })
1895 s = &Server{
1896 Addr: addr,
1897 TLSConfig: tlsConf,
1898 }
1899 errc := make(chan error, 1)
1900 go func() { errc <- s.ListenAndServeTLS("", "") }()
1901 select {
1902 case err := <-errc:
1903 t.Logf("On try #%v: %v", try+1, err)
1904 continue
1905 case ln = <-lnc:
1906 ok = true
1907 t.Logf("Listening on %v", ln.Addr().String())
1908 break Try
1909 }
1910 }
1911 if !ok {
1912 t.Fatalf("Failed to start up after %d tries", maxTries)
1913 }
1914 defer ln.Close()
1915 c, err := tls.Dial("tcp", ln.Addr().String(), &tls.Config{
1916 InsecureSkipVerify: true,
1917 NextProtos: []string{"h2", "http/1.1"},
1918 })
1919 if err != nil {
1920 t.Fatal(err)
1921 }
1922 defer c.Close()
1923 if got, want := c.ConnectionState().NegotiatedProtocol, "h2"; got != want {
1924 t.Errorf("NegotiatedProtocol = %q; want %q", got, want)
1925 }
1926 if got, want := c.ConnectionState().NegotiatedProtocolIsMutual, true; got != want {
1927 t.Errorf("NegotiatedProtocolIsMutual = %v; want %v", got, want)
1928 }
1929 }
1930
1931 type serverExpectTest struct {
1932 contentLength int
1933 chunked bool
1934 expectation string
1935 readBody bool
1936 expectedResponse string
1937 }
1938
1939 func expectTest(contentLength int, expectation string, readBody bool, expectedResponse string) serverExpectTest {
1940 return serverExpectTest{
1941 contentLength: contentLength,
1942 expectation: expectation,
1943 readBody: readBody,
1944 expectedResponse: expectedResponse,
1945 }
1946 }
1947
1948 var serverExpectTests = []serverExpectTest{
1949
1950 expectTest(100, "100-continue", true, "100 Continue"),
1951 expectTest(100, "100-cOntInUE", true, "100 Continue"),
1952
1953
1954 expectTest(100, "", true, "200 OK"),
1955
1956
1957
1958 expectTest(100, "100-continue", false, "401 Unauthorized"),
1959
1960 expectTest(100, "", false, "401 Unauthorized"),
1961
1962
1963 expectTest(0, "a-pony", false, "417 Expectation Failed"),
1964
1965
1966 expectTest(0, "100-continue", true, "200 OK"),
1967
1968 expectTest(0, "100-continue", false, "401 Unauthorized"),
1969
1970 {
1971 expectation: "100-continue",
1972 readBody: true,
1973 chunked: true,
1974 expectedResponse: "100 Continue",
1975 },
1976 }
1977
1978
1979
1980 func TestServerExpect(t *testing.T) { run(t, testServerExpect, []testMode{http1Mode}) }
1981 func testServerExpect(t *testing.T, mode testMode) {
1982 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1983
1984
1985
1986 if strings.Contains(r.URL.RawQuery, "readbody=true") {
1987 io.ReadAll(r.Body)
1988 w.Write([]byte("Hi"))
1989 } else {
1990 w.WriteHeader(StatusUnauthorized)
1991 }
1992 })).ts
1993
1994 runTest := func(test serverExpectTest) {
1995 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
1996 if err != nil {
1997 t.Fatalf("Dial: %v", err)
1998 }
1999 defer conn.Close()
2000
2001
2002
2003 writeBody := test.contentLength != 0 && strings.ToLower(test.expectation) != "100-continue"
2004
2005 wg := sync.WaitGroup{}
2006 wg.Add(1)
2007 defer wg.Wait()
2008
2009 go func() {
2010 defer wg.Done()
2011
2012 contentLen := fmt.Sprintf("Content-Length: %d", test.contentLength)
2013 if test.chunked {
2014 contentLen = "Transfer-Encoding: chunked"
2015 }
2016 _, err := fmt.Fprintf(conn, "POST /?readbody=%v HTTP/1.1\r\n"+
2017 "Connection: close\r\n"+
2018 "%s\r\n"+
2019 "Expect: %s\r\nHost: foo\r\n\r\n",
2020 test.readBody, contentLen, test.expectation)
2021 if err != nil {
2022 t.Errorf("On test %#v, error writing request headers: %v", test, err)
2023 return
2024 }
2025 if writeBody {
2026 var targ io.WriteCloser = struct {
2027 io.Writer
2028 io.Closer
2029 }{
2030 conn,
2031 io.NopCloser(nil),
2032 }
2033 if test.chunked {
2034 targ = httputil.NewChunkedWriter(conn)
2035 }
2036 body := strings.Repeat("A", test.contentLength)
2037 _, err = fmt.Fprint(targ, body)
2038 if err == nil {
2039 err = targ.Close()
2040 }
2041 if err != nil {
2042 if !test.readBody {
2043
2044
2045 t.Logf("On test %#v, acceptable error writing request body: %v", test, err)
2046 return
2047 }
2048 t.Errorf("On test %#v, error writing request body: %v", test, err)
2049 }
2050 }
2051 }()
2052 bufr := bufio.NewReader(conn)
2053 line, err := bufr.ReadString('\n')
2054 if err != nil {
2055 if writeBody && !test.readBody {
2056
2057
2058
2059
2060
2061 t.Logf("On test %#v, acceptable error from ReadString: %v", test, err)
2062 return
2063 }
2064 t.Fatalf("On test %#v, ReadString: %v", test, err)
2065 }
2066 if !strings.Contains(line, test.expectedResponse) {
2067 t.Errorf("On test %#v, got first line = %q; want %q", test, line, test.expectedResponse)
2068 }
2069 }
2070
2071 for _, test := range serverExpectTests {
2072 runTest(test)
2073 }
2074 }
2075
2076
2077
2078 func TestServerUnreadRequestBodyLittle(t *testing.T) {
2079 setParallel(t)
2080 defer afterTest(t)
2081 conn := new(testConn)
2082 body := strings.Repeat("x", 100<<10)
2083 conn.readBuf.Write([]byte(fmt.Sprintf(
2084 "POST / HTTP/1.1\r\n"+
2085 "Host: test\r\n"+
2086 "Content-Length: %d\r\n"+
2087 "\r\n", len(body))))
2088 conn.readBuf.Write([]byte(body))
2089
2090 done := make(chan bool)
2091
2092 readBufLen := func() int {
2093 conn.readMu.Lock()
2094 defer conn.readMu.Unlock()
2095 return conn.readBuf.Len()
2096 }
2097
2098 ls := &oneConnListener{conn}
2099 go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
2100 defer close(done)
2101 if bufLen := readBufLen(); bufLen < len(body)/2 {
2102 t.Errorf("on request, read buffer length is %d; expected about 100 KB", bufLen)
2103 }
2104 rw.WriteHeader(200)
2105 rw.(Flusher).Flush()
2106 if g, e := readBufLen(), 0; g != e {
2107 t.Errorf("after WriteHeader, read buffer length is %d; want %d", g, e)
2108 }
2109 if c := rw.Header().Get("Connection"); c != "" {
2110 t.Errorf(`Connection header = %q; want ""`, c)
2111 }
2112 }))
2113 <-done
2114 }
2115
2116
2117
2118
2119 func TestServerUnreadRequestBodyLarge(t *testing.T) {
2120 setParallel(t)
2121 if testing.Short() && testenv.Builder() == "" {
2122 t.Log("skipping in short mode")
2123 }
2124 conn := new(testConn)
2125 body := strings.Repeat("x", 1<<20)
2126 conn.readBuf.Write([]byte(fmt.Sprintf(
2127 "POST / HTTP/1.1\r\n"+
2128 "Host: test\r\n"+
2129 "Content-Length: %d\r\n"+
2130 "\r\n", len(body))))
2131 conn.readBuf.Write([]byte(body))
2132 conn.closec = make(chan bool, 1)
2133
2134 ls := &oneConnListener{conn}
2135 go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
2136 if conn.readBuf.Len() < len(body)/2 {
2137 t.Errorf("on request, read buffer length is %d; expected about 1MB", conn.readBuf.Len())
2138 }
2139 rw.WriteHeader(200)
2140 rw.(Flusher).Flush()
2141 if conn.readBuf.Len() < len(body)/2 {
2142 t.Errorf("post-WriteHeader, read buffer length is %d; expected about 1MB", conn.readBuf.Len())
2143 }
2144 }))
2145 <-conn.closec
2146
2147 if res := conn.writeBuf.String(); !strings.Contains(res, "Connection: close") {
2148 t.Errorf("Expected a Connection: close header; got response: %s", res)
2149 }
2150 }
2151
2152 type handlerBodyCloseTest struct {
2153 bodySize int
2154 bodyChunked bool
2155 reqConnClose bool
2156
2157 wantEOFSearch bool
2158 wantNextReq bool
2159 }
2160
2161 func (t handlerBodyCloseTest) connectionHeader() string {
2162 if t.reqConnClose {
2163 return "Connection: close\r\n"
2164 }
2165 return ""
2166 }
2167
2168 var handlerBodyCloseTests = [...]handlerBodyCloseTest{
2169
2170
2171 0: {
2172 bodySize: 20 << 10,
2173 bodyChunked: false,
2174 reqConnClose: false,
2175 wantEOFSearch: true,
2176 wantNextReq: true,
2177 },
2178
2179
2180
2181 1: {
2182 bodySize: 20 << 10,
2183 bodyChunked: true,
2184 reqConnClose: false,
2185 wantEOFSearch: true,
2186 wantNextReq: true,
2187 },
2188
2189
2190
2191
2192 2: {
2193 bodySize: 20 << 10,
2194 bodyChunked: false,
2195 reqConnClose: true,
2196 wantEOFSearch: false,
2197 wantNextReq: false,
2198 },
2199
2200
2201
2202
2203
2204
2205 3: {
2206 bodySize: 20 << 10,
2207 bodyChunked: true,
2208 reqConnClose: true,
2209 wantEOFSearch: true,
2210 wantNextReq: false,
2211 },
2212
2213
2214 4: {
2215 bodySize: 1 << 20,
2216 bodyChunked: false,
2217 reqConnClose: false,
2218 wantEOFSearch: false,
2219 wantNextReq: false,
2220 },
2221
2222
2223 5: {
2224 bodySize: 1 << 20,
2225 bodyChunked: true,
2226 reqConnClose: false,
2227 wantEOFSearch: true,
2228 wantNextReq: false,
2229 },
2230
2231
2232
2233
2234 6: {
2235 bodySize: 1 << 20,
2236 bodyChunked: true,
2237 reqConnClose: true,
2238 wantEOFSearch: true,
2239 wantNextReq: false,
2240 },
2241
2242
2243
2244 7: {
2245 bodySize: 1 << 20,
2246 bodyChunked: false,
2247 reqConnClose: true,
2248 wantEOFSearch: false,
2249 wantNextReq: false,
2250 },
2251 }
2252
2253 func TestHandlerBodyClose(t *testing.T) {
2254 setParallel(t)
2255 if testing.Short() && testenv.Builder() == "" {
2256 t.Skip("skipping in -short mode")
2257 }
2258 for i, tt := range handlerBodyCloseTests {
2259 testHandlerBodyClose(t, i, tt)
2260 }
2261 }
2262
2263 func testHandlerBodyClose(t *testing.T, i int, tt handlerBodyCloseTest) {
2264 conn := new(testConn)
2265 body := strings.Repeat("x", tt.bodySize)
2266 if tt.bodyChunked {
2267 conn.readBuf.WriteString("POST / HTTP/1.1\r\n" +
2268 "Host: test\r\n" +
2269 tt.connectionHeader() +
2270 "Transfer-Encoding: chunked\r\n" +
2271 "\r\n")
2272 cw := internal.NewChunkedWriter(&conn.readBuf)
2273 io.WriteString(cw, body)
2274 cw.Close()
2275 conn.readBuf.WriteString("\r\n")
2276 } else {
2277 conn.readBuf.Write([]byte(fmt.Sprintf(
2278 "POST / HTTP/1.1\r\n"+
2279 "Host: test\r\n"+
2280 tt.connectionHeader()+
2281 "Content-Length: %d\r\n"+
2282 "\r\n", len(body))))
2283 conn.readBuf.Write([]byte(body))
2284 }
2285 if !tt.reqConnClose {
2286 conn.readBuf.WriteString("GET / HTTP/1.1\r\nHost: test\r\n\r\n")
2287 }
2288 conn.closec = make(chan bool, 1)
2289
2290 readBufLen := func() int {
2291 conn.readMu.Lock()
2292 defer conn.readMu.Unlock()
2293 return conn.readBuf.Len()
2294 }
2295
2296 ls := &oneConnListener{conn}
2297 var numReqs int
2298 var size0, size1 int
2299 go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
2300 numReqs++
2301 if numReqs == 1 {
2302 size0 = readBufLen()
2303 req.Body.Close()
2304 size1 = readBufLen()
2305 }
2306 }))
2307 <-conn.closec
2308 if numReqs < 1 || numReqs > 2 {
2309 t.Fatalf("%d. bug in test. unexpected number of requests = %d", i, numReqs)
2310 }
2311 didSearch := size0 != size1
2312 if didSearch != tt.wantEOFSearch {
2313 t.Errorf("%d. did EOF search = %v; want %v (size went from %d to %d)", i, didSearch, !didSearch, size0, size1)
2314 }
2315 if tt.wantNextReq && numReqs != 2 {
2316 t.Errorf("%d. numReq = %d; want 2", i, numReqs)
2317 }
2318 }
2319
2320
2321
2322 type testHandlerBodyConsumer struct {
2323 name string
2324 f func(io.ReadCloser)
2325 }
2326
2327 var testHandlerBodyConsumers = []testHandlerBodyConsumer{
2328 {"nil", func(io.ReadCloser) {}},
2329 {"close", func(r io.ReadCloser) { r.Close() }},
2330 {"discard", func(r io.ReadCloser) { io.Copy(io.Discard, r) }},
2331 }
2332
2333 func TestRequestBodyReadErrorClosesConnection(t *testing.T) {
2334 setParallel(t)
2335 defer afterTest(t)
2336 for _, handler := range testHandlerBodyConsumers {
2337 conn := new(testConn)
2338 conn.readBuf.WriteString("POST /public HTTP/1.1\r\n" +
2339 "Host: test\r\n" +
2340 "Transfer-Encoding: chunked\r\n" +
2341 "\r\n" +
2342 "hax\r\n" +
2343 "GET /secret HTTP/1.1\r\n" +
2344 "Host: test\r\n" +
2345 "\r\n")
2346
2347 conn.closec = make(chan bool, 1)
2348 ls := &oneConnListener{conn}
2349 var numReqs int
2350 go Serve(ls, HandlerFunc(func(_ ResponseWriter, req *Request) {
2351 numReqs++
2352 if strings.Contains(req.URL.Path, "secret") {
2353 t.Error("Request for /secret encountered, should not have happened.")
2354 }
2355 handler.f(req.Body)
2356 }))
2357 <-conn.closec
2358 if numReqs != 1 {
2359 t.Errorf("Handler %v: got %d reqs; want 1", handler.name, numReqs)
2360 }
2361 }
2362 }
2363
2364 func TestInvalidTrailerClosesConnection(t *testing.T) {
2365 setParallel(t)
2366 defer afterTest(t)
2367 for _, handler := range testHandlerBodyConsumers {
2368 conn := new(testConn)
2369 conn.readBuf.WriteString("POST /public HTTP/1.1\r\n" +
2370 "Host: test\r\n" +
2371 "Trailer: hack\r\n" +
2372 "Transfer-Encoding: chunked\r\n" +
2373 "\r\n" +
2374 "3\r\n" +
2375 "hax\r\n" +
2376 "0\r\n" +
2377 "I'm not a valid trailer\r\n" +
2378 "GET /secret HTTP/1.1\r\n" +
2379 "Host: test\r\n" +
2380 "\r\n")
2381
2382 conn.closec = make(chan bool, 1)
2383 ln := &oneConnListener{conn}
2384 var numReqs int
2385 go Serve(ln, HandlerFunc(func(_ ResponseWriter, req *Request) {
2386 numReqs++
2387 if strings.Contains(req.URL.Path, "secret") {
2388 t.Errorf("Handler %s, Request for /secret encountered, should not have happened.", handler.name)
2389 }
2390 handler.f(req.Body)
2391 }))
2392 <-conn.closec
2393 if numReqs != 1 {
2394 t.Errorf("Handler %s: got %d reqs; want 1", handler.name, numReqs)
2395 }
2396 }
2397 }
2398
2399
2400
2401
2402 type slowTestConn struct {
2403
2404 script []any
2405 closec chan bool
2406
2407 mu sync.Mutex
2408 rd, wd time.Time
2409 noopConn
2410 }
2411
2412 func (c *slowTestConn) SetDeadline(t time.Time) error {
2413 c.SetReadDeadline(t)
2414 c.SetWriteDeadline(t)
2415 return nil
2416 }
2417
2418 func (c *slowTestConn) SetReadDeadline(t time.Time) error {
2419 c.mu.Lock()
2420 defer c.mu.Unlock()
2421 c.rd = t
2422 return nil
2423 }
2424
2425 func (c *slowTestConn) SetWriteDeadline(t time.Time) error {
2426 c.mu.Lock()
2427 defer c.mu.Unlock()
2428 c.wd = t
2429 return nil
2430 }
2431
2432 func (c *slowTestConn) Read(b []byte) (n int, err error) {
2433 c.mu.Lock()
2434 defer c.mu.Unlock()
2435 restart:
2436 if !c.rd.IsZero() && time.Now().After(c.rd) {
2437 return 0, syscall.ETIMEDOUT
2438 }
2439 if len(c.script) == 0 {
2440 return 0, io.EOF
2441 }
2442
2443 switch cue := c.script[0].(type) {
2444 case time.Duration:
2445 if !c.rd.IsZero() {
2446
2447
2448 if remaining := time.Until(c.rd); remaining < cue {
2449 c.script[0] = cue - remaining
2450 time.Sleep(remaining)
2451 return 0, syscall.ETIMEDOUT
2452 }
2453 }
2454 c.script = c.script[1:]
2455 time.Sleep(cue)
2456 goto restart
2457
2458 case string:
2459 n = copy(b, cue)
2460
2461 if len(cue) > n {
2462 c.script[0] = cue[n:]
2463 } else {
2464 c.script = c.script[1:]
2465 }
2466
2467 default:
2468 panic("unknown cue in slowTestConn script")
2469 }
2470
2471 return
2472 }
2473
2474 func (c *slowTestConn) Close() error {
2475 select {
2476 case c.closec <- true:
2477 default:
2478 }
2479 return nil
2480 }
2481
2482 func (c *slowTestConn) Write(b []byte) (int, error) {
2483 if !c.wd.IsZero() && time.Now().After(c.wd) {
2484 return 0, syscall.ETIMEDOUT
2485 }
2486 return len(b), nil
2487 }
2488
2489 func TestRequestBodyTimeoutClosesConnection(t *testing.T) {
2490 if testing.Short() {
2491 t.Skip("skipping in -short mode")
2492 }
2493 defer afterTest(t)
2494 for _, handler := range testHandlerBodyConsumers {
2495 conn := &slowTestConn{
2496 script: []any{
2497 "POST /public HTTP/1.1\r\n" +
2498 "Host: test\r\n" +
2499 "Content-Length: 10000\r\n" +
2500 "\r\n",
2501 "foo bar baz",
2502 600 * time.Millisecond,
2503 "GET /secret HTTP/1.1\r\n" +
2504 "Host: test\r\n" +
2505 "\r\n",
2506 },
2507 closec: make(chan bool, 1),
2508 }
2509 ls := &oneConnListener{conn}
2510
2511 var numReqs int
2512 s := Server{
2513 Handler: HandlerFunc(func(_ ResponseWriter, req *Request) {
2514 numReqs++
2515 if strings.Contains(req.URL.Path, "secret") {
2516 t.Error("Request for /secret encountered, should not have happened.")
2517 }
2518 handler.f(req.Body)
2519 }),
2520 ReadTimeout: 400 * time.Millisecond,
2521 }
2522 go s.Serve(ls)
2523 <-conn.closec
2524
2525 if numReqs != 1 {
2526 t.Errorf("Handler %v: got %d reqs; want 1", handler.name, numReqs)
2527 }
2528 }
2529 }
2530
2531
2532 type cancelableTimeoutContext struct {
2533 context.Context
2534 }
2535
2536 func (c cancelableTimeoutContext) Err() error {
2537 if c.Context.Err() != nil {
2538 return context.DeadlineExceeded
2539 }
2540 return nil
2541 }
2542
2543 func TestTimeoutHandler(t *testing.T) { run(t, testTimeoutHandler) }
2544 func testTimeoutHandler(t *testing.T, mode testMode) {
2545 sendHi := make(chan bool, 1)
2546 writeErrors := make(chan error, 1)
2547 sayHi := HandlerFunc(func(w ResponseWriter, r *Request) {
2548 <-sendHi
2549 _, werr := w.Write([]byte("hi"))
2550 writeErrors <- werr
2551 })
2552 ctx, cancel := context.WithCancel(context.Background())
2553 h := NewTestTimeoutHandler(sayHi, cancelableTimeoutContext{ctx})
2554 cst := newClientServerTest(t, mode, h)
2555
2556
2557 sendHi <- true
2558 res, err := cst.c.Get(cst.ts.URL)
2559 if err != nil {
2560 t.Error(err)
2561 }
2562 if g, e := res.StatusCode, StatusOK; g != e {
2563 t.Errorf("got res.StatusCode %d; expected %d", g, e)
2564 }
2565 body, _ := io.ReadAll(res.Body)
2566 if g, e := string(body), "hi"; g != e {
2567 t.Errorf("got body %q; expected %q", g, e)
2568 }
2569 if g := <-writeErrors; g != nil {
2570 t.Errorf("got unexpected Write error on first request: %v", g)
2571 }
2572
2573
2574 cancel()
2575
2576 res, err = cst.c.Get(cst.ts.URL)
2577 if err != nil {
2578 t.Error(err)
2579 }
2580 if g, e := res.StatusCode, StatusServiceUnavailable; g != e {
2581 t.Errorf("got res.StatusCode %d; expected %d", g, e)
2582 }
2583 body, _ = io.ReadAll(res.Body)
2584 if !strings.Contains(string(body), "<title>Timeout</title>") {
2585 t.Errorf("expected timeout body; got %q", string(body))
2586 }
2587 if g, w := res.Header.Get("Content-Type"), "text/html; charset=utf-8"; g != w {
2588 t.Errorf("response content-type = %q; want %q", g, w)
2589 }
2590
2591
2592
2593 sendHi <- true
2594 if g, e := <-writeErrors, ErrHandlerTimeout; g != e {
2595 t.Errorf("expected Write error of %v; got %v", e, g)
2596 }
2597 }
2598
2599
2600 func TestTimeoutHandlerRace(t *testing.T) { run(t, testTimeoutHandlerRace) }
2601 func testTimeoutHandlerRace(t *testing.T, mode testMode) {
2602 delayHi := HandlerFunc(func(w ResponseWriter, r *Request) {
2603 ms, _ := strconv.Atoi(r.URL.Path[1:])
2604 if ms == 0 {
2605 ms = 1
2606 }
2607 for i := 0; i < ms; i++ {
2608 w.Write([]byte("hi"))
2609 time.Sleep(time.Millisecond)
2610 }
2611 })
2612
2613 ts := newClientServerTest(t, mode, TimeoutHandler(delayHi, 20*time.Millisecond, "")).ts
2614
2615 c := ts.Client()
2616
2617 var wg sync.WaitGroup
2618 gate := make(chan bool, 10)
2619 n := 50
2620 if testing.Short() {
2621 n = 10
2622 gate = make(chan bool, 3)
2623 }
2624 for i := 0; i < n; i++ {
2625 gate <- true
2626 wg.Add(1)
2627 go func() {
2628 defer wg.Done()
2629 defer func() { <-gate }()
2630 res, err := c.Get(fmt.Sprintf("%s/%d", ts.URL, rand.Intn(50)))
2631 if err == nil {
2632 io.Copy(io.Discard, res.Body)
2633 res.Body.Close()
2634 }
2635 }()
2636 }
2637 wg.Wait()
2638 }
2639
2640
2641
2642 func TestTimeoutHandlerRaceHeader(t *testing.T) { run(t, testTimeoutHandlerRaceHeader) }
2643 func testTimeoutHandlerRaceHeader(t *testing.T, mode testMode) {
2644 delay204 := HandlerFunc(func(w ResponseWriter, r *Request) {
2645 w.WriteHeader(204)
2646 })
2647
2648 ts := newClientServerTest(t, mode, TimeoutHandler(delay204, time.Nanosecond, "")).ts
2649
2650 var wg sync.WaitGroup
2651 gate := make(chan bool, 50)
2652 n := 500
2653 if testing.Short() {
2654 n = 10
2655 }
2656
2657 c := ts.Client()
2658 for i := 0; i < n; i++ {
2659 gate <- true
2660 wg.Add(1)
2661 go func() {
2662 defer wg.Done()
2663 defer func() { <-gate }()
2664 res, err := c.Get(ts.URL)
2665 if err != nil {
2666
2667
2668 t.Log(err)
2669 return
2670 }
2671 defer res.Body.Close()
2672 io.Copy(io.Discard, res.Body)
2673 }()
2674 }
2675 wg.Wait()
2676 }
2677
2678
2679 func TestTimeoutHandlerRaceHeaderTimeout(t *testing.T) { run(t, testTimeoutHandlerRaceHeaderTimeout) }
2680 func testTimeoutHandlerRaceHeaderTimeout(t *testing.T, mode testMode) {
2681 sendHi := make(chan bool, 1)
2682 writeErrors := make(chan error, 1)
2683 sayHi := HandlerFunc(func(w ResponseWriter, r *Request) {
2684 w.Header().Set("Content-Type", "text/plain")
2685 <-sendHi
2686 _, werr := w.Write([]byte("hi"))
2687 writeErrors <- werr
2688 })
2689 ctx, cancel := context.WithCancel(context.Background())
2690 h := NewTestTimeoutHandler(sayHi, cancelableTimeoutContext{ctx})
2691 cst := newClientServerTest(t, mode, h)
2692
2693
2694 sendHi <- true
2695 res, err := cst.c.Get(cst.ts.URL)
2696 if err != nil {
2697 t.Error(err)
2698 }
2699 if g, e := res.StatusCode, StatusOK; g != e {
2700 t.Errorf("got res.StatusCode %d; expected %d", g, e)
2701 }
2702 body, _ := io.ReadAll(res.Body)
2703 if g, e := string(body), "hi"; g != e {
2704 t.Errorf("got body %q; expected %q", g, e)
2705 }
2706 if g := <-writeErrors; g != nil {
2707 t.Errorf("got unexpected Write error on first request: %v", g)
2708 }
2709
2710
2711 cancel()
2712
2713 res, err = cst.c.Get(cst.ts.URL)
2714 if err != nil {
2715 t.Error(err)
2716 }
2717 if g, e := res.StatusCode, StatusServiceUnavailable; g != e {
2718 t.Errorf("got res.StatusCode %d; expected %d", g, e)
2719 }
2720 body, _ = io.ReadAll(res.Body)
2721 if !strings.Contains(string(body), "<title>Timeout</title>") {
2722 t.Errorf("expected timeout body; got %q", string(body))
2723 }
2724
2725
2726
2727 sendHi <- true
2728 if g, e := <-writeErrors, ErrHandlerTimeout; g != e {
2729 t.Errorf("expected Write error of %v; got %v", e, g)
2730 }
2731 }
2732
2733
2734 func TestTimeoutHandlerStartTimerWhenServing(t *testing.T) {
2735 run(t, testTimeoutHandlerStartTimerWhenServing)
2736 }
2737 func testTimeoutHandlerStartTimerWhenServing(t *testing.T, mode testMode) {
2738 if testing.Short() {
2739 t.Skip("skipping sleeping test in -short mode")
2740 }
2741 var handler HandlerFunc = func(w ResponseWriter, _ *Request) {
2742 w.WriteHeader(StatusNoContent)
2743 }
2744 timeout := 300 * time.Millisecond
2745 ts := newClientServerTest(t, mode, TimeoutHandler(handler, timeout, "")).ts
2746 defer ts.Close()
2747
2748 c := ts.Client()
2749
2750
2751
2752
2753 time.Sleep(2 * timeout)
2754 res, err := c.Get(ts.URL)
2755 if err != nil {
2756 t.Fatal(err)
2757 }
2758 defer res.Body.Close()
2759 if res.StatusCode != StatusNoContent {
2760 t.Errorf("got res.StatusCode %d, want %v", res.StatusCode, StatusNoContent)
2761 }
2762 }
2763
2764 func TestTimeoutHandlerContextCanceled(t *testing.T) { run(t, testTimeoutHandlerContextCanceled) }
2765 func testTimeoutHandlerContextCanceled(t *testing.T, mode testMode) {
2766 writeErrors := make(chan error, 1)
2767 sayHi := HandlerFunc(func(w ResponseWriter, r *Request) {
2768 w.Header().Set("Content-Type", "text/plain")
2769 var err error
2770
2771
2772
2773 for i := 0; i < 100; i++ {
2774 _, err = w.Write([]byte("a"))
2775 if err != nil {
2776 break
2777 }
2778 time.Sleep(1 * time.Millisecond)
2779 }
2780 writeErrors <- err
2781 })
2782 ctx, cancel := context.WithCancel(context.Background())
2783 cancel()
2784 h := NewTestTimeoutHandler(sayHi, ctx)
2785 cst := newClientServerTest(t, mode, h)
2786 defer cst.close()
2787
2788 res, err := cst.c.Get(cst.ts.URL)
2789 if err != nil {
2790 t.Error(err)
2791 }
2792 if g, e := res.StatusCode, StatusServiceUnavailable; g != e {
2793 t.Errorf("got res.StatusCode %d; expected %d", g, e)
2794 }
2795 body, _ := io.ReadAll(res.Body)
2796 if g, e := string(body), ""; g != e {
2797 t.Errorf("got body %q; expected %q", g, e)
2798 }
2799 if g, e := <-writeErrors, context.Canceled; g != e {
2800 t.Errorf("got unexpected Write in handler: %v, want %g", g, e)
2801 }
2802 }
2803
2804
2805 func TestTimeoutHandlerEmptyResponse(t *testing.T) { run(t, testTimeoutHandlerEmptyResponse) }
2806 func testTimeoutHandlerEmptyResponse(t *testing.T, mode testMode) {
2807 var handler HandlerFunc = func(w ResponseWriter, _ *Request) {
2808
2809 }
2810 timeout := 300 * time.Millisecond
2811 ts := newClientServerTest(t, mode, TimeoutHandler(handler, timeout, "")).ts
2812
2813 c := ts.Client()
2814
2815 res, err := c.Get(ts.URL)
2816 if err != nil {
2817 t.Fatal(err)
2818 }
2819 defer res.Body.Close()
2820 if res.StatusCode != StatusOK {
2821 t.Errorf("got res.StatusCode %d, want %v", res.StatusCode, StatusOK)
2822 }
2823 }
2824
2825
2826 func TestTimeoutHandlerPanicRecovery(t *testing.T) {
2827 wrapper := func(h Handler) Handler {
2828 return TimeoutHandler(h, time.Second, "")
2829 }
2830 run(t, func(t *testing.T, mode testMode) {
2831 testHandlerPanic(t, false, mode, wrapper, "intentional death for testing")
2832 }, testNotParallel)
2833 }
2834
2835 func TestRedirectBadPath(t *testing.T) {
2836
2837
2838 rr := httptest.NewRecorder()
2839 req := &Request{
2840 Method: "GET",
2841 URL: &url.URL{
2842 Scheme: "http",
2843 Path: "not-empty-but-no-leading-slash",
2844 },
2845 }
2846 Redirect(rr, req, "", 304)
2847 if rr.Code != 304 {
2848 t.Errorf("Code = %d; want 304", rr.Code)
2849 }
2850 }
2851
2852
2853 func TestRedirect(t *testing.T) {
2854 req, _ := NewRequest("GET", "http://example.com/qux/", nil)
2855
2856 var tests = []struct {
2857 in string
2858 want string
2859 }{
2860
2861 {"http://foobar.com/baz", "http://foobar.com/baz"},
2862
2863 {"https://foobar.com/baz", "https://foobar.com/baz"},
2864
2865 {"test://foobar.com/baz", "test://foobar.com/baz"},
2866
2867 {"//foobar.com/baz", "//foobar.com/baz"},
2868
2869 {"/foobar.com/baz", "/foobar.com/baz"},
2870
2871 {"foobar.com/baz", "/qux/foobar.com/baz"},
2872
2873 {"../quux/foobar.com/baz", "/quux/foobar.com/baz"},
2874
2875 {"///foobar.com/baz", "/foobar.com/baz"},
2876
2877
2878 {"/foo?next=http://bar.com/", "/foo?next=http://bar.com/"},
2879 {"http://localhost:8080/_ah/login?continue=http://localhost:8080/",
2880 "http://localhost:8080/_ah/login?continue=http://localhost:8080/"},
2881
2882 {"/фубар", "/%d1%84%d1%83%d0%b1%d0%b0%d1%80"},
2883 {"http://foo.com/фубар", "http://foo.com/%d1%84%d1%83%d0%b1%d0%b0%d1%80"},
2884 }
2885
2886 for _, tt := range tests {
2887 rec := httptest.NewRecorder()
2888 Redirect(rec, req, tt.in, 302)
2889 if got, want := rec.Code, 302; got != want {
2890 t.Errorf("Redirect(%q) generated status code %v; want %v", tt.in, got, want)
2891 }
2892 if got := rec.Header().Get("Location"); got != tt.want {
2893 t.Errorf("Redirect(%q) generated Location header %q; want %q", tt.in, got, tt.want)
2894 }
2895 }
2896 }
2897
2898
2899
2900 func TestRedirectContentTypeAndBody(t *testing.T) {
2901 type ctHeader struct {
2902 Values []string
2903 }
2904
2905 var tests = []struct {
2906 method string
2907 ct *ctHeader
2908 wantCT string
2909 wantBody string
2910 }{
2911 {MethodGet, nil, "text/html; charset=utf-8", "<a href=\"/foo\">Found</a>.\n\n"},
2912 {MethodHead, nil, "text/html; charset=utf-8", ""},
2913 {MethodPost, nil, "", ""},
2914 {MethodDelete, nil, "", ""},
2915 {"foo", nil, "", ""},
2916 {MethodGet, &ctHeader{[]string{"application/test"}}, "application/test", ""},
2917 {MethodGet, &ctHeader{[]string{}}, "", ""},
2918 {MethodGet, &ctHeader{nil}, "", ""},
2919 }
2920 for _, tt := range tests {
2921 req := httptest.NewRequest(tt.method, "http://example.com/qux/", nil)
2922 rec := httptest.NewRecorder()
2923 if tt.ct != nil {
2924 rec.Header()["Content-Type"] = tt.ct.Values
2925 }
2926 Redirect(rec, req, "/foo", 302)
2927 if got, want := rec.Code, 302; got != want {
2928 t.Errorf("Redirect(%q, %#v) generated status code %v; want %v", tt.method, tt.ct, got, want)
2929 }
2930 if got, want := rec.Header().Get("Content-Type"), tt.wantCT; got != want {
2931 t.Errorf("Redirect(%q, %#v) generated Content-Type header %q; want %q", tt.method, tt.ct, got, want)
2932 }
2933 resp := rec.Result()
2934 body, err := io.ReadAll(resp.Body)
2935 if err != nil {
2936 t.Fatal(err)
2937 }
2938 if got, want := string(body), tt.wantBody; got != want {
2939 t.Errorf("Redirect(%q, %#v) generated Body %q; want %q", tt.method, tt.ct, got, want)
2940 }
2941 }
2942 }
2943
2944
2945
2946
2947
2948
2949
2950 func TestZeroLengthPostAndResponse(t *testing.T) { run(t, testZeroLengthPostAndResponse) }
2951
2952 func testZeroLengthPostAndResponse(t *testing.T, mode testMode) {
2953 cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, r *Request) {
2954 all, err := io.ReadAll(r.Body)
2955 if err != nil {
2956 t.Fatalf("handler ReadAll: %v", err)
2957 }
2958 if len(all) != 0 {
2959 t.Errorf("handler got %d bytes; expected 0", len(all))
2960 }
2961 rw.Header().Set("Content-Length", "0")
2962 }))
2963
2964 req, err := NewRequest("POST", cst.ts.URL, strings.NewReader(""))
2965 if err != nil {
2966 t.Fatal(err)
2967 }
2968 req.ContentLength = 0
2969
2970 var resp [5]*Response
2971 for i := range resp {
2972 resp[i], err = cst.c.Do(req)
2973 if err != nil {
2974 t.Fatalf("client post #%d: %v", i, err)
2975 }
2976 }
2977
2978 for i := range resp {
2979 all, err := io.ReadAll(resp[i].Body)
2980 if err != nil {
2981 t.Fatalf("req #%d: client ReadAll: %v", i, err)
2982 }
2983 if len(all) != 0 {
2984 t.Errorf("req #%d: client got %d bytes; expected 0", i, len(all))
2985 }
2986 }
2987 }
2988
2989 func TestHandlerPanicNil(t *testing.T) {
2990 run(t, func(t *testing.T, mode testMode) {
2991 testHandlerPanic(t, false, mode, nil, nil)
2992 }, testNotParallel)
2993 }
2994
2995 func TestHandlerPanic(t *testing.T) {
2996 run(t, func(t *testing.T, mode testMode) {
2997 testHandlerPanic(t, false, mode, nil, "intentional death for testing")
2998 }, testNotParallel)
2999 }
3000
3001 func TestHandlerPanicWithHijack(t *testing.T) {
3002
3003 run(t, func(t *testing.T, mode testMode) {
3004 testHandlerPanic(t, true, mode, nil, "intentional death for testing")
3005 }, []testMode{http1Mode})
3006 }
3007
3008 func testHandlerPanic(t *testing.T, withHijack bool, mode testMode, wrapper func(Handler) Handler, panicValue any) {
3009
3010
3011
3012
3013
3014
3015
3016
3017 pr, pw := io.Pipe()
3018 defer pw.Close()
3019
3020 var handler Handler = HandlerFunc(func(w ResponseWriter, r *Request) {
3021 if withHijack {
3022 rwc, _, err := w.(Hijacker).Hijack()
3023 if err != nil {
3024 t.Logf("unexpected error: %v", err)
3025 }
3026 defer rwc.Close()
3027 }
3028 panic(panicValue)
3029 })
3030 if wrapper != nil {
3031 handler = wrapper(handler)
3032 }
3033 cst := newClientServerTest(t, mode, handler, func(ts *httptest.Server) {
3034 ts.Config.ErrorLog = log.New(pw, "", 0)
3035 })
3036
3037
3038 done := make(chan bool, 1)
3039 go func() {
3040 buf := make([]byte, 4<<10)
3041 _, err := pr.Read(buf)
3042 pr.Close()
3043 if err != nil && err != io.EOF {
3044 t.Error(err)
3045 }
3046 done <- true
3047 }()
3048
3049 _, err := cst.c.Get(cst.ts.URL)
3050 if err == nil {
3051 t.Logf("expected an error")
3052 }
3053
3054 if panicValue == nil {
3055 return
3056 }
3057
3058 <-done
3059 }
3060
3061 type terrorWriter struct{ t *testing.T }
3062
3063 func (w terrorWriter) Write(p []byte) (int, error) {
3064 w.t.Errorf("%s", p)
3065 return len(p), nil
3066 }
3067
3068
3069
3070 func TestServerWriteHijackZeroBytes(t *testing.T) {
3071 run(t, testServerWriteHijackZeroBytes, []testMode{http1Mode})
3072 }
3073 func testServerWriteHijackZeroBytes(t *testing.T, mode testMode) {
3074 done := make(chan struct{})
3075 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3076 defer close(done)
3077 w.(Flusher).Flush()
3078 conn, _, err := w.(Hijacker).Hijack()
3079 if err != nil {
3080 t.Errorf("Hijack: %v", err)
3081 return
3082 }
3083 defer conn.Close()
3084 _, err = w.Write(nil)
3085 if err != ErrHijacked {
3086 t.Errorf("Write error = %v; want ErrHijacked", err)
3087 }
3088 }), func(ts *httptest.Server) {
3089 ts.Config.ErrorLog = log.New(terrorWriter{t}, "Unexpected write: ", 0)
3090 }).ts
3091
3092 c := ts.Client()
3093 res, err := c.Get(ts.URL)
3094 if err != nil {
3095 t.Fatal(err)
3096 }
3097 res.Body.Close()
3098 <-done
3099 }
3100
3101 func TestServerNoDate(t *testing.T) {
3102 run(t, func(t *testing.T, mode testMode) {
3103 testServerNoHeader(t, mode, "Date")
3104 })
3105 }
3106
3107 func TestServerContentType(t *testing.T) {
3108 run(t, func(t *testing.T, mode testMode) {
3109 testServerNoHeader(t, mode, "Content-Type")
3110 })
3111 }
3112
3113 func testServerNoHeader(t *testing.T, mode testMode, header string) {
3114 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3115 w.Header()[header] = nil
3116 io.WriteString(w, "<html>foo</html>")
3117 }))
3118 res, err := cst.c.Get(cst.ts.URL)
3119 if err != nil {
3120 t.Fatal(err)
3121 }
3122 res.Body.Close()
3123 if got, ok := res.Header[header]; ok {
3124 t.Fatalf("Expected no %s header; got %q", header, got)
3125 }
3126 }
3127
3128 func TestStripPrefix(t *testing.T) { run(t, testStripPrefix) }
3129 func testStripPrefix(t *testing.T, mode testMode) {
3130 h := HandlerFunc(func(w ResponseWriter, r *Request) {
3131 w.Header().Set("X-Path", r.URL.Path)
3132 w.Header().Set("X-RawPath", r.URL.RawPath)
3133 })
3134 ts := newClientServerTest(t, mode, StripPrefix("/foo/bar", h)).ts
3135
3136 c := ts.Client()
3137
3138 cases := []struct {
3139 reqPath string
3140 path string
3141 rawPath string
3142 }{
3143 {"/foo/bar/qux", "/qux", ""},
3144 {"/foo/bar%2Fqux", "/qux", "%2Fqux"},
3145 {"/foo%2Fbar/qux", "", ""},
3146 {"/bar", "", ""},
3147 }
3148 for _, tc := range cases {
3149 t.Run(tc.reqPath, func(t *testing.T) {
3150 res, err := c.Get(ts.URL + tc.reqPath)
3151 if err != nil {
3152 t.Fatal(err)
3153 }
3154 res.Body.Close()
3155 if tc.path == "" {
3156 if res.StatusCode != StatusNotFound {
3157 t.Errorf("got %q, want 404 Not Found", res.Status)
3158 }
3159 return
3160 }
3161 if res.StatusCode != StatusOK {
3162 t.Fatalf("got %q, want 200 OK", res.Status)
3163 }
3164 if g, w := res.Header.Get("X-Path"), tc.path; g != w {
3165 t.Errorf("got Path %q, want %q", g, w)
3166 }
3167 if g, w := res.Header.Get("X-RawPath"), tc.rawPath; g != w {
3168 t.Errorf("got RawPath %q, want %q", g, w)
3169 }
3170 })
3171 }
3172 }
3173
3174
3175 func TestStripPrefixNotModifyRequest(t *testing.T) {
3176 h := StripPrefix("/foo", NotFoundHandler())
3177 req := httptest.NewRequest("GET", "/foo/bar", nil)
3178 h.ServeHTTP(httptest.NewRecorder(), req)
3179 if req.URL.Path != "/foo/bar" {
3180 t.Errorf("StripPrefix should not modify the provided Request, but it did")
3181 }
3182 }
3183
3184 func TestRequestLimit(t *testing.T) { run(t, testRequestLimit) }
3185 func testRequestLimit(t *testing.T, mode testMode) {
3186 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3187 t.Fatalf("didn't expect to get request in Handler")
3188 }), optQuietLog)
3189 req, _ := NewRequest("GET", cst.ts.URL, nil)
3190 var bytesPerHeader = len("header12345: val12345\r\n")
3191 for i := 0; i < ((DefaultMaxHeaderBytes+4096)/bytesPerHeader)+1; i++ {
3192 req.Header.Set(fmt.Sprintf("header%05d", i), fmt.Sprintf("val%05d", i))
3193 }
3194 res, err := cst.c.Do(req)
3195 if res != nil {
3196 defer res.Body.Close()
3197 }
3198 if mode == http2Mode {
3199
3200
3201
3202
3203 if err == nil && res.StatusCode != 431 {
3204 t.Fatalf("expected 431 response status; got: %d %s", res.StatusCode, res.Status)
3205 }
3206 } else {
3207
3208
3209
3210
3211 if err != nil {
3212 t.Fatalf("Do: %v", err)
3213 }
3214 if res.StatusCode != 431 {
3215 t.Fatalf("expected 431 response status; got: %d %s", res.StatusCode, res.Status)
3216 }
3217 }
3218 }
3219
3220 type neverEnding byte
3221
3222 func (b neverEnding) Read(p []byte) (n int, err error) {
3223 for i := range p {
3224 p[i] = byte(b)
3225 }
3226 return len(p), nil
3227 }
3228
3229 type bodyLimitReader struct {
3230 mu sync.Mutex
3231 count int
3232 limit int
3233 closed chan struct{}
3234 }
3235
3236 func (r *bodyLimitReader) Read(p []byte) (int, error) {
3237 r.mu.Lock()
3238 defer r.mu.Unlock()
3239 select {
3240 case <-r.closed:
3241 return 0, errors.New("closed")
3242 default:
3243 }
3244 if r.count > r.limit {
3245 return 0, errors.New("at limit")
3246 }
3247 r.count += len(p)
3248 for i := range p {
3249 p[i] = 'a'
3250 }
3251 return len(p), nil
3252 }
3253
3254 func (r *bodyLimitReader) Close() error {
3255 r.mu.Lock()
3256 defer r.mu.Unlock()
3257 close(r.closed)
3258 return nil
3259 }
3260
3261 func TestRequestBodyLimit(t *testing.T) { run(t, testRequestBodyLimit) }
3262 func testRequestBodyLimit(t *testing.T, mode testMode) {
3263 const limit = 1 << 20
3264 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3265 r.Body = MaxBytesReader(w, r.Body, limit)
3266 n, err := io.Copy(io.Discard, r.Body)
3267 if err == nil {
3268 t.Errorf("expected error from io.Copy")
3269 }
3270 if n != limit {
3271 t.Errorf("io.Copy = %d, want %d", n, limit)
3272 }
3273 mbErr, ok := err.(*MaxBytesError)
3274 if !ok {
3275 t.Errorf("expected MaxBytesError, got %T", err)
3276 }
3277 if mbErr.Limit != limit {
3278 t.Errorf("MaxBytesError.Limit = %d, want %d", mbErr.Limit, limit)
3279 }
3280 }))
3281
3282 body := &bodyLimitReader{
3283 closed: make(chan struct{}),
3284 limit: limit * 200,
3285 }
3286 req, _ := NewRequest("POST", cst.ts.URL, body)
3287
3288
3289
3290
3291
3292
3293
3294
3295
3296
3297 resp, err := cst.c.Do(req)
3298 if err == nil {
3299 resp.Body.Close()
3300 }
3301
3302
3303 <-body.closed
3304
3305 if body.count > limit*100 {
3306 t.Errorf("handler restricted the request body to %d bytes, but client managed to write %d",
3307 limit, body.count)
3308 }
3309 }
3310
3311
3312
3313 func TestClientWriteShutdown(t *testing.T) { run(t, testClientWriteShutdown) }
3314 func testClientWriteShutdown(t *testing.T, mode testMode) {
3315 if runtime.GOOS == "plan9" {
3316 t.Skip("skipping test; see https://golang.org/issue/17906")
3317 }
3318 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {})).ts
3319 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3320 if err != nil {
3321 t.Fatalf("Dial: %v", err)
3322 }
3323 err = conn.(*net.TCPConn).CloseWrite()
3324 if err != nil {
3325 t.Fatalf("CloseWrite: %v", err)
3326 }
3327
3328 bs, err := io.ReadAll(conn)
3329 if err != nil {
3330 t.Errorf("ReadAll: %v", err)
3331 }
3332 got := string(bs)
3333 if got != "" {
3334 t.Errorf("read %q from server; want nothing", got)
3335 }
3336 }
3337
3338
3339
3340 func TestServerBufferedChunking(t *testing.T) {
3341 conn := new(testConn)
3342 conn.readBuf.Write([]byte("GET / HTTP/1.1\r\nHost: foo\r\n\r\n"))
3343 conn.closec = make(chan bool, 1)
3344 ls := &oneConnListener{conn}
3345 go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
3346 rw.(Flusher).Flush()
3347 rw.Write([]byte{'x'})
3348 rw.Write([]byte{'y'})
3349 rw.Write([]byte{'z'})
3350 }))
3351 <-conn.closec
3352 if !bytes.HasSuffix(conn.writeBuf.Bytes(), []byte("\r\n\r\n3\r\nxyz\r\n0\r\n\r\n")) {
3353 t.Errorf("response didn't end with a single 3 byte 'xyz' chunk; got:\n%q",
3354 conn.writeBuf.Bytes())
3355 }
3356 }
3357
3358
3359
3360
3361
3362 func TestServerGracefulClose(t *testing.T) {
3363
3364 run(t, testServerGracefulClose, []testMode{http1Mode}, testNotParallel)
3365 }
3366 func testServerGracefulClose(t *testing.T, mode testMode) {
3367 runTimeSensitiveTest(t, []time.Duration{
3368 1 * time.Millisecond,
3369 5 * time.Millisecond,
3370 10 * time.Millisecond,
3371 50 * time.Millisecond,
3372 100 * time.Millisecond,
3373 500 * time.Millisecond,
3374 time.Second,
3375 5 * time.Second,
3376 }, func(t *testing.T, timeout time.Duration) error {
3377 SetRSTAvoidanceDelay(t, timeout)
3378 t.Logf("set RST avoidance delay to %v", timeout)
3379
3380 const bodySize = 5 << 20
3381 req := []byte(fmt.Sprintf("POST / HTTP/1.1\r\nHost: foo.com\r\nContent-Length: %d\r\n\r\n", bodySize))
3382 for i := 0; i < bodySize; i++ {
3383 req = append(req, 'x')
3384 }
3385
3386 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3387 Error(w, "bye", StatusUnauthorized)
3388 }))
3389
3390
3391 defer cst.close()
3392 ts := cst.ts
3393
3394 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3395 if err != nil {
3396 return err
3397 }
3398 writeErr := make(chan error)
3399 go func() {
3400 _, err := conn.Write(req)
3401 writeErr <- err
3402 }()
3403 defer func() {
3404 conn.Close()
3405
3406
3407
3408 <-writeErr
3409 }()
3410
3411 br := bufio.NewReader(conn)
3412 lineNum := 0
3413 for {
3414 line, err := br.ReadString('\n')
3415 if err == io.EOF {
3416 break
3417 }
3418 if err != nil {
3419 return fmt.Errorf("ReadLine: %v", err)
3420 }
3421 lineNum++
3422 if lineNum == 1 && !strings.Contains(line, "401 Unauthorized") {
3423 t.Errorf("Response line = %q; want a 401", line)
3424 }
3425 }
3426 return nil
3427 })
3428 }
3429
3430 func TestCaseSensitiveMethod(t *testing.T) { run(t, testCaseSensitiveMethod) }
3431 func testCaseSensitiveMethod(t *testing.T, mode testMode) {
3432 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3433 if r.Method != "get" {
3434 t.Errorf(`Got method %q; want "get"`, r.Method)
3435 }
3436 }))
3437 defer cst.close()
3438 req, _ := NewRequest("get", cst.ts.URL, nil)
3439 res, err := cst.c.Do(req)
3440 if err != nil {
3441 t.Error(err)
3442 return
3443 }
3444
3445 res.Body.Close()
3446 }
3447
3448
3449
3450
3451
3452 func TestContentLengthZero(t *testing.T) {
3453 run(t, testContentLengthZero, []testMode{http1Mode})
3454 }
3455 func testContentLengthZero(t *testing.T, mode testMode) {
3456 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {})).ts
3457
3458 for _, version := range []string{"HTTP/1.0", "HTTP/1.1"} {
3459 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3460 if err != nil {
3461 t.Fatalf("error dialing: %v", err)
3462 }
3463 _, err = fmt.Fprintf(conn, "GET / %v\r\nConnection: keep-alive\r\nHost: foo\r\n\r\n", version)
3464 if err != nil {
3465 t.Fatalf("error writing: %v", err)
3466 }
3467 req, _ := NewRequest("GET", "/", nil)
3468 res, err := ReadResponse(bufio.NewReader(conn), req)
3469 if err != nil {
3470 t.Fatalf("error reading response: %v", err)
3471 }
3472 if te := res.TransferEncoding; len(te) > 0 {
3473 t.Errorf("For version %q, Transfer-Encoding = %q; want none", version, te)
3474 }
3475 if cl := res.ContentLength; cl != 0 {
3476 t.Errorf("For version %q, Content-Length = %v; want 0", version, cl)
3477 }
3478 conn.Close()
3479 }
3480 }
3481
3482 func TestCloseNotifier(t *testing.T) {
3483 run(t, testCloseNotifier, []testMode{http1Mode})
3484 }
3485 func testCloseNotifier(t *testing.T, mode testMode) {
3486 gotReq := make(chan bool, 1)
3487 sawClose := make(chan bool, 1)
3488 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
3489 gotReq <- true
3490 cc := rw.(CloseNotifier).CloseNotify()
3491 <-cc
3492 sawClose <- true
3493 })).ts
3494 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3495 if err != nil {
3496 t.Fatalf("error dialing: %v", err)
3497 }
3498 diec := make(chan bool)
3499 go func() {
3500 _, err = fmt.Fprintf(conn, "GET / HTTP/1.1\r\nConnection: keep-alive\r\nHost: foo\r\n\r\n")
3501 if err != nil {
3502 t.Error(err)
3503 return
3504 }
3505 <-diec
3506 conn.Close()
3507 }()
3508 For:
3509 for {
3510 select {
3511 case <-gotReq:
3512 diec <- true
3513 case <-sawClose:
3514 break For
3515 }
3516 }
3517 ts.Close()
3518 }
3519
3520
3521
3522
3523
3524 func TestCloseNotifierPipelined(t *testing.T) {
3525 run(t, testCloseNotifierPipelined, []testMode{http1Mode})
3526 }
3527 func testCloseNotifierPipelined(t *testing.T, mode testMode) {
3528 gotReq := make(chan bool, 2)
3529 sawClose := make(chan bool, 2)
3530 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
3531 gotReq <- true
3532 cc := rw.(CloseNotifier).CloseNotify()
3533 select {
3534 case <-cc:
3535 t.Error("unexpected CloseNotify")
3536 case <-time.After(100 * time.Millisecond):
3537 }
3538 sawClose <- true
3539 })).ts
3540 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3541 if err != nil {
3542 t.Fatalf("error dialing: %v", err)
3543 }
3544 diec := make(chan bool, 1)
3545 defer close(diec)
3546 go func() {
3547 const req = "GET / HTTP/1.1\r\nConnection: keep-alive\r\nHost: foo\r\n\r\n"
3548 _, err = io.WriteString(conn, req+req)
3549 if err != nil {
3550 t.Error(err)
3551 return
3552 }
3553 <-diec
3554 conn.Close()
3555 }()
3556 reqs := 0
3557 closes := 0
3558 for {
3559 select {
3560 case <-gotReq:
3561 reqs++
3562 if reqs > 2 {
3563 t.Fatal("too many requests")
3564 }
3565 case <-sawClose:
3566 closes++
3567 if closes > 1 {
3568 return
3569 }
3570 }
3571 }
3572 }
3573
3574 func TestCloseNotifierChanLeak(t *testing.T) {
3575 defer afterTest(t)
3576 req := reqBytes("GET / HTTP/1.0\nHost: golang.org")
3577 for i := 0; i < 20; i++ {
3578 var output bytes.Buffer
3579 conn := &rwTestConn{
3580 Reader: bytes.NewReader(req),
3581 Writer: &output,
3582 closec: make(chan bool, 1),
3583 }
3584 ln := &oneConnListener{conn: conn}
3585 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
3586
3587
3588
3589 _ = rw.(CloseNotifier).CloseNotify()
3590 })
3591 go Serve(ln, handler)
3592 <-conn.closec
3593 }
3594 }
3595
3596
3597
3598
3599
3600
3601
3602
3603
3604
3605 func TestHijackAfterCloseNotifier(t *testing.T) {
3606 run(t, testHijackAfterCloseNotifier, []testMode{http1Mode})
3607 }
3608 func testHijackAfterCloseNotifier(t *testing.T, mode testMode) {
3609 script := make(chan string, 2)
3610 script <- "closenotify"
3611 script <- "hijack"
3612 close(script)
3613 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3614 plan := <-script
3615 switch plan {
3616 default:
3617 panic("bogus plan; too many requests")
3618 case "closenotify":
3619 w.(CloseNotifier).CloseNotify()
3620 w.Header().Set("X-Addr", r.RemoteAddr)
3621 case "hijack":
3622 c, _, err := w.(Hijacker).Hijack()
3623 if err != nil {
3624 t.Errorf("Hijack in Handler: %v", err)
3625 return
3626 }
3627 if _, ok := c.(*net.TCPConn); !ok {
3628
3629
3630 t.Errorf("type of hijacked conn is %T; want *net.TCPConn", c)
3631 }
3632 fmt.Fprintf(c, "HTTP/1.0 200 OK\r\nX-Addr: %v\r\nContent-Length: 0\r\n\r\n", r.RemoteAddr)
3633 c.Close()
3634 return
3635 }
3636 })).ts
3637 res1, err := ts.Client().Get(ts.URL)
3638 if err != nil {
3639 log.Fatal(err)
3640 }
3641 res2, err := ts.Client().Get(ts.URL)
3642 if err != nil {
3643 log.Fatal(err)
3644 }
3645 addr1 := res1.Header.Get("X-Addr")
3646 addr2 := res2.Header.Get("X-Addr")
3647 if addr1 == "" || addr1 != addr2 {
3648 t.Errorf("addr1, addr2 = %q, %q; want same", addr1, addr2)
3649 }
3650 }
3651
3652 func TestHijackBeforeRequestBodyRead(t *testing.T) {
3653 run(t, testHijackBeforeRequestBodyRead, []testMode{http1Mode})
3654 }
3655 func testHijackBeforeRequestBodyRead(t *testing.T, mode testMode) {
3656 var requestBody = bytes.Repeat([]byte("a"), 1<<20)
3657 bodyOkay := make(chan bool, 1)
3658 gotCloseNotify := make(chan bool, 1)
3659 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3660 defer close(bodyOkay)
3661
3662 reqBody := r.Body
3663 r.Body = nil
3664
3665 gone := w.(CloseNotifier).CloseNotify()
3666 slurp, err := io.ReadAll(reqBody)
3667 if err != nil {
3668 t.Errorf("Body read: %v", err)
3669 return
3670 }
3671 if len(slurp) != len(requestBody) {
3672 t.Errorf("Backend read %d request body bytes; want %d", len(slurp), len(requestBody))
3673 return
3674 }
3675 if !bytes.Equal(slurp, requestBody) {
3676 t.Error("Backend read wrong request body.")
3677 return
3678 }
3679 bodyOkay <- true
3680 <-gone
3681 gotCloseNotify <- true
3682 })).ts
3683
3684 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3685 if err != nil {
3686 t.Fatal(err)
3687 }
3688 defer conn.Close()
3689
3690 fmt.Fprintf(conn, "POST / HTTP/1.1\r\nHost: foo\r\nContent-Length: %d\r\n\r\n%s",
3691 len(requestBody), requestBody)
3692 if !<-bodyOkay {
3693
3694 return
3695 }
3696 conn.Close()
3697 <-gotCloseNotify
3698 }
3699
3700 func TestOptions(t *testing.T) { run(t, testOptions, []testMode{http1Mode}) }
3701 func testOptions(t *testing.T, mode testMode) {
3702 uric := make(chan string, 2)
3703 mux := NewServeMux()
3704 mux.HandleFunc("/", func(w ResponseWriter, r *Request) {
3705 uric <- r.RequestURI
3706 })
3707 ts := newClientServerTest(t, mode, mux).ts
3708
3709 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3710 if err != nil {
3711 t.Fatal(err)
3712 }
3713 defer conn.Close()
3714
3715
3716 _, err = conn.Write([]byte("OPTIONS * HTTP/1.1\r\nHost: foo.com\r\n\r\n"))
3717 if err != nil {
3718 t.Fatal(err)
3719 }
3720 br := bufio.NewReader(conn)
3721 res, err := ReadResponse(br, &Request{Method: "OPTIONS"})
3722 if err != nil {
3723 t.Fatal(err)
3724 }
3725 if res.StatusCode != 200 {
3726 t.Errorf("Got non-200 response to OPTIONS *: %#v", res)
3727 }
3728
3729
3730 _, err = conn.Write([]byte("GET * HTTP/1.1\r\nHost: foo.com\r\n\r\n"))
3731 if err != nil {
3732 t.Fatal(err)
3733 }
3734 res, err = ReadResponse(br, &Request{Method: "GET"})
3735 if err != nil {
3736 t.Fatal(err)
3737 }
3738 if res.StatusCode != 400 {
3739 t.Errorf("Got non-400 response to GET *: %#v", res)
3740 }
3741
3742 res, err = Get(ts.URL + "/second")
3743 if err != nil {
3744 t.Fatal(err)
3745 }
3746 res.Body.Close()
3747 if got := <-uric; got != "/second" {
3748 t.Errorf("Handler saw request for %q; want /second", got)
3749 }
3750 }
3751
3752 func TestOptionsHandler(t *testing.T) { run(t, testOptionsHandler, []testMode{http1Mode}) }
3753 func testOptionsHandler(t *testing.T, mode testMode) {
3754 rc := make(chan *Request, 1)
3755
3756 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3757 rc <- r
3758 }), func(ts *httptest.Server) {
3759 ts.Config.DisableGeneralOptionsHandler = true
3760 }).ts
3761
3762 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3763 if err != nil {
3764 t.Fatal(err)
3765 }
3766 defer conn.Close()
3767
3768 _, err = conn.Write([]byte("OPTIONS * HTTP/1.1\r\nHost: foo.com\r\n\r\n"))
3769 if err != nil {
3770 t.Fatal(err)
3771 }
3772
3773 if got := <-rc; got.Method != "OPTIONS" || got.RequestURI != "*" {
3774 t.Errorf("Expected OPTIONS * request, got %v", got)
3775 }
3776 }
3777
3778
3779
3780
3781
3782
3783
3784
3785
3786
3787 func TestHeaderToWire(t *testing.T) {
3788 tests := []struct {
3789 name string
3790 handler func(ResponseWriter, *Request)
3791 check func(got, logs string) error
3792 }{
3793 {
3794 name: "write without Header",
3795 handler: func(rw ResponseWriter, r *Request) {
3796 rw.Write([]byte("hello world"))
3797 },
3798 check: func(got, logs string) error {
3799 if !strings.Contains(got, "Content-Length:") {
3800 return errors.New("no content-length")
3801 }
3802 if !strings.Contains(got, "Content-Type: text/plain") {
3803 return errors.New("no content-type")
3804 }
3805 return nil
3806 },
3807 },
3808 {
3809 name: "Header mutation before write",
3810 handler: func(rw ResponseWriter, r *Request) {
3811 h := rw.Header()
3812 h.Set("Content-Type", "some/type")
3813 rw.Write([]byte("hello world"))
3814 h.Set("Too-Late", "bogus")
3815 },
3816 check: func(got, logs string) error {
3817 if !strings.Contains(got, "Content-Length:") {
3818 return errors.New("no content-length")
3819 }
3820 if !strings.Contains(got, "Content-Type: some/type") {
3821 return errors.New("wrong content-type")
3822 }
3823 if strings.Contains(got, "Too-Late") {
3824 return errors.New("don't want too-late header")
3825 }
3826 return nil
3827 },
3828 },
3829 {
3830 name: "write then useless Header mutation",
3831 handler: func(rw ResponseWriter, r *Request) {
3832 rw.Write([]byte("hello world"))
3833 rw.Header().Set("Too-Late", "Write already wrote headers")
3834 },
3835 check: func(got, logs string) error {
3836 if strings.Contains(got, "Too-Late") {
3837 return errors.New("header appeared from after WriteHeader")
3838 }
3839 return nil
3840 },
3841 },
3842 {
3843 name: "flush then write",
3844 handler: func(rw ResponseWriter, r *Request) {
3845 rw.(Flusher).Flush()
3846 rw.Write([]byte("post-flush"))
3847 rw.Header().Set("Too-Late", "Write already wrote headers")
3848 },
3849 check: func(got, logs string) error {
3850 if !strings.Contains(got, "Transfer-Encoding: chunked") {
3851 return errors.New("not chunked")
3852 }
3853 if strings.Contains(got, "Too-Late") {
3854 return errors.New("header appeared from after WriteHeader")
3855 }
3856 return nil
3857 },
3858 },
3859 {
3860 name: "header then flush",
3861 handler: func(rw ResponseWriter, r *Request) {
3862 rw.Header().Set("Content-Type", "some/type")
3863 rw.(Flusher).Flush()
3864 rw.Write([]byte("post-flush"))
3865 rw.Header().Set("Too-Late", "Write already wrote headers")
3866 },
3867 check: func(got, logs string) error {
3868 if !strings.Contains(got, "Transfer-Encoding: chunked") {
3869 return errors.New("not chunked")
3870 }
3871 if strings.Contains(got, "Too-Late") {
3872 return errors.New("header appeared from after WriteHeader")
3873 }
3874 if !strings.Contains(got, "Content-Type: some/type") {
3875 return errors.New("wrong content-type")
3876 }
3877 return nil
3878 },
3879 },
3880 {
3881 name: "sniff-on-first-write content-type",
3882 handler: func(rw ResponseWriter, r *Request) {
3883 rw.Write([]byte("<html><head></head><body>some html</body></html>"))
3884 rw.Header().Set("Content-Type", "x/wrong")
3885 },
3886 check: func(got, logs string) error {
3887 if !strings.Contains(got, "Content-Type: text/html") {
3888 return errors.New("wrong content-type; want html")
3889 }
3890 return nil
3891 },
3892 },
3893 {
3894 name: "explicit content-type wins",
3895 handler: func(rw ResponseWriter, r *Request) {
3896 rw.Header().Set("Content-Type", "some/type")
3897 rw.Write([]byte("<html><head></head><body>some html</body></html>"))
3898 },
3899 check: func(got, logs string) error {
3900 if !strings.Contains(got, "Content-Type: some/type") {
3901 return errors.New("wrong content-type; want html")
3902 }
3903 return nil
3904 },
3905 },
3906 {
3907 name: "empty handler",
3908 handler: func(rw ResponseWriter, r *Request) {
3909 },
3910 check: func(got, logs string) error {
3911 if !strings.Contains(got, "Content-Length: 0") {
3912 return errors.New("want 0 content-length")
3913 }
3914 return nil
3915 },
3916 },
3917 {
3918 name: "only Header, no write",
3919 handler: func(rw ResponseWriter, r *Request) {
3920 rw.Header().Set("Some-Header", "some-value")
3921 },
3922 check: func(got, logs string) error {
3923 if !strings.Contains(got, "Some-Header") {
3924 return errors.New("didn't get header")
3925 }
3926 return nil
3927 },
3928 },
3929 {
3930 name: "WriteHeader call",
3931 handler: func(rw ResponseWriter, r *Request) {
3932 rw.WriteHeader(404)
3933 rw.Header().Set("Too-Late", "some-value")
3934 },
3935 check: func(got, logs string) error {
3936 if !strings.Contains(got, "404") {
3937 return errors.New("wrong status")
3938 }
3939 if strings.Contains(got, "Too-Late") {
3940 return errors.New("shouldn't have seen Too-Late")
3941 }
3942 return nil
3943 },
3944 },
3945 }
3946 for _, tc := range tests {
3947 ht := newHandlerTest(HandlerFunc(tc.handler))
3948 got := ht.rawResponse("GET / HTTP/1.1\nHost: golang.org")
3949 logs := ht.logbuf.String()
3950 if err := tc.check(got, logs); err != nil {
3951 t.Errorf("%s: %v\nGot response:\n%s\n\n%s", tc.name, err, got, logs)
3952 }
3953 }
3954 }
3955
3956 type errorListener struct {
3957 errs []error
3958 }
3959
3960 func (l *errorListener) Accept() (c net.Conn, err error) {
3961 if len(l.errs) == 0 {
3962 return nil, io.EOF
3963 }
3964 err = l.errs[0]
3965 l.errs = l.errs[1:]
3966 return
3967 }
3968
3969 func (l *errorListener) Close() error {
3970 return nil
3971 }
3972
3973 func (l *errorListener) Addr() net.Addr {
3974 return dummyAddr("test-address")
3975 }
3976
3977 func TestAcceptMaxFds(t *testing.T) {
3978 setParallel(t)
3979
3980 ln := &errorListener{[]error{
3981 &net.OpError{
3982 Op: "accept",
3983 Err: syscall.EMFILE,
3984 }}}
3985 server := &Server{
3986 Handler: HandlerFunc(HandlerFunc(func(ResponseWriter, *Request) {})),
3987 ErrorLog: log.New(io.Discard, "", 0),
3988 }
3989 err := server.Serve(ln)
3990 if err != io.EOF {
3991 t.Errorf("got error %v, want EOF", err)
3992 }
3993 }
3994
3995 func TestWriteAfterHijack(t *testing.T) {
3996 req := reqBytes("GET / HTTP/1.1\nHost: golang.org")
3997 var buf strings.Builder
3998 wrotec := make(chan bool, 1)
3999 conn := &rwTestConn{
4000 Reader: bytes.NewReader(req),
4001 Writer: &buf,
4002 closec: make(chan bool, 1),
4003 }
4004 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
4005 conn, bufrw, err := rw.(Hijacker).Hijack()
4006 if err != nil {
4007 t.Error(err)
4008 return
4009 }
4010 go func() {
4011 bufrw.Write([]byte("[hijack-to-bufw]"))
4012 bufrw.Flush()
4013 conn.Write([]byte("[hijack-to-conn]"))
4014 conn.Close()
4015 wrotec <- true
4016 }()
4017 })
4018 ln := &oneConnListener{conn: conn}
4019 go Serve(ln, handler)
4020 <-conn.closec
4021 <-wrotec
4022 if g, w := buf.String(), "[hijack-to-bufw][hijack-to-conn]"; g != w {
4023 t.Errorf("wrote %q; want %q", g, w)
4024 }
4025 }
4026
4027 func TestDoubleHijack(t *testing.T) {
4028 req := reqBytes("GET / HTTP/1.1\nHost: golang.org")
4029 var buf bytes.Buffer
4030 conn := &rwTestConn{
4031 Reader: bytes.NewReader(req),
4032 Writer: &buf,
4033 closec: make(chan bool, 1),
4034 }
4035 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
4036 conn, _, err := rw.(Hijacker).Hijack()
4037 if err != nil {
4038 t.Error(err)
4039 return
4040 }
4041 _, _, err = rw.(Hijacker).Hijack()
4042 if err == nil {
4043 t.Errorf("got err = nil; want err != nil")
4044 }
4045 conn.Close()
4046 })
4047 ln := &oneConnListener{conn: conn}
4048 go Serve(ln, handler)
4049 <-conn.closec
4050 }
4051
4052
4053
4054
4055
4056
4057
4058 func TestHTTP10ConnectionHeader(t *testing.T) {
4059 run(t, testHTTP10ConnectionHeader, []testMode{http1Mode})
4060 }
4061 func testHTTP10ConnectionHeader(t *testing.T, mode testMode) {
4062 mux := NewServeMux()
4063 mux.Handle("/", HandlerFunc(func(ResponseWriter, *Request) {}))
4064 ts := newClientServerTest(t, mode, mux).ts
4065
4066
4067 tests := []struct {
4068 req string
4069 expect []string
4070 }{
4071 {
4072 req: "GET / HTTP/1.0\r\n\r\n",
4073 expect: nil,
4074 },
4075 {
4076 req: "OPTIONS * HTTP/1.0\r\n\r\n",
4077 expect: nil,
4078 },
4079 {
4080 req: "GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n",
4081 expect: []string{"keep-alive"},
4082 },
4083 }
4084
4085 for _, tt := range tests {
4086 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
4087 if err != nil {
4088 t.Fatal("dial err:", err)
4089 }
4090
4091 _, err = fmt.Fprint(conn, tt.req)
4092 if err != nil {
4093 t.Fatal("conn write err:", err)
4094 }
4095
4096 resp, err := ReadResponse(bufio.NewReader(conn), &Request{Method: "GET"})
4097 if err != nil {
4098 t.Fatal("ReadResponse err:", err)
4099 }
4100 conn.Close()
4101 resp.Body.Close()
4102
4103 got := resp.Header["Connection"]
4104 if !slices.Equal(got, tt.expect) {
4105 t.Errorf("wrong Connection headers for request %q. Got %q expect %q", tt.req, got, tt.expect)
4106 }
4107 }
4108 }
4109
4110
4111 func TestServerReaderFromOrder(t *testing.T) { run(t, testServerReaderFromOrder) }
4112 func testServerReaderFromOrder(t *testing.T, mode testMode) {
4113 pr, pw := io.Pipe()
4114 const size = 3 << 20
4115 cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
4116 rw.Header().Set("Content-Type", "text/plain")
4117 done := make(chan bool)
4118 go func() {
4119 io.Copy(rw, pr)
4120 close(done)
4121 }()
4122 time.Sleep(25 * time.Millisecond)
4123 n, err := io.Copy(io.Discard, req.Body)
4124 if err != nil {
4125 t.Errorf("handler Copy: %v", err)
4126 return
4127 }
4128 if n != size {
4129 t.Errorf("handler Copy = %d; want %d", n, size)
4130 }
4131 pw.Write([]byte("hi"))
4132 pw.Close()
4133 <-done
4134 }))
4135
4136 req, err := NewRequest("POST", cst.ts.URL, io.LimitReader(neverEnding('a'), size))
4137 if err != nil {
4138 t.Fatal(err)
4139 }
4140 res, err := cst.c.Do(req)
4141 if err != nil {
4142 t.Fatal(err)
4143 }
4144 all, err := io.ReadAll(res.Body)
4145 if err != nil {
4146 t.Fatal(err)
4147 }
4148 res.Body.Close()
4149 if string(all) != "hi" {
4150 t.Errorf("Body = %q; want hi", all)
4151 }
4152 }
4153
4154
4155 func TestCodesPreventingContentTypeAndBody(t *testing.T) {
4156 for _, code := range []int{StatusNotModified, StatusNoContent} {
4157 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
4158 if r.URL.Path == "/header" {
4159 w.Header().Set("Content-Length", "123")
4160 }
4161 w.WriteHeader(code)
4162 if r.URL.Path == "/more" {
4163 w.Write([]byte("stuff"))
4164 }
4165 }))
4166 for _, req := range []string{
4167 "GET / HTTP/1.0",
4168 "GET /header HTTP/1.0",
4169 "GET /more HTTP/1.0",
4170 "GET / HTTP/1.1\nHost: foo",
4171 "GET /header HTTP/1.1\nHost: foo",
4172 "GET /more HTTP/1.1\nHost: foo",
4173 } {
4174 got := ht.rawResponse(req)
4175 wantStatus := fmt.Sprintf("%d %s", code, StatusText(code))
4176 if !strings.Contains(got, wantStatus) {
4177 t.Errorf("Code %d: Wanted %q Modified for %q: %s", code, wantStatus, req, got)
4178 } else if strings.Contains(got, "Content-Length") {
4179 t.Errorf("Code %d: Got a Content-Length from %q: %s", code, req, got)
4180 } else if strings.Contains(got, "stuff") {
4181 t.Errorf("Code %d: Response contains a body from %q: %s", code, req, got)
4182 }
4183 }
4184 }
4185 }
4186
4187 func TestContentTypeOkayOn204(t *testing.T) {
4188 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
4189 w.Header().Set("Content-Length", "123")
4190 w.Header().Set("Content-Type", "foo/bar")
4191 w.WriteHeader(204)
4192 }))
4193 got := ht.rawResponse("GET / HTTP/1.1\nHost: foo")
4194 if !strings.Contains(got, "Content-Type: foo/bar") {
4195 t.Errorf("Response = %q; want Content-Type: foo/bar", got)
4196 }
4197 if strings.Contains(got, "Content-Length: 123") {
4198 t.Errorf("Response = %q; don't want a Content-Length", got)
4199 }
4200 }
4201
4202
4203
4204
4205
4206
4207
4208 func TestTransportAndServerSharedBodyRace(t *testing.T) {
4209 run(t, testTransportAndServerSharedBodyRace, testNotParallel)
4210 }
4211 func testTransportAndServerSharedBodyRace(t *testing.T, mode testMode) {
4212
4213
4214
4215
4216 runTimeSensitiveTest(t, []time.Duration{
4217 1 * time.Millisecond,
4218 5 * time.Millisecond,
4219 10 * time.Millisecond,
4220 50 * time.Millisecond,
4221 100 * time.Millisecond,
4222 500 * time.Millisecond,
4223 time.Second,
4224 5 * time.Second,
4225 }, func(t *testing.T, timeout time.Duration) error {
4226 SetRSTAvoidanceDelay(t, timeout)
4227 t.Logf("set RST avoidance delay to %v", timeout)
4228
4229 const bodySize = 1 << 20
4230
4231 var wg sync.WaitGroup
4232 backend := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
4233
4234
4235
4236
4237
4238
4239
4240
4241 wg.Add(1)
4242 defer wg.Done()
4243
4244 n, err := io.CopyN(rw, req.Body, bodySize)
4245 t.Logf("backend CopyN: %v, %v", n, err)
4246 <-req.Context().Done()
4247 }))
4248
4249
4250 defer func() {
4251 wg.Wait()
4252 backend.close()
4253 }()
4254
4255 var proxy *clientServerTest
4256 proxy = newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
4257 req2, _ := NewRequest("POST", backend.ts.URL, req.Body)
4258 req2.ContentLength = bodySize
4259 cancel := make(chan struct{})
4260 req2.Cancel = cancel
4261
4262 bresp, err := proxy.c.Do(req2)
4263 if err != nil {
4264 t.Errorf("Proxy outbound request: %v", err)
4265 return
4266 }
4267 _, err = io.CopyN(io.Discard, bresp.Body, bodySize/2)
4268 if err != nil {
4269 t.Errorf("Proxy copy error: %v", err)
4270 return
4271 }
4272 t.Cleanup(func() { bresp.Body.Close() })
4273
4274
4275
4276
4277
4278
4279 if mode == http2Mode {
4280 close(cancel)
4281 } else {
4282 proxy.c.Transport.(*Transport).CancelRequest(req2)
4283 }
4284 rw.Write([]byte("OK"))
4285 }))
4286 defer proxy.close()
4287
4288 req, _ := NewRequest("POST", proxy.ts.URL, io.LimitReader(neverEnding('a'), bodySize))
4289 res, err := proxy.c.Do(req)
4290 if err != nil {
4291 return fmt.Errorf("original request: %v", err)
4292 }
4293 res.Body.Close()
4294 return nil
4295 })
4296 }
4297
4298
4299
4300
4301 func TestRequestBodyCloseDoesntBlock(t *testing.T) {
4302 run(t, testRequestBodyCloseDoesntBlock, []testMode{http1Mode})
4303 }
4304 func testRequestBodyCloseDoesntBlock(t *testing.T, mode testMode) {
4305 if testing.Short() {
4306 t.Skip("skipping in -short mode")
4307 }
4308
4309 readErrCh := make(chan error, 1)
4310 errCh := make(chan error, 2)
4311
4312 server := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
4313 go func(body io.Reader) {
4314 _, err := body.Read(make([]byte, 100))
4315 readErrCh <- err
4316 }(req.Body)
4317 time.Sleep(500 * time.Millisecond)
4318 })).ts
4319
4320 closeConn := make(chan bool)
4321 defer close(closeConn)
4322 go func() {
4323 conn, err := net.Dial("tcp", server.Listener.Addr().String())
4324 if err != nil {
4325 errCh <- err
4326 return
4327 }
4328 defer conn.Close()
4329 _, err = conn.Write([]byte("POST / HTTP/1.1\r\nConnection: close\r\nHost: foo\r\nContent-Length: 100000\r\n\r\n"))
4330 if err != nil {
4331 errCh <- err
4332 return
4333 }
4334
4335
4336 <-closeConn
4337 }()
4338 select {
4339 case err := <-readErrCh:
4340 if err == nil {
4341 t.Error("Read was nil. Expected error.")
4342 }
4343 case err := <-errCh:
4344 t.Error(err)
4345 }
4346 }
4347
4348
4349 func TestResponseWriterWriteString(t *testing.T) {
4350 okc := make(chan bool, 1)
4351 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
4352 _, ok := w.(io.StringWriter)
4353 okc <- ok
4354 }))
4355 ht.rawResponse("GET / HTTP/1.0")
4356 select {
4357 case ok := <-okc:
4358 if !ok {
4359 t.Error("ResponseWriter did not implement io.StringWriter")
4360 }
4361 default:
4362 t.Error("handler was never called")
4363 }
4364 }
4365
4366 func TestServerConnState(t *testing.T) { run(t, testServerConnState, []testMode{http1Mode}) }
4367 func testServerConnState(t *testing.T, mode testMode) {
4368 handler := map[string]func(w ResponseWriter, r *Request){
4369 "/": func(w ResponseWriter, r *Request) {
4370 fmt.Fprintf(w, "Hello.")
4371 },
4372 "/close": func(w ResponseWriter, r *Request) {
4373 w.Header().Set("Connection", "close")
4374 fmt.Fprintf(w, "Hello.")
4375 },
4376 "/hijack": func(w ResponseWriter, r *Request) {
4377 c, _, _ := w.(Hijacker).Hijack()
4378 c.Write([]byte("HTTP/1.0 200 OK\r\nConnection: close\r\n\r\nHello."))
4379 c.Close()
4380 },
4381 "/hijack-panic": func(w ResponseWriter, r *Request) {
4382 c, _, _ := w.(Hijacker).Hijack()
4383 c.Write([]byte("HTTP/1.0 200 OK\r\nConnection: close\r\n\r\nHello."))
4384 c.Close()
4385 panic("intentional panic")
4386 },
4387 }
4388
4389
4390 type stateLog struct {
4391 active net.Conn
4392 got []ConnState
4393 want []ConnState
4394 complete chan<- struct{}
4395 }
4396 activeLog := make(chan *stateLog, 1)
4397
4398
4399
4400
4401 wantLog := func(doRequests func(), want ...ConnState) {
4402 t.Helper()
4403 complete := make(chan struct{})
4404 activeLog <- &stateLog{want: want, complete: complete}
4405
4406 doRequests()
4407
4408 <-complete
4409 sl := <-activeLog
4410 if !slices.Equal(sl.got, sl.want) {
4411 t.Errorf("Request(s) produced unexpected state sequence.\nGot: %v\nWant: %v", sl.got, sl.want)
4412 }
4413
4414
4415
4416 }
4417
4418 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4419 handler[r.URL.Path](w, r)
4420 }), func(ts *httptest.Server) {
4421 ts.Config.ErrorLog = log.New(io.Discard, "", 0)
4422 ts.Config.ConnState = func(c net.Conn, state ConnState) {
4423 if c == nil {
4424 t.Errorf("nil conn seen in state %s", state)
4425 return
4426 }
4427 sl := <-activeLog
4428 if sl.active == nil && state == StateNew {
4429 sl.active = c
4430 } else if sl.active != c {
4431 t.Errorf("unexpected conn in state %s", state)
4432 activeLog <- sl
4433 return
4434 }
4435 sl.got = append(sl.got, state)
4436 if sl.complete != nil && (len(sl.got) >= len(sl.want) || !slices.Equal(sl.got, sl.want[:len(sl.got)])) {
4437 close(sl.complete)
4438 sl.complete = nil
4439 }
4440 activeLog <- sl
4441 }
4442 }).ts
4443 defer func() {
4444 activeLog <- &stateLog{}
4445 ts.Close()
4446 }()
4447
4448 c := ts.Client()
4449
4450 mustGet := func(url string, headers ...string) {
4451 t.Helper()
4452 req, err := NewRequest("GET", url, nil)
4453 if err != nil {
4454 t.Fatal(err)
4455 }
4456 for len(headers) > 0 {
4457 req.Header.Add(headers[0], headers[1])
4458 headers = headers[2:]
4459 }
4460 res, err := c.Do(req)
4461 if err != nil {
4462 t.Errorf("Error fetching %s: %v", url, err)
4463 return
4464 }
4465 _, err = io.ReadAll(res.Body)
4466 defer res.Body.Close()
4467 if err != nil {
4468 t.Errorf("Error reading %s: %v", url, err)
4469 }
4470 }
4471
4472 wantLog(func() {
4473 mustGet(ts.URL + "/")
4474 mustGet(ts.URL + "/close")
4475 }, StateNew, StateActive, StateIdle, StateActive, StateClosed)
4476
4477 wantLog(func() {
4478 mustGet(ts.URL + "/")
4479 mustGet(ts.URL+"/", "Connection", "close")
4480 }, StateNew, StateActive, StateIdle, StateActive, StateClosed)
4481
4482 wantLog(func() {
4483 mustGet(ts.URL + "/hijack")
4484 }, StateNew, StateActive, StateHijacked)
4485
4486 wantLog(func() {
4487 mustGet(ts.URL + "/hijack-panic")
4488 }, StateNew, StateActive, StateHijacked)
4489
4490 wantLog(func() {
4491 c, err := net.Dial("tcp", ts.Listener.Addr().String())
4492 if err != nil {
4493 t.Fatal(err)
4494 }
4495 c.Close()
4496 }, StateNew, StateClosed)
4497
4498 wantLog(func() {
4499 c, err := net.Dial("tcp", ts.Listener.Addr().String())
4500 if err != nil {
4501 t.Fatal(err)
4502 }
4503 if _, err := io.WriteString(c, "BOGUS REQUEST\r\n\r\n"); err != nil {
4504 t.Fatal(err)
4505 }
4506 c.Read(make([]byte, 1))
4507 c.Close()
4508 }, StateNew, StateActive, StateClosed)
4509
4510 wantLog(func() {
4511 c, err := net.Dial("tcp", ts.Listener.Addr().String())
4512 if err != nil {
4513 t.Fatal(err)
4514 }
4515 if _, err := io.WriteString(c, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n"); err != nil {
4516 t.Fatal(err)
4517 }
4518 res, err := ReadResponse(bufio.NewReader(c), nil)
4519 if err != nil {
4520 t.Fatal(err)
4521 }
4522 if _, err := io.Copy(io.Discard, res.Body); err != nil {
4523 t.Fatal(err)
4524 }
4525 c.Close()
4526 }, StateNew, StateActive, StateIdle, StateClosed)
4527 }
4528
4529 func TestServerKeepAlivesEnabledResultClose(t *testing.T) {
4530 run(t, testServerKeepAlivesEnabledResultClose, []testMode{http1Mode})
4531 }
4532 func testServerKeepAlivesEnabledResultClose(t *testing.T, mode testMode) {
4533 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4534 }), func(ts *httptest.Server) {
4535 ts.Config.SetKeepAlivesEnabled(false)
4536 }).ts
4537 res, err := ts.Client().Get(ts.URL)
4538 if err != nil {
4539 t.Fatal(err)
4540 }
4541 defer res.Body.Close()
4542 if !res.Close {
4543 t.Errorf("Body.Close == false; want true")
4544 }
4545 }
4546
4547
4548 func TestServerEmptyBodyRace(t *testing.T) { run(t, testServerEmptyBodyRace) }
4549 func testServerEmptyBodyRace(t *testing.T, mode testMode) {
4550 var n int32
4551 cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
4552 atomic.AddInt32(&n, 1)
4553 }), optQuietLog)
4554 var wg sync.WaitGroup
4555 const reqs = 20
4556 for i := 0; i < reqs; i++ {
4557 wg.Add(1)
4558 go func() {
4559 defer wg.Done()
4560 res, err := cst.c.Get(cst.ts.URL)
4561 if err != nil {
4562
4563
4564 time.Sleep(10 * time.Millisecond)
4565 res, err = cst.c.Get(cst.ts.URL)
4566 if err != nil {
4567 t.Error(err)
4568 return
4569 }
4570 }
4571 defer res.Body.Close()
4572 _, err = io.Copy(io.Discard, res.Body)
4573 if err != nil {
4574 t.Error(err)
4575 return
4576 }
4577 }()
4578 }
4579 wg.Wait()
4580 if got := atomic.LoadInt32(&n); got != reqs {
4581 t.Errorf("handler ran %d times; want %d", got, reqs)
4582 }
4583 }
4584
4585 func TestServerConnStateNew(t *testing.T) {
4586 sawNew := false
4587 srv := &Server{
4588 ConnState: func(c net.Conn, state ConnState) {
4589 if state == StateNew {
4590 sawNew = true
4591 }
4592 },
4593 Handler: HandlerFunc(func(w ResponseWriter, r *Request) {}),
4594 }
4595 srv.Serve(&oneConnListener{
4596 conn: &rwTestConn{
4597 Reader: strings.NewReader("GET / HTTP/1.1\r\nHost: foo\r\n\r\n"),
4598 Writer: io.Discard,
4599 },
4600 })
4601 if !sawNew {
4602 t.Error("StateNew not seen")
4603 }
4604 }
4605
4606 type closeWriteTestConn struct {
4607 rwTestConn
4608 didCloseWrite bool
4609 }
4610
4611 func (c *closeWriteTestConn) CloseWrite() error {
4612 c.didCloseWrite = true
4613 return nil
4614 }
4615
4616 func TestCloseWrite(t *testing.T) {
4617 SetRSTAvoidanceDelay(t, 1*time.Millisecond)
4618
4619 var srv Server
4620 var testConn closeWriteTestConn
4621 c := ExportServerNewConn(&srv, &testConn)
4622 ExportCloseWriteAndWait(c)
4623 if !testConn.didCloseWrite {
4624 t.Error("didn't see CloseWrite call")
4625 }
4626 }
4627
4628
4629
4630
4631
4632
4633
4634
4635 func TestServerFlushAndHijack(t *testing.T) { run(t, testServerFlushAndHijack, []testMode{http1Mode}) }
4636 func testServerFlushAndHijack(t *testing.T, mode testMode) {
4637 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4638 io.WriteString(w, "Hello, ")
4639 w.(Flusher).Flush()
4640 conn, buf, _ := w.(Hijacker).Hijack()
4641 buf.WriteString("6\r\nworld!\r\n0\r\n\r\n")
4642 if err := buf.Flush(); err != nil {
4643 t.Error(err)
4644 }
4645 if err := conn.Close(); err != nil {
4646 t.Error(err)
4647 }
4648 })).ts
4649 res, err := Get(ts.URL)
4650 if err != nil {
4651 t.Fatal(err)
4652 }
4653 defer res.Body.Close()
4654 all, err := io.ReadAll(res.Body)
4655 if err != nil {
4656 t.Fatal(err)
4657 }
4658 if want := "Hello, world!"; string(all) != want {
4659 t.Errorf("Got %q; want %q", all, want)
4660 }
4661 }
4662
4663
4664
4665
4666
4667
4668
4669 func TestServerKeepAliveAfterWriteError(t *testing.T) {
4670 run(t, testServerKeepAliveAfterWriteError, []testMode{http1Mode})
4671 }
4672 func testServerKeepAliveAfterWriteError(t *testing.T, mode testMode) {
4673 if testing.Short() {
4674 t.Skip("skipping in -short mode")
4675 }
4676 const numReq = 3
4677 addrc := make(chan string, numReq)
4678 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4679 addrc <- r.RemoteAddr
4680 time.Sleep(500 * time.Millisecond)
4681 w.(Flusher).Flush()
4682 }), func(ts *httptest.Server) {
4683 ts.Config.WriteTimeout = 250 * time.Millisecond
4684 }).ts
4685
4686 errc := make(chan error, numReq)
4687 go func() {
4688 defer close(errc)
4689 for i := 0; i < numReq; i++ {
4690 res, err := Get(ts.URL)
4691 if res != nil {
4692 res.Body.Close()
4693 }
4694 errc <- err
4695 }
4696 }()
4697
4698 addrSeen := map[string]bool{}
4699 numOkay := 0
4700 for {
4701 select {
4702 case v := <-addrc:
4703 addrSeen[v] = true
4704 case err, ok := <-errc:
4705 if !ok {
4706 if len(addrSeen) != numReq {
4707 t.Errorf("saw %d unique client addresses; want %d", len(addrSeen), numReq)
4708 }
4709 if numOkay != 0 {
4710 t.Errorf("got %d successful client requests; want 0", numOkay)
4711 }
4712 return
4713 }
4714 if err == nil {
4715 numOkay++
4716 }
4717 }
4718 }
4719 }
4720
4721
4722
4723 func TestNoContentLengthIfTransferEncoding(t *testing.T) {
4724 run(t, testNoContentLengthIfTransferEncoding, []testMode{http1Mode})
4725 }
4726 func testNoContentLengthIfTransferEncoding(t *testing.T, mode testMode) {
4727 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4728 w.Header().Set("Transfer-Encoding", "foo")
4729 io.WriteString(w, "<html>")
4730 })).ts
4731 c, err := net.Dial("tcp", ts.Listener.Addr().String())
4732 if err != nil {
4733 t.Fatalf("Dial: %v", err)
4734 }
4735 defer c.Close()
4736 if _, err := io.WriteString(c, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n"); err != nil {
4737 t.Fatal(err)
4738 }
4739 bs := bufio.NewScanner(c)
4740 var got strings.Builder
4741 for bs.Scan() {
4742 if strings.TrimSpace(bs.Text()) == "" {
4743 break
4744 }
4745 got.WriteString(bs.Text())
4746 got.WriteByte('\n')
4747 }
4748 if err := bs.Err(); err != nil {
4749 t.Fatal(err)
4750 }
4751 if strings.Contains(got.String(), "Content-Length") {
4752 t.Errorf("Unexpected Content-Length in response headers: %s", got.String())
4753 }
4754 if strings.Contains(got.String(), "Content-Type") {
4755 t.Errorf("Unexpected Content-Type in response headers: %s", got.String())
4756 }
4757 }
4758
4759
4760
4761 func TestTolerateCRLFBeforeRequestLine(t *testing.T) {
4762 req := []byte("POST / HTTP/1.1\r\nHost: golang.org\r\nContent-Length: 3\r\n\r\nABC" +
4763 "\r\n\r\n" +
4764 "GET / HTTP/1.1\r\nHost: golang.org\r\n\r\n")
4765 var buf bytes.Buffer
4766 conn := &rwTestConn{
4767 Reader: bytes.NewReader(req),
4768 Writer: &buf,
4769 closec: make(chan bool, 1),
4770 }
4771 ln := &oneConnListener{conn: conn}
4772 numReq := 0
4773 go Serve(ln, HandlerFunc(func(rw ResponseWriter, r *Request) {
4774 numReq++
4775 }))
4776 <-conn.closec
4777 if numReq != 2 {
4778 t.Errorf("num requests = %d; want 2", numReq)
4779 t.Logf("Res: %s", buf.Bytes())
4780 }
4781 }
4782
4783 func TestIssue13893_Expect100(t *testing.T) {
4784
4785 req := reqBytes(`PUT /readbody HTTP/1.1
4786 User-Agent: PycURL/7.22.0
4787 Host: 127.0.0.1:9000
4788 Accept: */*
4789 Expect: 100-continue
4790 Content-Length: 10
4791
4792 HelloWorld
4793
4794 `)
4795 var buf bytes.Buffer
4796 conn := &rwTestConn{
4797 Reader: bytes.NewReader(req),
4798 Writer: &buf,
4799 closec: make(chan bool, 1),
4800 }
4801 ln := &oneConnListener{conn: conn}
4802 go Serve(ln, HandlerFunc(func(w ResponseWriter, r *Request) {
4803 if _, ok := r.Header["Expect"]; !ok {
4804 t.Error("Expect header should not be filtered out")
4805 }
4806 }))
4807 <-conn.closec
4808 }
4809
4810 func TestIssue11549_Expect100(t *testing.T) {
4811 req := reqBytes(`PUT /readbody HTTP/1.1
4812 User-Agent: PycURL/7.22.0
4813 Host: 127.0.0.1:9000
4814 Accept: */*
4815 Expect: 100-continue
4816 Content-Length: 10
4817
4818 HelloWorldPUT /noreadbody HTTP/1.1
4819 User-Agent: PycURL/7.22.0
4820 Host: 127.0.0.1:9000
4821 Accept: */*
4822 Expect: 100-continue
4823 Content-Length: 10
4824
4825 GET /should-be-ignored HTTP/1.1
4826 Host: foo
4827
4828 `)
4829 var buf strings.Builder
4830 conn := &rwTestConn{
4831 Reader: bytes.NewReader(req),
4832 Writer: &buf,
4833 closec: make(chan bool, 1),
4834 }
4835 ln := &oneConnListener{conn: conn}
4836 numReq := 0
4837 go Serve(ln, HandlerFunc(func(w ResponseWriter, r *Request) {
4838 numReq++
4839 if r.URL.Path == "/readbody" {
4840 io.ReadAll(r.Body)
4841 }
4842 io.WriteString(w, "Hello world!")
4843 }))
4844 <-conn.closec
4845 if numReq != 2 {
4846 t.Errorf("num requests = %d; want 2", numReq)
4847 }
4848 if !strings.Contains(buf.String(), "Connection: close\r\n") {
4849 t.Errorf("expected 'Connection: close' in response; got: %s", buf.String())
4850 }
4851 }
4852
4853
4854
4855 func TestHandlerFinishSkipBigContentLengthRead(t *testing.T) {
4856 setParallel(t)
4857 conn := newTestConn()
4858 conn.readBuf.WriteString(
4859 "POST / HTTP/1.1\r\n" +
4860 "Host: test\r\n" +
4861 "Content-Length: 9999999999\r\n" +
4862 "\r\n" + strings.Repeat("a", 1<<20))
4863
4864 ls := &oneConnListener{conn}
4865 var inHandlerLen int
4866 go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
4867 inHandlerLen = conn.readBuf.Len()
4868 rw.WriteHeader(404)
4869 }))
4870 <-conn.closec
4871 afterHandlerLen := conn.readBuf.Len()
4872
4873 if afterHandlerLen != inHandlerLen {
4874 t.Errorf("unexpected implicit read. Read buffer went from %d -> %d", inHandlerLen, afterHandlerLen)
4875 }
4876 }
4877
4878 func TestHandlerSetsBodyNil(t *testing.T) { run(t, testHandlerSetsBodyNil) }
4879 func testHandlerSetsBodyNil(t *testing.T, mode testMode) {
4880 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4881 r.Body = nil
4882 fmt.Fprintf(w, "%v", r.RemoteAddr)
4883 }))
4884 get := func() string {
4885 res, err := cst.c.Get(cst.ts.URL)
4886 if err != nil {
4887 t.Fatal(err)
4888 }
4889 defer res.Body.Close()
4890 slurp, err := io.ReadAll(res.Body)
4891 if err != nil {
4892 t.Fatal(err)
4893 }
4894 return string(slurp)
4895 }
4896 a, b := get(), get()
4897 if a != b {
4898 t.Errorf("Failed to reuse connections between requests: %v vs %v", a, b)
4899 }
4900 }
4901
4902
4903
4904 func TestServerValidatesHostHeader(t *testing.T) {
4905 tests := []struct {
4906 proto string
4907 host string
4908 want int
4909 }{
4910 {"HTTP/0.9", "", 505},
4911
4912 {"HTTP/1.1", "", 400},
4913 {"HTTP/1.1", "Host: \r\n", 200},
4914 {"HTTP/1.1", "Host: 1.2.3.4\r\n", 200},
4915 {"HTTP/1.1", "Host: foo.com\r\n", 200},
4916 {"HTTP/1.1", "Host: foo-bar_baz.com\r\n", 200},
4917 {"HTTP/1.1", "Host: foo.com:80\r\n", 200},
4918 {"HTTP/1.1", "Host: ::1\r\n", 200},
4919 {"HTTP/1.1", "Host: [::1]\r\n", 200},
4920 {"HTTP/1.1", "Host: [::1]:80\r\n", 200},
4921 {"HTTP/1.1", "Host: [::1%25en0]:80\r\n", 200},
4922 {"HTTP/1.1", "Host: 1.2.3.4\r\n", 200},
4923 {"HTTP/1.1", "Host: \x06\r\n", 400},
4924 {"HTTP/1.1", "Host: \xff\r\n", 400},
4925 {"HTTP/1.1", "Host: {\r\n", 400},
4926 {"HTTP/1.1", "Host: }\r\n", 400},
4927 {"HTTP/1.1", "Host: first\r\nHost: second\r\n", 400},
4928
4929
4930
4931 {"HTTP/1.0", "", 200},
4932 {"HTTP/1.0", "Host: first\r\nHost: second\r\n", 400},
4933 {"HTTP/1.0", "Host: \xff\r\n", 400},
4934
4935
4936 {"PRI * HTTP/2.0", "", 200},
4937
4938
4939 {"CONNECT golang.org:443 HTTP/1.1", "", 200},
4940
4941
4942 {"PRI / HTTP/2.0", "", 505},
4943 {"GET / HTTP/2.0", "", 505},
4944 {"GET / HTTP/3.0", "", 505},
4945 }
4946 for _, tt := range tests {
4947 conn := newTestConn()
4948 methodTarget := "GET / "
4949 if !strings.HasPrefix(tt.proto, "HTTP/") {
4950 methodTarget = ""
4951 }
4952 io.WriteString(&conn.readBuf, methodTarget+tt.proto+"\r\n"+tt.host+"\r\n")
4953
4954 ln := &oneConnListener{conn}
4955 srv := Server{
4956 ErrorLog: quietLog,
4957 Handler: HandlerFunc(func(ResponseWriter, *Request) {}),
4958 }
4959 go srv.Serve(ln)
4960 <-conn.closec
4961 res, err := ReadResponse(bufio.NewReader(&conn.writeBuf), nil)
4962 if err != nil {
4963 t.Errorf("For %s %q, ReadResponse: %v", tt.proto, tt.host, res)
4964 continue
4965 }
4966 if res.StatusCode != tt.want {
4967 t.Errorf("For %s %q, Status = %d; want %d", tt.proto, tt.host, res.StatusCode, tt.want)
4968 }
4969 }
4970 }
4971
4972 func TestServerHandlersCanHandleH2PRI(t *testing.T) {
4973 run(t, testServerHandlersCanHandleH2PRI, []testMode{http1Mode})
4974 }
4975 func testServerHandlersCanHandleH2PRI(t *testing.T, mode testMode) {
4976 const upgradeResponse = "upgrade here"
4977 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4978 conn, br, err := w.(Hijacker).Hijack()
4979 if err != nil {
4980 t.Error(err)
4981 return
4982 }
4983 defer conn.Close()
4984 if r.Method != "PRI" || r.RequestURI != "*" {
4985 t.Errorf("Got method/target %q %q; want PRI *", r.Method, r.RequestURI)
4986 return
4987 }
4988 if !r.Close {
4989 t.Errorf("Request.Close = true; want false")
4990 }
4991 const want = "SM\r\n\r\n"
4992 buf := make([]byte, len(want))
4993 n, err := io.ReadFull(br, buf)
4994 if err != nil || string(buf[:n]) != want {
4995 t.Errorf("Read = %v, %v (%q), want %q", n, err, buf[:n], want)
4996 return
4997 }
4998 io.WriteString(conn, upgradeResponse)
4999 })).ts
5000
5001 c, err := net.Dial("tcp", ts.Listener.Addr().String())
5002 if err != nil {
5003 t.Fatalf("Dial: %v", err)
5004 }
5005 defer c.Close()
5006 io.WriteString(c, "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n")
5007 slurp, err := io.ReadAll(c)
5008 if err != nil {
5009 t.Fatal(err)
5010 }
5011 if string(slurp) != upgradeResponse {
5012 t.Errorf("Handler response = %q; want %q", slurp, upgradeResponse)
5013 }
5014 }
5015
5016
5017
5018 func TestServerValidatesHeaders(t *testing.T) {
5019 setParallel(t)
5020 tests := []struct {
5021 header string
5022 want int
5023 }{
5024 {"", 200},
5025 {"Foo: bar\r\n", 200},
5026 {"X-Foo: bar\r\n", 200},
5027 {"Foo: a space\r\n", 200},
5028
5029 {"A space: foo\r\n", 400},
5030 {"foo\xffbar: foo\r\n", 400},
5031 {"foo\x00bar: foo\r\n", 400},
5032 {"Foo: " + strings.Repeat("x", 1<<21) + "\r\n", 431},
5033
5034
5035 {"Foo : bar\r\n", 400},
5036 {"Foo\t: bar\r\n", 400},
5037
5038
5039
5040 {": empty key\r\n", 400},
5041
5042
5043
5044
5045 {"Content-Length: notdigits\r\n", 400},
5046 {"Content-Length: notdigits\r\nTransfer-Encoding: chunked\r\n\r\n0\r\n\r\n", 400},
5047
5048 {"foo: foo foo\r\n", 200},
5049 {"foo: foo\tfoo\r\n", 200},
5050 {"foo: foo\x00foo\r\n", 400},
5051 {"foo: foo\x7ffoo\r\n", 400},
5052 {"foo: foo\xfffoo\r\n", 200},
5053 }
5054 for _, tt := range tests {
5055 conn := newTestConn()
5056 io.WriteString(&conn.readBuf, "GET / HTTP/1.1\r\nHost: foo\r\n"+tt.header+"\r\n")
5057
5058 ln := &oneConnListener{conn}
5059 srv := Server{
5060 ErrorLog: quietLog,
5061 Handler: HandlerFunc(func(ResponseWriter, *Request) {}),
5062 }
5063 go srv.Serve(ln)
5064 <-conn.closec
5065 res, err := ReadResponse(bufio.NewReader(&conn.writeBuf), nil)
5066 if err != nil {
5067 t.Errorf("For %q, ReadResponse: %v", tt.header, res)
5068 continue
5069 }
5070 if res.StatusCode != tt.want {
5071 t.Errorf("For %q, Status = %d; want %d", tt.header, res.StatusCode, tt.want)
5072 }
5073 }
5074 }
5075
5076 func TestServerRequestContextCancel_ServeHTTPDone(t *testing.T) {
5077 run(t, testServerRequestContextCancel_ServeHTTPDone)
5078 }
5079 func testServerRequestContextCancel_ServeHTTPDone(t *testing.T, mode testMode) {
5080 ctxc := make(chan context.Context, 1)
5081 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5082 ctx := r.Context()
5083 select {
5084 case <-ctx.Done():
5085 t.Error("should not be Done in ServeHTTP")
5086 default:
5087 }
5088 ctxc <- ctx
5089 }))
5090 res, err := cst.c.Get(cst.ts.URL)
5091 if err != nil {
5092 t.Fatal(err)
5093 }
5094 res.Body.Close()
5095 ctx := <-ctxc
5096 select {
5097 case <-ctx.Done():
5098 default:
5099 t.Error("context should be done after ServeHTTP completes")
5100 }
5101 }
5102
5103
5104
5105
5106
5107 func TestServerRequestContextCancel_ConnClose(t *testing.T) {
5108 run(t, testServerRequestContextCancel_ConnClose, []testMode{http1Mode})
5109 }
5110 func testServerRequestContextCancel_ConnClose(t *testing.T, mode testMode) {
5111 inHandler := make(chan struct{})
5112 handlerDone := make(chan struct{})
5113 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5114 close(inHandler)
5115 <-r.Context().Done()
5116 close(handlerDone)
5117 })).ts
5118 c, err := net.Dial("tcp", ts.Listener.Addr().String())
5119 if err != nil {
5120 t.Fatal(err)
5121 }
5122 defer c.Close()
5123 io.WriteString(c, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n")
5124 <-inHandler
5125 c.Close()
5126 <-handlerDone
5127 }
5128
5129 func TestServerContext_ServerContextKey(t *testing.T) {
5130 run(t, testServerContext_ServerContextKey)
5131 }
5132 func testServerContext_ServerContextKey(t *testing.T, mode testMode) {
5133 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5134 ctx := r.Context()
5135 got := ctx.Value(ServerContextKey)
5136 if _, ok := got.(*Server); !ok {
5137 t.Errorf("context value = %T; want *http.Server", got)
5138 }
5139 }))
5140 res, err := cst.c.Get(cst.ts.URL)
5141 if err != nil {
5142 t.Fatal(err)
5143 }
5144 res.Body.Close()
5145 }
5146
5147 func TestServerContext_LocalAddrContextKey(t *testing.T) {
5148 run(t, testServerContext_LocalAddrContextKey)
5149 }
5150 func testServerContext_LocalAddrContextKey(t *testing.T, mode testMode) {
5151 ch := make(chan any, 1)
5152 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5153 ch <- r.Context().Value(LocalAddrContextKey)
5154 }))
5155 if _, err := cst.c.Head(cst.ts.URL); err != nil {
5156 t.Fatal(err)
5157 }
5158
5159 host := cst.ts.Listener.Addr().String()
5160 got := <-ch
5161 if addr, ok := got.(net.Addr); !ok {
5162 t.Errorf("local addr value = %T; want net.Addr", got)
5163 } else if fmt.Sprint(addr) != host {
5164 t.Errorf("local addr = %v; want %v", addr, host)
5165 }
5166 }
5167
5168
5169 func TestHandlerSetTransferEncodingChunked(t *testing.T) {
5170 setParallel(t)
5171 defer afterTest(t)
5172 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
5173 w.Header().Set("Transfer-Encoding", "chunked")
5174 w.Write([]byte("hello"))
5175 }))
5176 resp := ht.rawResponse("GET / HTTP/1.1\nHost: foo")
5177 const hdr = "Transfer-Encoding: chunked"
5178 if n := strings.Count(resp, hdr); n != 1 {
5179 t.Errorf("want 1 occurrence of %q in response, got %v\nresponse: %v", hdr, n, resp)
5180 }
5181 }
5182
5183
5184 func TestHandlerSetTransferEncodingGzip(t *testing.T) {
5185 setParallel(t)
5186 defer afterTest(t)
5187 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
5188 w.Header().Set("Transfer-Encoding", "gzip")
5189 gz := gzip.NewWriter(w)
5190 gz.Write([]byte("hello"))
5191 gz.Close()
5192 }))
5193 resp := ht.rawResponse("GET / HTTP/1.1\nHost: foo")
5194 for _, v := range []string{"gzip", "chunked"} {
5195 hdr := "Transfer-Encoding: " + v
5196 if n := strings.Count(resp, hdr); n != 1 {
5197 t.Errorf("want 1 occurrence of %q in response, got %v\nresponse: %v", hdr, n, resp)
5198 }
5199 }
5200 }
5201
5202 func BenchmarkClientServer(b *testing.B) {
5203 run(b, benchmarkClientServer, []testMode{http1Mode, https1Mode, http2Mode})
5204 }
5205 func benchmarkClientServer(b *testing.B, mode testMode) {
5206 b.ReportAllocs()
5207 b.StopTimer()
5208 ts := newClientServerTest(b, mode, HandlerFunc(func(rw ResponseWriter, r *Request) {
5209 fmt.Fprintf(rw, "Hello world.\n")
5210 })).ts
5211 b.StartTimer()
5212
5213 c := ts.Client()
5214 for i := 0; i < b.N; i++ {
5215 res, err := c.Get(ts.URL)
5216 if err != nil {
5217 b.Fatal("Get:", err)
5218 }
5219 all, err := io.ReadAll(res.Body)
5220 res.Body.Close()
5221 if err != nil {
5222 b.Fatal("ReadAll:", err)
5223 }
5224 body := string(all)
5225 if body != "Hello world.\n" {
5226 b.Fatal("Got body:", body)
5227 }
5228 }
5229
5230 b.StopTimer()
5231 }
5232
5233 func BenchmarkClientServerParallel(b *testing.B) {
5234 for _, parallelism := range []int{4, 64} {
5235 b.Run(fmt.Sprint(parallelism), func(b *testing.B) {
5236 run(b, func(b *testing.B, mode testMode) {
5237 benchmarkClientServerParallel(b, parallelism, mode)
5238 }, []testMode{http1Mode, https1Mode, http2Mode})
5239 })
5240 }
5241 }
5242
5243 func benchmarkClientServerParallel(b *testing.B, parallelism int, mode testMode) {
5244 b.ReportAllocs()
5245 ts := newClientServerTest(b, mode, HandlerFunc(func(rw ResponseWriter, r *Request) {
5246 fmt.Fprintf(rw, "Hello world.\n")
5247 })).ts
5248 b.ResetTimer()
5249 b.SetParallelism(parallelism)
5250 b.RunParallel(func(pb *testing.PB) {
5251 c := ts.Client()
5252 for pb.Next() {
5253 res, err := c.Get(ts.URL)
5254 if err != nil {
5255 b.Logf("Get: %v", err)
5256 continue
5257 }
5258 all, err := io.ReadAll(res.Body)
5259 res.Body.Close()
5260 if err != nil {
5261 b.Logf("ReadAll: %v", err)
5262 continue
5263 }
5264 body := string(all)
5265 if body != "Hello world.\n" {
5266 panic("Got body: " + body)
5267 }
5268 }
5269 })
5270 }
5271
5272
5273
5274
5275
5276
5277
5278
5279
5280
5281 func BenchmarkServer(b *testing.B) {
5282 b.ReportAllocs()
5283
5284 if url := os.Getenv("TEST_BENCH_SERVER_URL"); url != "" {
5285 n, err := strconv.Atoi(os.Getenv("TEST_BENCH_CLIENT_N"))
5286 if err != nil {
5287 panic(err)
5288 }
5289 for i := 0; i < n; i++ {
5290 res, err := Get(url)
5291 if err != nil {
5292 log.Panicf("Get: %v", err)
5293 }
5294 all, err := io.ReadAll(res.Body)
5295 res.Body.Close()
5296 if err != nil {
5297 log.Panicf("ReadAll: %v", err)
5298 }
5299 body := string(all)
5300 if body != "Hello world.\n" {
5301 log.Panicf("Got body: %q", body)
5302 }
5303 }
5304 os.Exit(0)
5305 return
5306 }
5307
5308 var res = []byte("Hello world.\n")
5309 b.StopTimer()
5310 ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, r *Request) {
5311 rw.Header().Set("Content-Type", "text/html; charset=utf-8")
5312 rw.Write(res)
5313 }))
5314 defer ts.Close()
5315 b.StartTimer()
5316
5317 cmd := testenv.Command(b, os.Args[0], "-test.run=^$", "-test.bench=^BenchmarkServer$")
5318 cmd.Env = append([]string{
5319 fmt.Sprintf("TEST_BENCH_CLIENT_N=%d", b.N),
5320 fmt.Sprintf("TEST_BENCH_SERVER_URL=%s", ts.URL),
5321 }, os.Environ()...)
5322 out, err := cmd.CombinedOutput()
5323 if err != nil {
5324 b.Errorf("Test failure: %v, with output: %s", err, out)
5325 }
5326 }
5327
5328
5329 func getNoBody(urlStr string) (*Response, error) {
5330 res, err := Get(urlStr)
5331 if err != nil {
5332 return nil, err
5333 }
5334 res.Body.Close()
5335 return res, nil
5336 }
5337
5338
5339
5340 func BenchmarkClient(b *testing.B) {
5341 b.ReportAllocs()
5342 b.StopTimer()
5343 defer afterTest(b)
5344
5345 var data = []byte("Hello world.\n")
5346 if server := os.Getenv("TEST_BENCH_SERVER"); server != "" {
5347
5348 port := os.Getenv("TEST_BENCH_SERVER_PORT")
5349 if port == "" {
5350 port = "0"
5351 }
5352 ln, err := net.Listen("tcp", "localhost:"+port)
5353 if err != nil {
5354 fmt.Fprintln(os.Stderr, err.Error())
5355 os.Exit(1)
5356 }
5357 fmt.Println(ln.Addr().String())
5358 HandleFunc("/", func(w ResponseWriter, r *Request) {
5359 r.ParseForm()
5360 if r.Form.Get("stop") != "" {
5361 os.Exit(0)
5362 }
5363 w.Header().Set("Content-Type", "text/html; charset=utf-8")
5364 w.Write(data)
5365 })
5366 var srv Server
5367 log.Fatal(srv.Serve(ln))
5368 }
5369
5370
5371 ctx, cancel := context.WithCancel(context.Background())
5372 cmd := testenv.CommandContext(b, ctx, os.Args[0], "-test.run=^$", "-test.bench=^BenchmarkClient$")
5373 cmd.Env = append(cmd.Environ(), "TEST_BENCH_SERVER=yes")
5374 cmd.Stderr = os.Stderr
5375 stdout, err := cmd.StdoutPipe()
5376 if err != nil {
5377 b.Fatal(err)
5378 }
5379 if err := cmd.Start(); err != nil {
5380 b.Fatalf("subprocess failed to start: %v", err)
5381 }
5382
5383 done := make(chan error, 1)
5384 go func() {
5385 done <- cmd.Wait()
5386 close(done)
5387 }()
5388 defer func() {
5389 cancel()
5390 <-done
5391 }()
5392
5393
5394
5395 bs := bufio.NewScanner(stdout)
5396 if !bs.Scan() {
5397 b.Fatalf("failed to read listening URL from child: %v", bs.Err())
5398 }
5399 url := "http://" + strings.TrimSpace(bs.Text()) + "/"
5400 if _, err := getNoBody(url); err != nil {
5401 b.Fatalf("initial probe of child process failed: %v", err)
5402 }
5403
5404
5405 b.StartTimer()
5406 for i := 0; i < b.N; i++ {
5407 res, err := Get(url)
5408 if err != nil {
5409 b.Fatalf("Get: %v", err)
5410 }
5411 body, err := io.ReadAll(res.Body)
5412 res.Body.Close()
5413 if err != nil {
5414 b.Fatalf("ReadAll: %v", err)
5415 }
5416 if !bytes.Equal(body, data) {
5417 b.Fatalf("Got body: %q", body)
5418 }
5419 }
5420 b.StopTimer()
5421
5422
5423 getNoBody(url + "?stop=yes")
5424 if err := <-done; err != nil {
5425 b.Fatalf("subprocess failed: %v", err)
5426 }
5427 }
5428
5429 func BenchmarkServerFakeConnNoKeepAlive(b *testing.B) {
5430 b.ReportAllocs()
5431 req := reqBytes(`GET / HTTP/1.0
5432 Host: golang.org
5433 Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8
5434 User-Agent: Mozilla/5.0 (Macintosh; Intel Mac OS X 10_8_2) AppleWebKit/537.17 (KHTML, like Gecko) Chrome/24.0.1312.52 Safari/537.17
5435 Accept-Encoding: gzip,deflate,sdch
5436 Accept-Language: en-US,en;q=0.8
5437 Accept-Charset: ISO-8859-1,utf-8;q=0.7,*;q=0.3
5438 `)
5439 res := []byte("Hello world!\n")
5440
5441 conn := newTestConn()
5442 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
5443 rw.Header().Set("Content-Type", "text/html; charset=utf-8")
5444 rw.Write(res)
5445 })
5446 ln := new(oneConnListener)
5447 for i := 0; i < b.N; i++ {
5448 conn.readBuf.Reset()
5449 conn.writeBuf.Reset()
5450 conn.readBuf.Write(req)
5451 ln.conn = conn
5452 Serve(ln, handler)
5453 <-conn.closec
5454 }
5455 }
5456
5457
5458 type repeatReader struct {
5459 content []byte
5460 count int
5461 off int
5462 }
5463
5464 func (r *repeatReader) Read(p []byte) (n int, err error) {
5465 if r.count <= 0 {
5466 return 0, io.EOF
5467 }
5468 n = copy(p, r.content[r.off:])
5469 r.off += n
5470 if r.off == len(r.content) {
5471 r.count--
5472 r.off = 0
5473 }
5474 return
5475 }
5476
5477 func BenchmarkServerFakeConnWithKeepAlive(b *testing.B) {
5478 b.ReportAllocs()
5479
5480 req := reqBytes(`GET / HTTP/1.1
5481 Host: golang.org
5482 Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8
5483 User-Agent: Mozilla/5.0 (Macintosh; Intel Mac OS X 10_8_2) AppleWebKit/537.17 (KHTML, like Gecko) Chrome/24.0.1312.52 Safari/537.17
5484 Accept-Encoding: gzip,deflate,sdch
5485 Accept-Language: en-US,en;q=0.8
5486 Accept-Charset: ISO-8859-1,utf-8;q=0.7,*;q=0.3
5487 `)
5488 res := []byte("Hello world!\n")
5489
5490 conn := &rwTestConn{
5491 Reader: &repeatReader{content: req, count: b.N},
5492 Writer: io.Discard,
5493 closec: make(chan bool, 1),
5494 }
5495 handled := 0
5496 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
5497 handled++
5498 rw.Header().Set("Content-Type", "text/html; charset=utf-8")
5499 rw.Write(res)
5500 })
5501 ln := &oneConnListener{conn: conn}
5502 go Serve(ln, handler)
5503 <-conn.closec
5504 if b.N != handled {
5505 b.Errorf("b.N=%d but handled %d", b.N, handled)
5506 }
5507 }
5508
5509
5510
5511 func BenchmarkServerFakeConnWithKeepAliveLite(b *testing.B) {
5512 b.ReportAllocs()
5513
5514 req := reqBytes(`GET / HTTP/1.1
5515 Host: golang.org
5516 `)
5517 res := []byte("Hello world!\n")
5518
5519 conn := &rwTestConn{
5520 Reader: &repeatReader{content: req, count: b.N},
5521 Writer: io.Discard,
5522 closec: make(chan bool, 1),
5523 }
5524 handled := 0
5525 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
5526 handled++
5527 rw.Write(res)
5528 })
5529 ln := &oneConnListener{conn: conn}
5530 go Serve(ln, handler)
5531 <-conn.closec
5532 if b.N != handled {
5533 b.Errorf("b.N=%d but handled %d", b.N, handled)
5534 }
5535 }
5536
5537 const someResponse = "<html>some response</html>"
5538
5539
5540 var response = bytes.Repeat([]byte(someResponse), 2<<10/len(someResponse))
5541
5542
5543 func BenchmarkServerHandlerTypeLen(b *testing.B) {
5544 benchmarkHandler(b, HandlerFunc(func(w ResponseWriter, r *Request) {
5545 w.Header().Set("Content-Type", "text/html")
5546 w.Header().Set("Content-Length", strconv.Itoa(len(response)))
5547 w.Write(response)
5548 }))
5549 }
5550
5551
5552 func BenchmarkServerHandlerNoLen(b *testing.B) {
5553 benchmarkHandler(b, HandlerFunc(func(w ResponseWriter, r *Request) {
5554 w.Header().Set("Content-Type", "text/html")
5555 w.Write(response)
5556 }))
5557 }
5558
5559
5560 func BenchmarkServerHandlerNoType(b *testing.B) {
5561 benchmarkHandler(b, HandlerFunc(func(w ResponseWriter, r *Request) {
5562 w.Header().Set("Content-Length", strconv.Itoa(len(response)))
5563 w.Write(response)
5564 }))
5565 }
5566
5567
5568 func BenchmarkServerHandlerNoHeader(b *testing.B) {
5569 benchmarkHandler(b, HandlerFunc(func(w ResponseWriter, r *Request) {
5570 w.Write(response)
5571 }))
5572 }
5573
5574 func benchmarkHandler(b *testing.B, h Handler) {
5575 b.ReportAllocs()
5576 req := reqBytes(`GET / HTTP/1.1
5577 Host: golang.org
5578 `)
5579 conn := &rwTestConn{
5580 Reader: &repeatReader{content: req, count: b.N},
5581 Writer: io.Discard,
5582 closec: make(chan bool, 1),
5583 }
5584 handled := 0
5585 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
5586 handled++
5587 h.ServeHTTP(rw, r)
5588 })
5589 ln := &oneConnListener{conn: conn}
5590 go Serve(ln, handler)
5591 <-conn.closec
5592 if b.N != handled {
5593 b.Errorf("b.N=%d but handled %d", b.N, handled)
5594 }
5595 }
5596
5597 func BenchmarkServerHijack(b *testing.B) {
5598 b.ReportAllocs()
5599 req := reqBytes(`GET / HTTP/1.1
5600 Host: golang.org
5601 `)
5602 h := HandlerFunc(func(w ResponseWriter, r *Request) {
5603 conn, _, err := w.(Hijacker).Hijack()
5604 if err != nil {
5605 panic(err)
5606 }
5607 conn.Close()
5608 })
5609 conn := &rwTestConn{
5610 Writer: io.Discard,
5611 closec: make(chan bool, 1),
5612 }
5613 ln := &oneConnListener{conn: conn}
5614 for i := 0; i < b.N; i++ {
5615 conn.Reader = bytes.NewReader(req)
5616 ln.conn = conn
5617 Serve(ln, h)
5618 <-conn.closec
5619 }
5620 }
5621
5622 func BenchmarkCloseNotifier(b *testing.B) { run(b, benchmarkCloseNotifier, []testMode{http1Mode}) }
5623 func benchmarkCloseNotifier(b *testing.B, mode testMode) {
5624 b.ReportAllocs()
5625 b.StopTimer()
5626 sawClose := make(chan bool)
5627 ts := newClientServerTest(b, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
5628 <-rw.(CloseNotifier).CloseNotify()
5629 sawClose <- true
5630 })).ts
5631 b.StartTimer()
5632 for i := 0; i < b.N; i++ {
5633 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
5634 if err != nil {
5635 b.Fatalf("error dialing: %v", err)
5636 }
5637 _, err = fmt.Fprintf(conn, "GET / HTTP/1.1\r\nConnection: keep-alive\r\nHost: foo\r\n\r\n")
5638 if err != nil {
5639 b.Fatal(err)
5640 }
5641 conn.Close()
5642 <-sawClose
5643 }
5644 b.StopTimer()
5645 }
5646
5647
5648 func TestConcurrentServerServe(t *testing.T) {
5649 setParallel(t)
5650 for i := 0; i < 100; i++ {
5651 ln1 := &oneConnListener{conn: nil}
5652 ln2 := &oneConnListener{conn: nil}
5653 srv := Server{}
5654 go func() { srv.Serve(ln1) }()
5655 go func() { srv.Serve(ln2) }()
5656 }
5657 }
5658
5659 func TestServerIdleTimeout(t *testing.T) { run(t, testServerIdleTimeout, []testMode{http1Mode}) }
5660 func testServerIdleTimeout(t *testing.T, mode testMode) {
5661 if testing.Short() {
5662 t.Skip("skipping in short mode")
5663 }
5664 runTimeSensitiveTest(t, []time.Duration{
5665 10 * time.Millisecond,
5666 100 * time.Millisecond,
5667 1 * time.Second,
5668 10 * time.Second,
5669 }, func(t *testing.T, readHeaderTimeout time.Duration) error {
5670 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5671 io.Copy(io.Discard, r.Body)
5672 io.WriteString(w, r.RemoteAddr)
5673 }), func(ts *httptest.Server) {
5674 ts.Config.ReadHeaderTimeout = readHeaderTimeout
5675 ts.Config.IdleTimeout = 2 * readHeaderTimeout
5676 })
5677 defer cst.close()
5678 ts := cst.ts
5679 t.Logf("ReadHeaderTimeout = %v", ts.Config.ReadHeaderTimeout)
5680 t.Logf("IdleTimeout = %v", ts.Config.IdleTimeout)
5681 c := ts.Client()
5682
5683 get := func() (string, error) {
5684 res, err := c.Get(ts.URL)
5685 if err != nil {
5686 return "", err
5687 }
5688 defer res.Body.Close()
5689 slurp, err := io.ReadAll(res.Body)
5690 if err != nil {
5691
5692
5693
5694 t.Fatal(err)
5695 }
5696 return string(slurp), nil
5697 }
5698
5699 a1, err := get()
5700 if err != nil {
5701 return err
5702 }
5703 a2, err := get()
5704 if err != nil {
5705 return err
5706 }
5707 if a1 != a2 {
5708 return fmt.Errorf("did requests on different connections")
5709 }
5710 time.Sleep(ts.Config.IdleTimeout * 3 / 2)
5711 a3, err := get()
5712 if err != nil {
5713 return err
5714 }
5715 if a2 == a3 {
5716 return fmt.Errorf("request three unexpectedly on same connection")
5717 }
5718
5719
5720 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
5721 if err != nil {
5722 return err
5723 }
5724 defer conn.Close()
5725 conn.Write([]byte("GET / HTTP/1.1\r\nHost: foo.com\r\n"))
5726 time.Sleep(ts.Config.ReadHeaderTimeout * 2)
5727 if _, err := io.CopyN(io.Discard, conn, 1); err == nil {
5728 return fmt.Errorf("copy byte succeeded; want err")
5729 }
5730
5731 return nil
5732 })
5733 }
5734
5735 func get(t *testing.T, c *Client, url string) string {
5736 res, err := c.Get(url)
5737 if err != nil {
5738 t.Fatal(err)
5739 }
5740 defer res.Body.Close()
5741 slurp, err := io.ReadAll(res.Body)
5742 if err != nil {
5743 t.Fatal(err)
5744 }
5745 return string(slurp)
5746 }
5747
5748
5749
5750 func TestServerSetKeepAlivesEnabledClosesConns(t *testing.T) {
5751 run(t, testServerSetKeepAlivesEnabledClosesConns, []testMode{http1Mode})
5752 }
5753 func testServerSetKeepAlivesEnabledClosesConns(t *testing.T, mode testMode) {
5754 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5755 io.WriteString(w, r.RemoteAddr)
5756 })).ts
5757
5758 c := ts.Client()
5759 tr := c.Transport.(*Transport)
5760
5761 get := func() string { return get(t, c, ts.URL) }
5762
5763 a1, a2 := get(), get()
5764 if a1 == a2 {
5765 t.Logf("made two requests from a single conn %q (as expected)", a1)
5766 } else {
5767 t.Errorf("server reported requests from %q and %q; expected same connection", a1, a2)
5768 }
5769
5770
5771
5772
5773
5774 if conns := tr.IdleConnStrsForTesting(); len(conns) != 1 {
5775 t.Errorf("found %d idle conns (%q); want 1", len(conns), conns)
5776 }
5777
5778
5779 ts.Config.SetKeepAlivesEnabled(false)
5780
5781 waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
5782 if conns := tr.IdleConnStrsForTesting(); len(conns) > 0 {
5783 if d > 0 {
5784 t.Logf("idle conns %v after SetKeepAlivesEnabled called = %q; waiting for empty", d, conns)
5785 }
5786 return false
5787 }
5788 return true
5789 })
5790
5791
5792
5793
5794 }
5795
5796 func TestServerShutdown(t *testing.T) { run(t, testServerShutdown) }
5797 func testServerShutdown(t *testing.T, mode testMode) {
5798 var cst *clientServerTest
5799
5800 var once sync.Once
5801 statesRes := make(chan map[ConnState]int, 1)
5802 shutdownRes := make(chan error, 1)
5803 gotOnShutdown := make(chan struct{})
5804 handler := HandlerFunc(func(w ResponseWriter, r *Request) {
5805 first := false
5806 once.Do(func() {
5807 statesRes <- cst.ts.Config.ExportAllConnsByState()
5808 go func() {
5809 shutdownRes <- cst.ts.Config.Shutdown(context.Background())
5810 }()
5811 first = true
5812 })
5813
5814 if first {
5815
5816
5817
5818 <-gotOnShutdown
5819
5820
5821 for !t.Failed() {
5822 res, err := cst.c.Get(cst.ts.URL)
5823 if err != nil {
5824 break
5825 }
5826 out, _ := io.ReadAll(res.Body)
5827 res.Body.Close()
5828 if mode == http2Mode {
5829 t.Logf("%v: unexpected success (%q). Listener should be closed before OnShutdown is called.", cst.ts.URL, out)
5830 t.Logf("Retrying to work around https://go.dev/issue/59038.")
5831 continue
5832 }
5833 t.Errorf("%v: unexpected success (%q). Listener should be closed before OnShutdown is called.", cst.ts.URL, out)
5834 }
5835 }
5836
5837 io.WriteString(w, r.RemoteAddr)
5838 })
5839
5840 cst = newClientServerTest(t, mode, handler, func(srv *httptest.Server) {
5841 srv.Config.RegisterOnShutdown(func() { close(gotOnShutdown) })
5842 })
5843
5844 out := get(t, cst.c, cst.ts.URL)
5845 t.Logf("%v: %q", cst.ts.URL, out)
5846
5847 if err := <-shutdownRes; err != nil {
5848 t.Fatalf("Shutdown: %v", err)
5849 }
5850 <-gotOnShutdown
5851
5852 if states := <-statesRes; states[StateActive] != 1 {
5853 t.Errorf("connection in wrong state, %v", states)
5854 }
5855 }
5856
5857 func TestServerShutdownStateNew(t *testing.T) { runSynctest(t, testServerShutdownStateNew) }
5858 func testServerShutdownStateNew(t *testing.T, mode testMode) {
5859 if testing.Short() {
5860 t.Skip("test takes 5-6 seconds; skipping in short mode")
5861 }
5862
5863 listener := fakeNetListen()
5864 defer listener.Close()
5865
5866 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5867
5868 }), func(ts *httptest.Server) {
5869 ts.Listener.Close()
5870 ts.Listener = listener
5871
5872 ts.Config.ErrorLog = log.New(io.Discard, "", 0)
5873 }).ts
5874
5875
5876 c := listener.connect()
5877 defer c.Close()
5878 synctest.Wait()
5879
5880 shutdownRes := runAsync(func() (struct{}, error) {
5881 return struct{}{}, ts.Config.Shutdown(context.Background())
5882 })
5883
5884
5885
5886
5887 const expectTimeout = 5 * time.Second
5888
5889
5890 time.Sleep(expectTimeout - 1)
5891 synctest.Wait()
5892 if shutdownRes.done() {
5893 t.Fatal("shutdown too soon")
5894 }
5895 if c.IsClosedByPeer() {
5896 t.Fatal("connection was closed by server too soon")
5897 }
5898
5899
5900
5901
5902
5903 time.Sleep(2 * time.Second)
5904 synctest.Wait()
5905 if _, err := shutdownRes.result(); err != nil {
5906 t.Fatalf("Shutdown() = %v, want complete", err)
5907 }
5908 if !c.IsClosedByPeer() {
5909 t.Fatalf("connection was not closed by server after shutdown")
5910 }
5911 }
5912
5913
5914 func TestServerCloseDeadlock(t *testing.T) {
5915 var s Server
5916 s.Close()
5917 s.Close()
5918 }
5919
5920
5921
5922 func TestServerKeepAlivesEnabled(t *testing.T) { run(t, testServerKeepAlivesEnabled, testNotParallel) }
5923 func testServerKeepAlivesEnabled(t *testing.T, mode testMode) {
5924 if mode == http2Mode {
5925 restore := ExportSetH2GoawayTimeout(10 * time.Millisecond)
5926 defer restore()
5927 }
5928
5929 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}))
5930 defer cst.close()
5931 srv := cst.ts.Config
5932 srv.SetKeepAlivesEnabled(false)
5933 for try := 0; try < 2; try++ {
5934 waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
5935 if !srv.ExportAllConnsIdle() {
5936 if d > 0 {
5937 t.Logf("test server still has active conns after %v", d)
5938 }
5939 return false
5940 }
5941 return true
5942 })
5943 conns := 0
5944 var info httptrace.GotConnInfo
5945 ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
5946 GotConn: func(v httptrace.GotConnInfo) {
5947 conns++
5948 info = v
5949 },
5950 })
5951 req, err := NewRequestWithContext(ctx, "GET", cst.ts.URL, nil)
5952 if err != nil {
5953 t.Fatal(err)
5954 }
5955 res, err := cst.c.Do(req)
5956 if err != nil {
5957 t.Fatal(err)
5958 }
5959 res.Body.Close()
5960 if conns != 1 {
5961 t.Fatalf("request %v: got %v conns, want 1", try, conns)
5962 }
5963 if info.Reused || info.WasIdle {
5964 t.Fatalf("request %v: Reused=%v (want false), WasIdle=%v (want false)", try, info.Reused, info.WasIdle)
5965 }
5966 }
5967 }
5968
5969
5970
5971
5972 func TestServerCancelsReadTimeoutWhenIdle(t *testing.T) { run(t, testServerCancelsReadTimeoutWhenIdle) }
5973 func testServerCancelsReadTimeoutWhenIdle(t *testing.T, mode testMode) {
5974 runTimeSensitiveTest(t, []time.Duration{
5975 10 * time.Millisecond,
5976 50 * time.Millisecond,
5977 250 * time.Millisecond,
5978 time.Second,
5979 2 * time.Second,
5980 }, func(t *testing.T, timeout time.Duration) error {
5981 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5982 select {
5983 case <-time.After(2 * timeout):
5984 fmt.Fprint(w, "ok")
5985 case <-r.Context().Done():
5986 fmt.Fprint(w, r.Context().Err())
5987 }
5988 }), func(ts *httptest.Server) {
5989 ts.Config.ReadTimeout = timeout
5990 t.Logf("Server.Config.ReadTimeout = %v", timeout)
5991 })
5992 defer cst.close()
5993 ts := cst.ts
5994
5995 var retries atomic.Int32
5996 cst.c.Transport.(*Transport).Proxy = func(*Request) (*url.URL, error) {
5997 if retries.Add(1) != 1 {
5998 return nil, errors.New("too many retries")
5999 }
6000 return nil, nil
6001 }
6002
6003 c := ts.Client()
6004
6005 res, err := c.Get(ts.URL)
6006 if err != nil {
6007 return fmt.Errorf("Get: %v", err)
6008 }
6009 slurp, err := io.ReadAll(res.Body)
6010 res.Body.Close()
6011 if err != nil {
6012 return fmt.Errorf("Body ReadAll: %v", err)
6013 }
6014 if string(slurp) != "ok" {
6015 return fmt.Errorf("got: %q, want ok", slurp)
6016 }
6017 return nil
6018 })
6019 }
6020
6021
6022
6023
6024 func TestServerCancelsReadHeaderTimeoutWhenIdle(t *testing.T) {
6025 run(t, testServerCancelsReadHeaderTimeoutWhenIdle, []testMode{http1Mode})
6026 }
6027 func testServerCancelsReadHeaderTimeoutWhenIdle(t *testing.T, mode testMode) {
6028 runTimeSensitiveTest(t, []time.Duration{
6029 10 * time.Millisecond,
6030 50 * time.Millisecond,
6031 250 * time.Millisecond,
6032 time.Second,
6033 2 * time.Second,
6034 }, func(t *testing.T, timeout time.Duration) error {
6035 cst := newClientServerTest(t, mode, serve(200), func(ts *httptest.Server) {
6036 ts.Config.ReadHeaderTimeout = timeout
6037 ts.Config.IdleTimeout = 0
6038 })
6039 defer cst.close()
6040 ts := cst.ts
6041
6042
6043
6044 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
6045 if err != nil {
6046 t.Fatalf("dial failed: %v", err)
6047 }
6048 br := bufio.NewReader(conn)
6049 defer conn.Close()
6050
6051 if _, err := conn.Write([]byte("GET / HTTP/1.1\r\nHost: e.com\r\n\r\n")); err != nil {
6052 return fmt.Errorf("writing first request failed: %v", err)
6053 }
6054
6055 if _, err := ReadResponse(br, nil); err != nil {
6056 return fmt.Errorf("first response (before timeout) failed: %v", err)
6057 }
6058
6059
6060
6061 time.Sleep(timeout * 3 / 2)
6062
6063 if _, err := conn.Write([]byte("GET / HTTP/1.1\r\nHost: e.com\r\n\r\n")); err != nil {
6064 return fmt.Errorf("writing second request failed: %v", err)
6065 }
6066
6067 if _, err := ReadResponse(br, nil); err != nil {
6068 return fmt.Errorf("second response (after timeout) failed: %v", err)
6069 }
6070
6071 return nil
6072 })
6073 }
6074
6075
6076
6077 func runTimeSensitiveTest(t *testing.T, durations []time.Duration, test func(t *testing.T, d time.Duration) error) {
6078 for i, d := range durations {
6079 err := test(t, d)
6080 if err == nil {
6081 return
6082 }
6083 if i == len(durations)-1 || t.Failed() {
6084 t.Fatalf("failed with duration %v: %v", d, err)
6085 }
6086 t.Logf("retrying after error with duration %v: %v", d, err)
6087 }
6088 }
6089
6090
6091
6092 func TestServerDuplicateBackgroundRead(t *testing.T) {
6093 run(t, testServerDuplicateBackgroundRead, []testMode{http1Mode})
6094 }
6095 func testServerDuplicateBackgroundRead(t *testing.T, mode testMode) {
6096 if runtime.GOOS == "netbsd" && runtime.GOARCH == "arm" {
6097 testenv.SkipFlaky(t, 24826)
6098 }
6099
6100 goroutines := 5
6101 requests := 2000
6102 if testing.Short() {
6103 goroutines = 3
6104 requests = 100
6105 }
6106
6107 hts := newClientServerTest(t, mode, HandlerFunc(NotFound)).ts
6108
6109 reqBytes := []byte("GET / HTTP/1.1\r\nHost: e.com\r\n\r\n")
6110
6111 var wg sync.WaitGroup
6112 for i := 0; i < goroutines; i++ {
6113 wg.Add(1)
6114 go func() {
6115 defer wg.Done()
6116 cn, err := net.Dial("tcp", hts.Listener.Addr().String())
6117 if err != nil {
6118 t.Error(err)
6119 return
6120 }
6121 defer cn.Close()
6122
6123 wg.Add(1)
6124 go func() {
6125 defer wg.Done()
6126 io.Copy(io.Discard, cn)
6127 }()
6128
6129 for j := 0; j < requests; j++ {
6130 if t.Failed() {
6131 return
6132 }
6133 _, err := cn.Write(reqBytes)
6134 if err != nil {
6135 t.Error(err)
6136 return
6137 }
6138 }
6139 }()
6140 }
6141 wg.Wait()
6142 }
6143
6144
6145
6146
6147
6148
6149 func TestServerHijackGetsBackgroundByte(t *testing.T) {
6150 run(t, testServerHijackGetsBackgroundByte, []testMode{http1Mode})
6151 }
6152 func testServerHijackGetsBackgroundByte(t *testing.T, mode testMode) {
6153 if runtime.GOOS == "plan9" {
6154 t.Skip("skipping test; see https://golang.org/issue/18657")
6155 }
6156 done := make(chan struct{})
6157 inHandler := make(chan bool, 1)
6158 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6159 defer close(done)
6160
6161
6162 inHandler <- true
6163
6164 conn, buf, err := w.(Hijacker).Hijack()
6165 if err != nil {
6166 t.Error(err)
6167 return
6168 }
6169 defer conn.Close()
6170
6171 peek, err := buf.Reader.Peek(3)
6172 if string(peek) != "foo" || err != nil {
6173 t.Errorf("Peek = %q, %v; want foo, nil", peek, err)
6174 }
6175
6176 select {
6177 case <-r.Context().Done():
6178 t.Error("context unexpectedly canceled")
6179 default:
6180 }
6181 })).ts
6182
6183 cn, err := net.Dial("tcp", ts.Listener.Addr().String())
6184 if err != nil {
6185 t.Fatal(err)
6186 }
6187 defer cn.Close()
6188 if _, err := cn.Write([]byte("GET / HTTP/1.1\r\nHost: e.com\r\n\r\n")); err != nil {
6189 t.Fatal(err)
6190 }
6191 <-inHandler
6192 if _, err := cn.Write([]byte("foo")); err != nil {
6193 t.Fatal(err)
6194 }
6195
6196 if err := cn.(*net.TCPConn).CloseWrite(); err != nil {
6197 t.Fatal(err)
6198 }
6199 <-done
6200 }
6201
6202
6203 func TestServerHijackGetsFullBody(t *testing.T) {
6204 run(t, testServerHijackGetsFullBody, []testMode{http1Mode})
6205 }
6206 func testServerHijackGetsFullBody(t *testing.T, mode testMode) {
6207 if runtime.GOOS == "plan9" {
6208 t.Skip("skipping test; see https://golang.org/issue/18657")
6209 }
6210 done := make(chan struct{})
6211 needle := strings.Repeat("x", 100*1024)
6212 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6213 defer close(done)
6214
6215 conn, buf, err := w.(Hijacker).Hijack()
6216 if err != nil {
6217 t.Error(err)
6218 return
6219 }
6220 defer conn.Close()
6221
6222 got := make([]byte, len(needle))
6223 n, err := io.ReadFull(buf.Reader, got)
6224 if n != len(needle) || string(got) != needle || err != nil {
6225 t.Errorf("Peek = %q, %v; want 'x'*4096, nil", got, err)
6226 }
6227 })).ts
6228
6229 cn, err := net.Dial("tcp", ts.Listener.Addr().String())
6230 if err != nil {
6231 t.Fatal(err)
6232 }
6233 defer cn.Close()
6234 buf := []byte("GET / HTTP/1.1\r\nHost: e.com\r\n\r\n")
6235 buf = append(buf, []byte(needle)...)
6236 if _, err := cn.Write(buf); err != nil {
6237 t.Fatal(err)
6238 }
6239
6240 if err := cn.(*net.TCPConn).CloseWrite(); err != nil {
6241 t.Fatal(err)
6242 }
6243 <-done
6244 }
6245
6246
6247
6248
6249 func TestServerHijackGetsBackgroundByte_big(t *testing.T) {
6250 run(t, testServerHijackGetsBackgroundByte_big, []testMode{http1Mode})
6251 }
6252 func testServerHijackGetsBackgroundByte_big(t *testing.T, mode testMode) {
6253 if runtime.GOOS == "plan9" {
6254 t.Skip("skipping test; see https://golang.org/issue/18657")
6255 }
6256 done := make(chan struct{})
6257 const size = 8 << 10
6258 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6259 defer close(done)
6260
6261 conn, buf, err := w.(Hijacker).Hijack()
6262 if err != nil {
6263 t.Error(err)
6264 return
6265 }
6266 defer conn.Close()
6267 slurp, err := io.ReadAll(buf.Reader)
6268 if err != nil {
6269 t.Errorf("Copy: %v", err)
6270 }
6271 allX := true
6272 for _, v := range slurp {
6273 if v != 'x' {
6274 allX = false
6275 }
6276 }
6277 if len(slurp) != size {
6278 t.Errorf("read %d; want %d", len(slurp), size)
6279 } else if !allX {
6280 t.Errorf("read %q; want %d 'x'", slurp, size)
6281 }
6282 })).ts
6283
6284 cn, err := net.Dial("tcp", ts.Listener.Addr().String())
6285 if err != nil {
6286 t.Fatal(err)
6287 }
6288 defer cn.Close()
6289 if _, err := fmt.Fprintf(cn, "GET / HTTP/1.1\r\nHost: e.com\r\n\r\n%s",
6290 strings.Repeat("x", size)); err != nil {
6291 t.Fatal(err)
6292 }
6293 if err := cn.(*net.TCPConn).CloseWrite(); err != nil {
6294 t.Fatal(err)
6295 }
6296
6297 <-done
6298 }
6299
6300
6301 func TestServerValidatesMethod(t *testing.T) {
6302 tests := []struct {
6303 method string
6304 want int
6305 }{
6306 {"GET", 200},
6307 {"GE(T", 400},
6308 }
6309 for _, tt := range tests {
6310 conn := newTestConn()
6311 io.WriteString(&conn.readBuf, tt.method+" / HTTP/1.1\r\nHost: foo.example\r\n\r\n")
6312
6313 ln := &oneConnListener{conn}
6314 go Serve(ln, serve(200))
6315 <-conn.closec
6316 res, err := ReadResponse(bufio.NewReader(&conn.writeBuf), nil)
6317 if err != nil {
6318 t.Errorf("For %s, ReadResponse: %v", tt.method, res)
6319 continue
6320 }
6321 if res.StatusCode != tt.want {
6322 t.Errorf("For %s, Status = %d; want %d", tt.method, res.StatusCode, tt.want)
6323 }
6324 }
6325 }
6326
6327
6328 type eofListenerNotComparable []int
6329
6330 func (eofListenerNotComparable) Accept() (net.Conn, error) { return nil, io.EOF }
6331 func (eofListenerNotComparable) Addr() net.Addr { return nil }
6332 func (eofListenerNotComparable) Close() error { return nil }
6333
6334
6335 func TestServerListenNotComparableListener(t *testing.T) {
6336 var s Server
6337 s.Serve(make(eofListenerNotComparable, 1))
6338 }
6339
6340
6341 type countCloseListener struct {
6342 net.Listener
6343 closes int32
6344 }
6345
6346 func (p *countCloseListener) Close() error {
6347 var err error
6348 if n := atomic.AddInt32(&p.closes, 1); n == 1 && p.Listener != nil {
6349 err = p.Listener.Close()
6350 }
6351 return err
6352 }
6353
6354
6355 func TestServerCloseListenerOnce(t *testing.T) {
6356 setParallel(t)
6357 defer afterTest(t)
6358
6359 ln := newLocalListener(t)
6360 defer ln.Close()
6361
6362 cl := &countCloseListener{Listener: ln}
6363 server := &Server{}
6364 sdone := make(chan bool, 1)
6365
6366 go func() {
6367 server.Serve(cl)
6368 sdone <- true
6369 }()
6370 time.Sleep(10 * time.Millisecond)
6371 server.Shutdown(context.Background())
6372 ln.Close()
6373 <-sdone
6374
6375 nclose := atomic.LoadInt32(&cl.closes)
6376 if nclose != 1 {
6377 t.Errorf("Close calls = %v; want 1", nclose)
6378 }
6379 }
6380
6381
6382 func TestServerShutdownThenServe(t *testing.T) {
6383 var srv Server
6384 cl := &countCloseListener{Listener: nil}
6385 srv.Shutdown(context.Background())
6386 got := srv.Serve(cl)
6387 if got != ErrServerClosed {
6388 t.Errorf("Serve err = %v; want ErrServerClosed", got)
6389 }
6390 nclose := atomic.LoadInt32(&cl.closes)
6391 if nclose != 1 {
6392 t.Errorf("Close calls = %v; want 1", nclose)
6393 }
6394 }
6395
6396
6397 func TestStripPortFromHost(t *testing.T) {
6398 mux := NewServeMux()
6399
6400 mux.HandleFunc("example.com/", func(w ResponseWriter, r *Request) {
6401 fmt.Fprintf(w, "OK")
6402 })
6403 mux.HandleFunc("example.com:9000/", func(w ResponseWriter, r *Request) {
6404 fmt.Fprintf(w, "uh-oh!")
6405 })
6406
6407 req := httptest.NewRequest("GET", "http://example.com:9000/", nil)
6408 rw := httptest.NewRecorder()
6409
6410 mux.ServeHTTP(rw, req)
6411
6412 response := rw.Body.String()
6413 if response != "OK" {
6414 t.Errorf("Response gotten was %q", response)
6415 }
6416 }
6417
6418 func TestServerContexts(t *testing.T) { run(t, testServerContexts) }
6419 func testServerContexts(t *testing.T, mode testMode) {
6420 type baseKey struct{}
6421 type connKey struct{}
6422 ch := make(chan context.Context, 1)
6423 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, r *Request) {
6424 ch <- r.Context()
6425 }), func(ts *httptest.Server) {
6426 ts.Config.BaseContext = func(ln net.Listener) context.Context {
6427 if strings.Contains(reflect.TypeOf(ln).String(), "onceClose") {
6428 t.Errorf("unexpected onceClose listener type %T", ln)
6429 }
6430 return context.WithValue(context.Background(), baseKey{}, "base")
6431 }
6432 ts.Config.ConnContext = func(ctx context.Context, c net.Conn) context.Context {
6433 if got, want := ctx.Value(baseKey{}), "base"; got != want {
6434 t.Errorf("in ConnContext, base context key = %#v; want %q", got, want)
6435 }
6436 return context.WithValue(ctx, connKey{}, "conn")
6437 }
6438 }).ts
6439 res, err := ts.Client().Get(ts.URL)
6440 if err != nil {
6441 t.Fatal(err)
6442 }
6443 res.Body.Close()
6444 ctx := <-ch
6445 if got, want := ctx.Value(baseKey{}), "base"; got != want {
6446 t.Errorf("base context key = %#v; want %q", got, want)
6447 }
6448 if got, want := ctx.Value(connKey{}), "conn"; got != want {
6449 t.Errorf("conn context key = %#v; want %q", got, want)
6450 }
6451 }
6452
6453
6454 func TestConnContextNotModifyingAllContexts(t *testing.T) {
6455 run(t, testConnContextNotModifyingAllContexts)
6456 }
6457 func testConnContextNotModifyingAllContexts(t *testing.T, mode testMode) {
6458 type connKey struct{}
6459 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, r *Request) {
6460 rw.Header().Set("Connection", "close")
6461 }), func(ts *httptest.Server) {
6462 ts.Config.ConnContext = func(ctx context.Context, c net.Conn) context.Context {
6463 if got := ctx.Value(connKey{}); got != nil {
6464 t.Errorf("in ConnContext, unexpected context key = %#v", got)
6465 }
6466 return context.WithValue(ctx, connKey{}, "conn")
6467 }
6468 }).ts
6469
6470 var res *Response
6471 var err error
6472
6473 res, err = ts.Client().Get(ts.URL)
6474 if err != nil {
6475 t.Fatal(err)
6476 }
6477 res.Body.Close()
6478
6479 res, err = ts.Client().Get(ts.URL)
6480 if err != nil {
6481 t.Fatal(err)
6482 }
6483 res.Body.Close()
6484 }
6485
6486
6487
6488 func TestUnsupportedTransferEncodingsReturn501(t *testing.T) {
6489 run(t, testUnsupportedTransferEncodingsReturn501, []testMode{http1Mode})
6490 }
6491 func testUnsupportedTransferEncodingsReturn501(t *testing.T, mode testMode) {
6492 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6493 w.Write([]byte("Hello, World!"))
6494 })).ts
6495
6496 serverURL, err := url.Parse(cst.URL)
6497 if err != nil {
6498 t.Fatalf("Failed to parse server URL: %v", err)
6499 }
6500
6501 unsupportedTEs := []string{
6502 "fugazi",
6503 "foo-bar",
6504 "unknown",
6505 `" chunked"`,
6506 }
6507
6508 for _, badTE := range unsupportedTEs {
6509 http1ReqBody := fmt.Sprintf(""+
6510 "POST / HTTP/1.1\r\nConnection: close\r\n"+
6511 "Host: localhost\r\nTransfer-Encoding: %s\r\n\r\n", badTE)
6512
6513 gotBody, err := fetchWireResponse(serverURL.Host, []byte(http1ReqBody))
6514 if err != nil {
6515 t.Errorf("%q. unexpected error: %v", badTE, err)
6516 continue
6517 }
6518
6519 wantBody := fmt.Sprintf("" +
6520 "HTTP/1.1 501 Not Implemented\r\nContent-Type: text/plain; charset=utf-8\r\n" +
6521 "Connection: close\r\n\r\nUnsupported transfer encoding")
6522
6523 if string(gotBody) != wantBody {
6524 t.Errorf("%q. body\ngot\n%q\nwant\n%q", badTE, gotBody, wantBody)
6525 }
6526 }
6527 }
6528
6529
6530 func TestContentEncodingNoSniffing(t *testing.T) { run(t, testContentEncodingNoSniffing) }
6531 func testContentEncodingNoSniffing(t *testing.T, mode testMode) {
6532 type setting struct {
6533 name string
6534 body []byte
6535
6536
6537
6538
6539 contentEncoding any
6540 wantContentType string
6541 }
6542
6543 settings := []*setting{
6544 {
6545 name: "gzip content-encoding, gzipped",
6546 contentEncoding: "application/gzip",
6547 wantContentType: "",
6548 body: func() []byte {
6549 buf := new(bytes.Buffer)
6550 gzw := gzip.NewWriter(buf)
6551 gzw.Write([]byte("doctype html><p>Hello</p>"))
6552 gzw.Close()
6553 return buf.Bytes()
6554 }(),
6555 },
6556 {
6557 name: "zlib content-encoding, zlibbed",
6558 contentEncoding: "application/zlib",
6559 wantContentType: "",
6560 body: func() []byte {
6561 buf := new(bytes.Buffer)
6562 zw := zlib.NewWriter(buf)
6563 zw.Write([]byte("doctype html><p>Hello</p>"))
6564 zw.Close()
6565 return buf.Bytes()
6566 }(),
6567 },
6568 {
6569 name: "no content-encoding",
6570 wantContentType: "application/x-gzip",
6571 body: func() []byte {
6572 buf := new(bytes.Buffer)
6573 gzw := gzip.NewWriter(buf)
6574 gzw.Write([]byte("doctype html><p>Hello</p>"))
6575 gzw.Close()
6576 return buf.Bytes()
6577 }(),
6578 },
6579 {
6580 name: "phony content-encoding",
6581 contentEncoding: "foo/bar",
6582 body: []byte("doctype html><p>Hello</p>"),
6583 },
6584 {
6585 name: "empty but set content-encoding",
6586 contentEncoding: "",
6587 wantContentType: "audio/mpeg",
6588 body: []byte("ID3"),
6589 },
6590 }
6591
6592 for _, tt := range settings {
6593 t.Run(tt.name, func(t *testing.T) {
6594 cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, r *Request) {
6595 if tt.contentEncoding != nil {
6596 rw.Header().Set("Content-Encoding", tt.contentEncoding.(string))
6597 }
6598 rw.Write(tt.body)
6599 }))
6600
6601 res, err := cst.c.Get(cst.ts.URL)
6602 if err != nil {
6603 t.Fatalf("Failed to fetch URL: %v", err)
6604 }
6605 defer res.Body.Close()
6606
6607 if g, w := res.Header.Get("Content-Encoding"), tt.contentEncoding; g != w {
6608 if w != nil {
6609 t.Errorf("Content-Encoding mismatch\n\tgot: %q\n\twant: %q", g, w)
6610 } else if g != "" {
6611 t.Errorf("Unexpected Content-Encoding %q", g)
6612 }
6613 }
6614
6615 if g, w := res.Header.Get("Content-Type"), tt.wantContentType; g != w {
6616 t.Errorf("Content-Type mismatch\n\tgot: %q\n\twant: %q", g, w)
6617 }
6618 })
6619 }
6620 }
6621
6622
6623
6624 func TestTimeoutHandlerSuperfluousLogs(t *testing.T) {
6625 run(t, testTimeoutHandlerSuperfluousLogs, []testMode{http1Mode})
6626 }
6627 func testTimeoutHandlerSuperfluousLogs(t *testing.T, mode testMode) {
6628 if testing.Short() {
6629 t.Skip("skipping in short mode")
6630 }
6631
6632 pc, curFile, _, _ := runtime.Caller(0)
6633 curFileBaseName := filepath.Base(curFile)
6634 testFuncName := runtime.FuncForPC(pc).Name()
6635
6636 timeoutMsg := "timed out here!"
6637
6638 tests := []struct {
6639 name string
6640 mustTimeout bool
6641 wantResp string
6642 }{
6643 {
6644 name: "return before timeout",
6645 wantResp: "HTTP/1.1 404 Not Found\r\nContent-Length: 0\r\n\r\n",
6646 },
6647 {
6648 name: "return after timeout",
6649 mustTimeout: true,
6650 wantResp: fmt.Sprintf("HTTP/1.1 503 Service Unavailable\r\nContent-Length: %d\r\n\r\n%s",
6651 len(timeoutMsg), timeoutMsg),
6652 },
6653 }
6654
6655 for _, tt := range tests {
6656 tt := tt
6657 t.Run(tt.name, func(t *testing.T) {
6658 exitHandler := make(chan bool, 1)
6659 defer close(exitHandler)
6660 lastLine := make(chan int, 1)
6661
6662 sh := HandlerFunc(func(w ResponseWriter, r *Request) {
6663 w.WriteHeader(404)
6664 w.WriteHeader(404)
6665 w.WriteHeader(404)
6666 w.WriteHeader(404)
6667 _, _, line, _ := runtime.Caller(0)
6668 lastLine <- line
6669 <-exitHandler
6670 })
6671
6672 if !tt.mustTimeout {
6673 exitHandler <- true
6674 }
6675
6676 logBuf := new(strings.Builder)
6677 srvLog := log.New(logBuf, "", 0)
6678
6679 dur := 20 * time.Millisecond
6680 if !tt.mustTimeout {
6681
6682 dur = 10 * time.Second
6683 }
6684 th := TimeoutHandler(sh, dur, timeoutMsg)
6685 cst := newClientServerTest(t, mode, th, optWithServerLog(srvLog))
6686 defer cst.close()
6687
6688 res, err := cst.c.Get(cst.ts.URL)
6689 if err != nil {
6690 t.Fatalf("Unexpected error: %v", err)
6691 }
6692
6693
6694
6695 res.Header.Del("Date")
6696 res.Header.Del("Content-Type")
6697
6698
6699 blob, _ := httputil.DumpResponse(res, true)
6700 if g, w := string(blob), tt.wantResp; g != w {
6701 t.Errorf("Response mismatch\nGot\n%q\n\nWant\n%q", g, w)
6702 }
6703
6704
6705
6706 logEntries := strings.Split(strings.TrimSpace(logBuf.String()), "\n")
6707 if g, w := len(logEntries), 3; g != w {
6708 blob, _ := json.MarshalIndent(logEntries, "", " ")
6709 t.Fatalf("Server logs count mismatch\ngot %d, want %d\n\nGot\n%s\n", g, w, blob)
6710 }
6711
6712 lastSpuriousLine := <-lastLine
6713 firstSpuriousLine := lastSpuriousLine - 3
6714
6715
6716 for i, logEntry := range logEntries {
6717 wantLine := firstSpuriousLine + i
6718 pat := fmt.Sprintf("^http: superfluous response.WriteHeader call from %s.func\\d+.\\d+ \\(%s:%d\\)$",
6719 testFuncName, curFileBaseName, wantLine)
6720 re := regexp.MustCompile(pat)
6721 if !re.MatchString(logEntry) {
6722 t.Errorf("Log entry mismatch\n\t%s\ndoes not match\n\t%s", logEntry, pat)
6723 }
6724 }
6725 })
6726 }
6727 }
6728
6729
6730
6731
6732 func fetchWireResponse(host string, http1ReqBody []byte) ([]byte, error) {
6733 conn, err := net.Dial("tcp", host)
6734 if err != nil {
6735 return nil, err
6736 }
6737 defer conn.Close()
6738
6739 if _, err := conn.Write(http1ReqBody); err != nil {
6740 return nil, err
6741 }
6742 return io.ReadAll(conn)
6743 }
6744
6745 func BenchmarkResponseStatusLine(b *testing.B) {
6746 b.ReportAllocs()
6747 b.RunParallel(func(pb *testing.PB) {
6748 bw := bufio.NewWriter(io.Discard)
6749 var buf3 [3]byte
6750 for pb.Next() {
6751 Export_writeStatusLine(bw, true, 200, buf3[:])
6752 }
6753 })
6754 }
6755
6756 func TestDisableKeepAliveUpgrade(t *testing.T) {
6757 run(t, testDisableKeepAliveUpgrade, []testMode{http1Mode})
6758 }
6759 func testDisableKeepAliveUpgrade(t *testing.T, mode testMode) {
6760 if testing.Short() {
6761 t.Skip("skipping in short mode")
6762 }
6763
6764 s := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6765 w.Header().Set("Connection", "Upgrade")
6766 w.Header().Set("Upgrade", "someProto")
6767 w.WriteHeader(StatusSwitchingProtocols)
6768 c, buf, err := w.(Hijacker).Hijack()
6769 if err != nil {
6770 return
6771 }
6772 defer c.Close()
6773
6774
6775
6776 io.Copy(c, buf)
6777 }), func(ts *httptest.Server) {
6778 ts.Config.SetKeepAlivesEnabled(false)
6779 }).ts
6780
6781 cl := s.Client()
6782 cl.Transport.(*Transport).DisableKeepAlives = true
6783
6784 resp, err := cl.Get(s.URL)
6785 if err != nil {
6786 t.Fatalf("failed to perform request: %v", err)
6787 }
6788 defer resp.Body.Close()
6789
6790 if resp.StatusCode != StatusSwitchingProtocols {
6791 t.Fatalf("unexpected status code: %v", resp.StatusCode)
6792 }
6793
6794 rwc, ok := resp.Body.(io.ReadWriteCloser)
6795 if !ok {
6796 t.Fatalf("Response.Body is not an io.ReadWriteCloser: %T", resp.Body)
6797 }
6798
6799 _, err = rwc.Write([]byte("hello"))
6800 if err != nil {
6801 t.Fatalf("failed to write to body: %v", err)
6802 }
6803
6804 b := make([]byte, 5)
6805 _, err = io.ReadFull(rwc, b)
6806 if err != nil {
6807 t.Fatalf("failed to read from body: %v", err)
6808 }
6809
6810 if string(b) != "hello" {
6811 t.Fatalf("unexpected value read from body:\ngot: %q\nwant: %q", b, "hello")
6812 }
6813 }
6814
6815 type tlogWriter struct{ t *testing.T }
6816
6817 func (w tlogWriter) Write(p []byte) (int, error) {
6818 w.t.Log(string(p))
6819 return len(p), nil
6820 }
6821
6822 func TestWriteHeaderSwitchingProtocols(t *testing.T) {
6823 run(t, testWriteHeaderSwitchingProtocols, []testMode{http1Mode})
6824 }
6825 func testWriteHeaderSwitchingProtocols(t *testing.T, mode testMode) {
6826 const wantBody = "want"
6827 const wantUpgrade = "someProto"
6828 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6829 w.Header().Set("Connection", "Upgrade")
6830 w.Header().Set("Upgrade", wantUpgrade)
6831 w.WriteHeader(StatusSwitchingProtocols)
6832 NewResponseController(w).Flush()
6833
6834
6835 w.WriteHeader(200)
6836 if _, err := w.Write([]byte("x")); err == nil {
6837 t.Errorf("Write to body after 101 Switching Protocols unexpectedly succeeded")
6838 }
6839
6840 c, _, err := NewResponseController(w).Hijack()
6841 if err != nil {
6842 t.Errorf("Hijack: %v", err)
6843 return
6844 }
6845 defer c.Close()
6846 if _, err := c.Write([]byte(wantBody)); err != nil {
6847 t.Errorf("Write to hijacked body: %v", err)
6848 }
6849 }), func(ts *httptest.Server) {
6850
6851 ts.Config.ErrorLog = log.New(tlogWriter{t}, "log: ", 0)
6852 }).ts
6853
6854 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
6855 if err != nil {
6856 t.Fatalf("net.Dial: %v", err)
6857 }
6858 _, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: foo\r\n\r\n"))
6859 if err != nil {
6860 t.Fatalf("conn.Write: %v", err)
6861 }
6862 defer conn.Close()
6863
6864 r := bufio.NewReader(conn)
6865 res, err := ReadResponse(r, &Request{Method: "GET"})
6866 if err != nil {
6867 t.Fatal("ReadResponse error:", err)
6868 }
6869 if res.StatusCode != StatusSwitchingProtocols {
6870 t.Errorf("Response StatusCode=%v, want 101", res.StatusCode)
6871 }
6872 if got := res.Header.Get("Upgrade"); got != wantUpgrade {
6873 t.Errorf("Response Upgrade header = %q, want %q", got, wantUpgrade)
6874 }
6875 body, err := io.ReadAll(r)
6876 if err != nil {
6877 t.Error(err)
6878 }
6879 if string(body) != wantBody {
6880 t.Errorf("Response body = %q, want %q", string(body), wantBody)
6881 }
6882 }
6883
6884 func TestMuxRedirectRelative(t *testing.T) {
6885 setParallel(t)
6886 req, err := ReadRequest(bufio.NewReader(strings.NewReader("GET http://example.com HTTP/1.1\r\nHost: test\r\n\r\n")))
6887 if err != nil {
6888 t.Errorf("%s", err)
6889 }
6890 mux := NewServeMux()
6891 resp := httptest.NewRecorder()
6892 mux.ServeHTTP(resp, req)
6893 if got, want := resp.Header().Get("Location"), "/"; got != want {
6894 t.Errorf("Location header expected %q; got %q", want, got)
6895 }
6896 if got, want := resp.Code, StatusMovedPermanently; got != want {
6897 t.Errorf("Expected response code %d; got %d", want, got)
6898 }
6899 }
6900
6901
6902 func TestQuerySemicolon(t *testing.T) {
6903 t.Cleanup(func() { afterTest(t) })
6904
6905 tests := []struct {
6906 query string
6907 xNoSemicolons string
6908 xWithSemicolons string
6909 expectParseFormErr bool
6910 }{
6911 {"?a=1;x=bad&x=good", "good", "bad", true},
6912 {"?a=1;b=bad&x=good", "good", "good", true},
6913 {"?a=1%3Bx=bad&x=good%3B", "good;", "good;", false},
6914 {"?a=1;x=good;x=bad", "", "good", true},
6915 }
6916
6917 run(t, func(t *testing.T, mode testMode) {
6918 for _, tt := range tests {
6919 t.Run(tt.query+"/allow=false", func(t *testing.T) {
6920 allowSemicolons := false
6921 testQuerySemicolon(t, mode, tt.query, tt.xNoSemicolons, allowSemicolons, tt.expectParseFormErr)
6922 })
6923 t.Run(tt.query+"/allow=true", func(t *testing.T) {
6924 allowSemicolons, expectParseFormErr := true, false
6925 testQuerySemicolon(t, mode, tt.query, tt.xWithSemicolons, allowSemicolons, expectParseFormErr)
6926 })
6927 }
6928 })
6929 }
6930
6931 func testQuerySemicolon(t *testing.T, mode testMode, query string, wantX string, allowSemicolons, expectParseFormErr bool) {
6932 writeBackX := func(w ResponseWriter, r *Request) {
6933 x := r.URL.Query().Get("x")
6934 if expectParseFormErr {
6935 if err := r.ParseForm(); err == nil || !strings.Contains(err.Error(), "semicolon") {
6936 t.Errorf("expected error mentioning semicolons from ParseForm, got %v", err)
6937 }
6938 } else {
6939 if err := r.ParseForm(); err != nil {
6940 t.Errorf("expected no error from ParseForm, got %v", err)
6941 }
6942 }
6943 if got := r.FormValue("x"); x != got {
6944 t.Errorf("got %q from FormValue, want %q", got, x)
6945 }
6946 fmt.Fprintf(w, "%s", x)
6947 }
6948
6949 h := Handler(HandlerFunc(writeBackX))
6950 if allowSemicolons {
6951 h = AllowQuerySemicolons(h)
6952 }
6953
6954 logBuf := &strings.Builder{}
6955 ts := newClientServerTest(t, mode, h, func(ts *httptest.Server) {
6956 ts.Config.ErrorLog = log.New(logBuf, "", 0)
6957 }).ts
6958
6959 req, _ := NewRequest("GET", ts.URL+query, nil)
6960 res, err := ts.Client().Do(req)
6961 if err != nil {
6962 t.Fatal(err)
6963 }
6964 slurp, _ := io.ReadAll(res.Body)
6965 res.Body.Close()
6966 if got, want := res.StatusCode, 200; got != want {
6967 t.Errorf("Status = %d; want = %d", got, want)
6968 }
6969 if got, want := string(slurp), wantX; got != want {
6970 t.Errorf("Body = %q; want = %q", got, want)
6971 }
6972 }
6973
6974 func TestMaxBytesHandler(t *testing.T) {
6975
6976 defer afterTest(t)
6977
6978 for _, maxSize := range []int64{100, 1_000, 1_000_000} {
6979 for _, requestSize := range []int64{100, 1_000, 1_000_000} {
6980 t.Run(fmt.Sprintf("max size %d request size %d", maxSize, requestSize),
6981 func(t *testing.T) {
6982 run(t, func(t *testing.T, mode testMode) {
6983 testMaxBytesHandler(t, mode, maxSize, requestSize)
6984 }, testNotParallel)
6985 })
6986 }
6987 }
6988 }
6989
6990 func testMaxBytesHandler(t *testing.T, mode testMode, maxSize, requestSize int64) {
6991 runTimeSensitiveTest(t, []time.Duration{
6992 1 * time.Millisecond,
6993 5 * time.Millisecond,
6994 10 * time.Millisecond,
6995 50 * time.Millisecond,
6996 100 * time.Millisecond,
6997 500 * time.Millisecond,
6998 time.Second,
6999 5 * time.Second,
7000 }, func(t *testing.T, timeout time.Duration) error {
7001 SetRSTAvoidanceDelay(t, timeout)
7002 t.Logf("set RST avoidance delay to %v", timeout)
7003
7004 var (
7005 handlerN int64
7006 handlerErr error
7007 )
7008 echo := HandlerFunc(func(w ResponseWriter, r *Request) {
7009 var buf bytes.Buffer
7010 handlerN, handlerErr = io.Copy(&buf, r.Body)
7011 io.Copy(w, &buf)
7012 })
7013
7014 cst := newClientServerTest(t, mode, MaxBytesHandler(echo, maxSize))
7015
7016
7017 defer cst.close()
7018 ts := cst.ts
7019 c := ts.Client()
7020
7021 body := strings.Repeat("a", int(requestSize))
7022 var wg sync.WaitGroup
7023 defer wg.Wait()
7024 getBody := func() (io.ReadCloser, error) {
7025 wg.Add(1)
7026 body := &wgReadCloser{
7027 Reader: strings.NewReader(body),
7028 wg: &wg,
7029 }
7030 return body, nil
7031 }
7032 reqBody, _ := getBody()
7033 req, err := NewRequest("POST", ts.URL, reqBody)
7034 if err != nil {
7035 reqBody.Close()
7036 t.Fatal(err)
7037 }
7038 req.ContentLength = int64(len(body))
7039 req.GetBody = getBody
7040 req.Header.Set("Content-Type", "text/plain")
7041
7042 var buf strings.Builder
7043 res, err := c.Do(req)
7044 if err != nil {
7045 return fmt.Errorf("unexpected connection error: %v", err)
7046 } else {
7047 _, err = io.Copy(&buf, res.Body)
7048 res.Body.Close()
7049 if err != nil {
7050 return fmt.Errorf("unexpected read error: %v", err)
7051 }
7052 }
7053
7054
7055
7056
7057 if handlerN > maxSize {
7058 t.Errorf("expected max request body %d; got %d", maxSize, handlerN)
7059 }
7060 if requestSize > maxSize && handlerErr == nil {
7061 t.Error("expected error on handler side; got nil")
7062 }
7063 if requestSize <= maxSize {
7064 if handlerErr != nil {
7065 t.Errorf("%d expected nil error on handler side; got %v", requestSize, handlerErr)
7066 }
7067 if handlerN != requestSize {
7068 t.Errorf("expected request of size %d; got %d", requestSize, handlerN)
7069 }
7070 }
7071 if buf.Len() != int(handlerN) {
7072 t.Errorf("expected echo of size %d; got %d", handlerN, buf.Len())
7073 }
7074
7075 return nil
7076 })
7077 }
7078
7079 func TestEarlyHints(t *testing.T) {
7080 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
7081 h := w.Header()
7082 h.Add("Link", "</style.css>; rel=preload; as=style")
7083 h.Add("Link", "</script.js>; rel=preload; as=script")
7084 w.WriteHeader(StatusEarlyHints)
7085
7086 h.Add("Link", "</foo.js>; rel=preload; as=script")
7087 w.WriteHeader(StatusEarlyHints)
7088
7089 w.Write([]byte("stuff"))
7090 }))
7091
7092 got := ht.rawResponse("GET / HTTP/1.1\nHost: golang.org")
7093 expected := "HTTP/1.1 103 Early Hints\r\nLink: </style.css>; rel=preload; as=style\r\nLink: </script.js>; rel=preload; as=script\r\n\r\nHTTP/1.1 103 Early Hints\r\nLink: </style.css>; rel=preload; as=style\r\nLink: </script.js>; rel=preload; as=script\r\nLink: </foo.js>; rel=preload; as=script\r\n\r\nHTTP/1.1 200 OK\r\nLink: </style.css>; rel=preload; as=style\r\nLink: </script.js>; rel=preload; as=script\r\nLink: </foo.js>; rel=preload; as=script\r\nDate: "
7094 if !strings.Contains(got, expected) {
7095 t.Errorf("unexpected response; got %q; should start by %q", got, expected)
7096 }
7097 }
7098 func TestProcessing(t *testing.T) {
7099 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
7100 w.WriteHeader(StatusProcessing)
7101 w.Write([]byte("stuff"))
7102 }))
7103
7104 got := ht.rawResponse("GET / HTTP/1.1\nHost: golang.org")
7105 expected := "HTTP/1.1 102 Processing\r\n\r\nHTTP/1.1 200 OK\r\nDate: "
7106 if !strings.Contains(got, expected) {
7107 t.Errorf("unexpected response; got %q; should start by %q", got, expected)
7108 }
7109 }
7110
7111 func TestParseFormCleanup(t *testing.T) { run(t, testParseFormCleanup) }
7112 func testParseFormCleanup(t *testing.T, mode testMode) {
7113 if mode == http2Mode {
7114 t.Skip("https://go.dev/issue/20253")
7115 }
7116
7117 const maxMemory = 1024
7118 const key = "file"
7119
7120 if runtime.GOOS == "windows" {
7121
7122 t.Skip("https://go.dev/issue/25965")
7123 }
7124
7125 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7126 r.ParseMultipartForm(maxMemory)
7127 f, _, err := r.FormFile(key)
7128 if err != nil {
7129 t.Errorf("r.FormFile(%q) = %v", key, err)
7130 return
7131 }
7132 of, ok := f.(*os.File)
7133 if !ok {
7134 t.Errorf("r.FormFile(%q) returned type %T, want *os.File", key, f)
7135 return
7136 }
7137 w.Write([]byte(of.Name()))
7138 }))
7139
7140 fBuf := new(bytes.Buffer)
7141 mw := multipart.NewWriter(fBuf)
7142 mf, err := mw.CreateFormFile(key, "myfile.txt")
7143 if err != nil {
7144 t.Fatal(err)
7145 }
7146 if _, err := mf.Write(bytes.Repeat([]byte("A"), maxMemory*2)); err != nil {
7147 t.Fatal(err)
7148 }
7149 if err := mw.Close(); err != nil {
7150 t.Fatal(err)
7151 }
7152 req, err := NewRequest("POST", cst.ts.URL, fBuf)
7153 if err != nil {
7154 t.Fatal(err)
7155 }
7156 req.Header.Set("Content-Type", mw.FormDataContentType())
7157 res, err := cst.c.Do(req)
7158 if err != nil {
7159 t.Fatal(err)
7160 }
7161 defer res.Body.Close()
7162 fname, err := io.ReadAll(res.Body)
7163 if err != nil {
7164 t.Fatal(err)
7165 }
7166 cst.close()
7167 if _, err := os.Stat(string(fname)); !errors.Is(err, os.ErrNotExist) {
7168 t.Errorf("file %q exists after HTTP handler returned", string(fname))
7169 }
7170 }
7171
7172 func TestHeadBody(t *testing.T) {
7173 const identityMode = false
7174 const chunkedMode = true
7175 run(t, func(t *testing.T, mode testMode) {
7176 t.Run("identity", func(t *testing.T) { testHeadBody(t, mode, identityMode, "HEAD") })
7177 t.Run("chunked", func(t *testing.T) { testHeadBody(t, mode, chunkedMode, "HEAD") })
7178 })
7179 }
7180
7181 func TestGetBody(t *testing.T) {
7182 const identityMode = false
7183 const chunkedMode = true
7184 run(t, func(t *testing.T, mode testMode) {
7185 t.Run("identity", func(t *testing.T) { testHeadBody(t, mode, identityMode, "GET") })
7186 t.Run("chunked", func(t *testing.T) { testHeadBody(t, mode, chunkedMode, "GET") })
7187 })
7188 }
7189
7190 func testHeadBody(t *testing.T, mode testMode, chunked bool, method string) {
7191 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7192 b, err := io.ReadAll(r.Body)
7193 if err != nil {
7194 t.Errorf("server reading body: %v", err)
7195 return
7196 }
7197 w.Header().Set("X-Request-Body", string(b))
7198 w.Header().Set("Content-Length", "0")
7199 }))
7200 defer cst.close()
7201 for _, reqBody := range []string{
7202 "",
7203 "",
7204 "request_body",
7205 "",
7206 } {
7207 var bodyReader io.Reader
7208 if reqBody != "" {
7209 bodyReader = strings.NewReader(reqBody)
7210 if chunked {
7211 bodyReader = bufio.NewReader(bodyReader)
7212 }
7213 }
7214 req, err := NewRequest(method, cst.ts.URL, bodyReader)
7215 if err != nil {
7216 t.Fatal(err)
7217 }
7218 res, err := cst.c.Do(req)
7219 if err != nil {
7220 t.Fatal(err)
7221 }
7222 res.Body.Close()
7223 if got, want := res.StatusCode, 200; got != want {
7224 t.Errorf("%v request with %d-byte body: StatusCode = %v, want %v", method, len(reqBody), got, want)
7225 }
7226 if got, want := res.Header.Get("X-Request-Body"), reqBody; got != want {
7227 t.Errorf("%v request with %d-byte body: handler read body %q, want %q", method, len(reqBody), got, want)
7228 }
7229 }
7230 }
7231
7232
7233
7234 func TestDisableContentLength(t *testing.T) { run(t, testDisableContentLength) }
7235 func testDisableContentLength(t *testing.T, mode testMode) {
7236 noCL := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7237 w.Header()["Content-Length"] = nil
7238 fmt.Fprintf(w, "OK")
7239 }))
7240
7241 res, err := noCL.c.Get(noCL.ts.URL)
7242 if err != nil {
7243 t.Fatal(err)
7244 }
7245 if got, haveCL := res.Header["Content-Length"]; haveCL {
7246 t.Errorf("Unexpected Content-Length: %q", got)
7247 }
7248 if err := res.Body.Close(); err != nil {
7249 t.Fatal(err)
7250 }
7251
7252 withCL := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7253 fmt.Fprintf(w, "OK")
7254 }))
7255
7256 res, err = withCL.c.Get(withCL.ts.URL)
7257 if err != nil {
7258 t.Fatal(err)
7259 }
7260 if got := res.Header.Get("Content-Length"); got != "2" {
7261 t.Errorf("Content-Length: %q; want 2", got)
7262 }
7263 if err := res.Body.Close(); err != nil {
7264 t.Fatal(err)
7265 }
7266 }
7267
7268 func TestErrorContentLength(t *testing.T) { run(t, testErrorContentLength) }
7269 func testErrorContentLength(t *testing.T, mode testMode) {
7270 const errorBody = "an error occurred"
7271 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7272 w.Header().Set("Content-Length", "1000")
7273 Error(w, errorBody, 400)
7274 }))
7275 res, err := cst.c.Get(cst.ts.URL)
7276 if err != nil {
7277 t.Fatalf("Get(%q) = %v", cst.ts.URL, err)
7278 }
7279 defer res.Body.Close()
7280 body, err := io.ReadAll(res.Body)
7281 if err != nil {
7282 t.Fatalf("io.ReadAll(res.Body) = %v", err)
7283 }
7284 if string(body) != errorBody+"\n" {
7285 t.Fatalf("read body: %q, want %q", string(body), errorBody)
7286 }
7287 }
7288
7289 func TestError(t *testing.T) {
7290 w := httptest.NewRecorder()
7291 w.Header().Set("Content-Length", "1")
7292 w.Header().Set("X-Content-Type-Options", "scratch and sniff")
7293 w.Header().Set("Other", "foo")
7294 Error(w, "oops", 432)
7295
7296 h := w.Header()
7297 for _, hdr := range []string{"Content-Length"} {
7298 if v, ok := h[hdr]; ok {
7299 t.Errorf("%s: %q, want not present", hdr, v)
7300 }
7301 }
7302 if v := h.Get("Content-Type"); v != "text/plain; charset=utf-8" {
7303 t.Errorf("Content-Type: %q, want %q", v, "text/plain; charset=utf-8")
7304 }
7305 if v := h.Get("X-Content-Type-Options"); v != "nosniff" {
7306 t.Errorf("X-Content-Type-Options: %q, want %q", v, "nosniff")
7307 }
7308 }
7309
7310 func TestServerReadAfterWriteHeader100Continue(t *testing.T) {
7311 run(t, testServerReadAfterWriteHeader100Continue)
7312 }
7313 func testServerReadAfterWriteHeader100Continue(t *testing.T, mode testMode) {
7314 t.Skip("https://go.dev/issue/67555")
7315 body := []byte("body")
7316 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7317 w.WriteHeader(200)
7318 NewResponseController(w).Flush()
7319 io.ReadAll(r.Body)
7320 w.Write(body)
7321 }), func(tr *Transport) {
7322 tr.ExpectContinueTimeout = 24 * time.Hour
7323 })
7324
7325 req, _ := NewRequest("GET", cst.ts.URL, strings.NewReader("body"))
7326 req.Header.Set("Expect", "100-continue")
7327 res, err := cst.c.Do(req)
7328 if err != nil {
7329 t.Fatalf("Get(%q) = %v", cst.ts.URL, err)
7330 }
7331 defer res.Body.Close()
7332 got, err := io.ReadAll(res.Body)
7333 if err != nil {
7334 t.Fatalf("io.ReadAll(res.Body) = %v", err)
7335 }
7336 if !bytes.Equal(got, body) {
7337 t.Fatalf("response body = %q, want %q", got, body)
7338 }
7339 }
7340
7341 func TestServerReadAfterHandlerDone100Continue(t *testing.T) {
7342 run(t, testServerReadAfterHandlerDone100Continue)
7343 }
7344 func testServerReadAfterHandlerDone100Continue(t *testing.T, mode testMode) {
7345 t.Skip("https://go.dev/issue/67555")
7346 readyc := make(chan struct{})
7347 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7348 go func() {
7349 <-readyc
7350 io.ReadAll(r.Body)
7351 <-readyc
7352 }()
7353 }), func(tr *Transport) {
7354 tr.ExpectContinueTimeout = 24 * time.Hour
7355 })
7356
7357 req, _ := NewRequest("GET", cst.ts.URL, strings.NewReader("body"))
7358 req.Header.Set("Expect", "100-continue")
7359 res, err := cst.c.Do(req)
7360 if err != nil {
7361 t.Fatalf("Get(%q) = %v", cst.ts.URL, err)
7362 }
7363 res.Body.Close()
7364 readyc <- struct{}{}
7365 readyc <- struct{}{}
7366 }
7367
7368 func TestServerReadAfterHandlerAbort100Continue(t *testing.T) {
7369 run(t, testServerReadAfterHandlerAbort100Continue)
7370 }
7371 func testServerReadAfterHandlerAbort100Continue(t *testing.T, mode testMode) {
7372 t.Skip("https://go.dev/issue/67555")
7373 readyc := make(chan struct{})
7374 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7375 go func() {
7376 <-readyc
7377 io.ReadAll(r.Body)
7378 <-readyc
7379 }()
7380 panic(ErrAbortHandler)
7381 }), func(tr *Transport) {
7382 tr.ExpectContinueTimeout = 24 * time.Hour
7383 })
7384
7385 req, _ := NewRequest("GET", cst.ts.URL, strings.NewReader("body"))
7386 req.Header.Set("Expect", "100-continue")
7387 res, err := cst.c.Do(req)
7388 if err == nil {
7389 res.Body.Close()
7390 }
7391 readyc <- struct{}{}
7392 readyc <- struct{}{}
7393 }
7394
7395 func TestInvalidChunkedBodies(t *testing.T) {
7396 for _, test := range []struct {
7397 name string
7398 b string
7399 }{{
7400 name: "bare LF in chunk size",
7401 b: "1\na\r\n0\r\n\r\n",
7402 }, {
7403 name: "bare LF at body end",
7404 b: "1\r\na\r\n0\r\n\n",
7405 }} {
7406 t.Run(test.name, func(t *testing.T) {
7407 reqc := make(chan error)
7408 ts := newClientServerTest(t, http1Mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7409 got, err := io.ReadAll(r.Body)
7410 if err == nil {
7411 t.Logf("read body: %q", got)
7412 }
7413 reqc <- err
7414 })).ts
7415
7416 serverURL, err := url.Parse(ts.URL)
7417 if err != nil {
7418 t.Fatal(err)
7419 }
7420
7421 conn, err := net.Dial("tcp", serverURL.Host)
7422 if err != nil {
7423 t.Fatal(err)
7424 }
7425
7426 if _, err := conn.Write([]byte(
7427 "POST / HTTP/1.1\r\n" +
7428 "Host: localhost\r\n" +
7429 "Transfer-Encoding: chunked\r\n" +
7430 "Connection: close\r\n" +
7431 "\r\n" +
7432 test.b)); err != nil {
7433 t.Fatal(err)
7434 }
7435 conn.(*net.TCPConn).CloseWrite()
7436
7437 if err := <-reqc; err == nil {
7438 t.Errorf("server handler: io.ReadAll(r.Body) succeeded, want error")
7439 }
7440 })
7441 }
7442 }
7443
7444
7445 func TestServerTLSNextProtos(t *testing.T) {
7446 run(t, testServerTLSNextProtos, []testMode{https1Mode, http2Mode})
7447 }
7448 func testServerTLSNextProtos(t *testing.T, mode testMode) {
7449 CondSkipHTTP2(t)
7450
7451 cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
7452 if err != nil {
7453 t.Fatal(err)
7454 }
7455 leafCert, err := x509.ParseCertificate(cert.Certificate[0])
7456 if err != nil {
7457 t.Fatal(err)
7458 }
7459 certpool := x509.NewCertPool()
7460 certpool.AddCert(leafCert)
7461
7462 protos := new(Protocols)
7463 switch mode {
7464 case https1Mode:
7465 protos.SetHTTP1(true)
7466 case http2Mode:
7467 protos.SetHTTP2(true)
7468 }
7469
7470 wantNextProtos := []string{"http/1.1", "h2", "other"}
7471 nextProtos := slices.Clone(wantNextProtos)
7472
7473
7474 srv := &Server{
7475 TLSConfig: &tls.Config{
7476 Certificates: []tls.Certificate{cert},
7477 NextProtos: nextProtos,
7478 },
7479 Handler: HandlerFunc(func(w ResponseWriter, req *Request) {}),
7480 Protocols: protos,
7481 }
7482 tr := &Transport{
7483 TLSClientConfig: &tls.Config{
7484 RootCAs: certpool,
7485 NextProtos: nextProtos,
7486 },
7487 Protocols: protos,
7488 }
7489
7490 listener := newLocalListener(t)
7491 srvc := make(chan error, 1)
7492 go func() {
7493 srvc <- srv.ServeTLS(listener, "", "")
7494 }()
7495 t.Cleanup(func() {
7496 srv.Close()
7497 <-srvc
7498 })
7499
7500 client := &Client{Transport: tr}
7501 resp, err := client.Get("https://" + listener.Addr().String())
7502 if err != nil {
7503 t.Fatal(err)
7504 }
7505 resp.Body.Close()
7506
7507 if !slices.Equal(nextProtos, wantNextProtos) {
7508 t.Fatalf("after running test: original NextProtos slice = %v, want %v", nextProtos, wantNextProtos)
7509 }
7510 }
7511
View as plain text