Source file
src/net/http/serve_test.go
1
2
3
4
5
6
7 package http_test
8
9 import (
10 "bufio"
11 "bytes"
12 "compress/gzip"
13 "compress/zlib"
14 "context"
15 "crypto/tls"
16 "crypto/x509"
17 "encoding/json"
18 "errors"
19 "fmt"
20 "internal/synctest"
21 "internal/testenv"
22 "io"
23 "log"
24 "math/rand"
25 "mime/multipart"
26 "net"
27 . "net/http"
28 "net/http/httptest"
29 "net/http/httptrace"
30 "net/http/httputil"
31 "net/http/internal"
32 "net/http/internal/testcert"
33 "net/url"
34 "os"
35 "path/filepath"
36 "reflect"
37 "regexp"
38 "runtime"
39 "slices"
40 "strconv"
41 "strings"
42 "sync"
43 "sync/atomic"
44 "syscall"
45 "testing"
46 "time"
47 )
48
49 type dummyAddr string
50 type oneConnListener struct {
51 conn net.Conn
52 }
53
54 func (l *oneConnListener) Accept() (c net.Conn, err error) {
55 c = l.conn
56 if c == nil {
57 err = io.EOF
58 return
59 }
60 err = nil
61 l.conn = nil
62 return
63 }
64
65 func (l *oneConnListener) Close() error {
66 return nil
67 }
68
69 func (l *oneConnListener) Addr() net.Addr {
70 return dummyAddr("test-address")
71 }
72
73 func (a dummyAddr) Network() string {
74 return string(a)
75 }
76
77 func (a dummyAddr) String() string {
78 return string(a)
79 }
80
81 type noopConn struct{}
82
83 func (noopConn) LocalAddr() net.Addr { return dummyAddr("local-addr") }
84 func (noopConn) RemoteAddr() net.Addr { return dummyAddr("remote-addr") }
85 func (noopConn) SetDeadline(t time.Time) error { return nil }
86 func (noopConn) SetReadDeadline(t time.Time) error { return nil }
87 func (noopConn) SetWriteDeadline(t time.Time) error { return nil }
88
89 type rwTestConn struct {
90 io.Reader
91 io.Writer
92 noopConn
93
94 closeFunc func() error
95 closec chan bool
96 }
97
98 func (c *rwTestConn) Close() error {
99 if c.closeFunc != nil {
100 return c.closeFunc()
101 }
102 select {
103 case c.closec <- true:
104 default:
105 }
106 return nil
107 }
108
109 type testConn struct {
110 readMu sync.Mutex
111 readBuf bytes.Buffer
112 writeBuf bytes.Buffer
113 closec chan bool
114 noopConn
115 }
116
117 func newTestConn() *testConn {
118 return &testConn{closec: make(chan bool, 1)}
119 }
120
121 func (c *testConn) Read(b []byte) (int, error) {
122 c.readMu.Lock()
123 defer c.readMu.Unlock()
124 return c.readBuf.Read(b)
125 }
126
127 func (c *testConn) Write(b []byte) (int, error) {
128 return c.writeBuf.Write(b)
129 }
130
131 func (c *testConn) Close() error {
132 select {
133 case c.closec <- true:
134 default:
135 }
136 return nil
137 }
138
139
140
141 func reqBytes(req string) []byte {
142 return []byte(strings.ReplaceAll(strings.TrimSpace(req), "\n", "\r\n") + "\r\n\r\n")
143 }
144
145 type handlerTest struct {
146 logbuf bytes.Buffer
147 handler Handler
148 }
149
150 func newHandlerTest(h Handler) handlerTest {
151 return handlerTest{handler: h}
152 }
153
154 func (ht *handlerTest) rawResponse(req string) string {
155 reqb := reqBytes(req)
156 var output strings.Builder
157 conn := &rwTestConn{
158 Reader: bytes.NewReader(reqb),
159 Writer: &output,
160 closec: make(chan bool, 1),
161 }
162 ln := &oneConnListener{conn: conn}
163 srv := &Server{
164 ErrorLog: log.New(&ht.logbuf, "", 0),
165 Handler: ht.handler,
166 }
167 go srv.Serve(ln)
168 <-conn.closec
169 return output.String()
170 }
171
172 func TestConsumingBodyOnNextConn(t *testing.T) {
173 t.Parallel()
174 defer afterTest(t)
175 conn := new(testConn)
176 for i := 0; i < 2; i++ {
177 conn.readBuf.Write([]byte(
178 "POST / HTTP/1.1\r\n" +
179 "Host: test\r\n" +
180 "Content-Length: 11\r\n" +
181 "\r\n" +
182 "foo=1&bar=1"))
183 }
184
185 reqNum := 0
186 ch := make(chan *Request)
187 servech := make(chan error)
188 listener := &oneConnListener{conn}
189 handler := func(res ResponseWriter, req *Request) {
190 reqNum++
191 ch <- req
192 }
193
194 go func() {
195 servech <- Serve(listener, HandlerFunc(handler))
196 }()
197
198 var req *Request
199 req = <-ch
200 if req == nil {
201 t.Fatal("Got nil first request.")
202 }
203 if req.Method != "POST" {
204 t.Errorf("For request #1's method, got %q; expected %q",
205 req.Method, "POST")
206 }
207
208 req = <-ch
209 if req == nil {
210 t.Fatal("Got nil first request.")
211 }
212 if req.Method != "POST" {
213 t.Errorf("For request #2's method, got %q; expected %q",
214 req.Method, "POST")
215 }
216
217 if serveerr := <-servech; serveerr != io.EOF {
218 t.Errorf("Serve returned %q; expected EOF", serveerr)
219 }
220 }
221
222 type stringHandler string
223
224 func (s stringHandler) ServeHTTP(w ResponseWriter, r *Request) {
225 w.Header().Set("Result", string(s))
226 }
227
228 var handlers = []struct {
229 pattern string
230 msg string
231 }{
232 {"/", "Default"},
233 {"/someDir/", "someDir"},
234 {"/#/", "hash"},
235 {"someHost.com/someDir/", "someHost.com/someDir"},
236 }
237
238 var vtests = []struct {
239 url string
240 expected string
241 }{
242 {"http://localhost/someDir/apage", "someDir"},
243 {"http://localhost/%23/apage", "hash"},
244 {"http://localhost/otherDir/apage", "Default"},
245 {"http://someHost.com/someDir/apage", "someHost.com/someDir"},
246 {"http://otherHost.com/someDir/apage", "someDir"},
247 {"http://otherHost.com/aDir/apage", "Default"},
248
249 {"http://localhost/someDir", "/someDir/"},
250 {"http://localhost/%23", "/%23/"},
251 {"http://someHost.com/someDir", "/someDir/"},
252 }
253
254 func TestHostHandlers(t *testing.T) { run(t, testHostHandlers, []testMode{http1Mode}) }
255 func testHostHandlers(t *testing.T, mode testMode) {
256 mux := NewServeMux()
257 for _, h := range handlers {
258 mux.Handle(h.pattern, stringHandler(h.msg))
259 }
260 ts := newClientServerTest(t, mode, mux).ts
261
262 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
263 if err != nil {
264 t.Fatal(err)
265 }
266 defer conn.Close()
267 cc := httputil.NewClientConn(conn, nil)
268 for _, vt := range vtests {
269 var r *Response
270 var req Request
271 if req.URL, err = url.Parse(vt.url); err != nil {
272 t.Errorf("cannot parse url: %v", err)
273 continue
274 }
275 if err := cc.Write(&req); err != nil {
276 t.Errorf("writing request: %v", err)
277 continue
278 }
279 r, err := cc.Read(&req)
280 if err != nil {
281 t.Errorf("reading response: %v", err)
282 continue
283 }
284 switch r.StatusCode {
285 case StatusOK:
286 s := r.Header.Get("Result")
287 if s != vt.expected {
288 t.Errorf("Get(%q) = %q, want %q", vt.url, s, vt.expected)
289 }
290 case StatusMovedPermanently:
291 s := r.Header.Get("Location")
292 if s != vt.expected {
293 t.Errorf("Get(%q) = %q, want %q", vt.url, s, vt.expected)
294 }
295 default:
296 t.Errorf("Get(%q) unhandled status code %d", vt.url, r.StatusCode)
297 }
298 }
299 }
300
301 var serveMuxRegister = []struct {
302 pattern string
303 h Handler
304 }{
305 {"/dir/", serve(200)},
306 {"/search", serve(201)},
307 {"codesearch.google.com/search", serve(202)},
308 {"codesearch.google.com/", serve(203)},
309 {"example.com/", HandlerFunc(checkQueryStringHandler)},
310 }
311
312
313 func serve(code int) HandlerFunc {
314 return func(w ResponseWriter, r *Request) {
315 w.WriteHeader(code)
316 }
317 }
318
319
320
321
322 func checkQueryStringHandler(w ResponseWriter, r *Request) {
323 u := *r.URL
324 u.Scheme = "http"
325 u.Host = r.Host
326 u.RawQuery = ""
327 if "http://"+r.URL.RawQuery == u.String() {
328 w.WriteHeader(200)
329 } else {
330 w.WriteHeader(500)
331 }
332 }
333
334 var serveMuxTests = []struct {
335 method string
336 host string
337 path string
338 code int
339 pattern string
340 }{
341 {"GET", "google.com", "/", 404, ""},
342 {"GET", "google.com", "/dir", 301, "/dir/"},
343 {"GET", "google.com", "/dir/", 200, "/dir/"},
344 {"GET", "google.com", "/dir/file", 200, "/dir/"},
345 {"GET", "google.com", "/search", 201, "/search"},
346 {"GET", "google.com", "/search/", 404, ""},
347 {"GET", "google.com", "/search/foo", 404, ""},
348 {"GET", "codesearch.google.com", "/search", 202, "codesearch.google.com/search"},
349 {"GET", "codesearch.google.com", "/search/", 203, "codesearch.google.com/"},
350 {"GET", "codesearch.google.com", "/search/foo", 203, "codesearch.google.com/"},
351 {"GET", "codesearch.google.com", "/", 203, "codesearch.google.com/"},
352 {"GET", "codesearch.google.com:443", "/", 203, "codesearch.google.com/"},
353 {"GET", "images.google.com", "/search", 201, "/search"},
354 {"GET", "images.google.com", "/search/", 404, ""},
355 {"GET", "images.google.com", "/search/foo", 404, ""},
356 {"GET", "google.com", "/../search", 301, "/search"},
357 {"GET", "google.com", "/dir/..", 301, ""},
358 {"GET", "google.com", "/dir/..", 301, ""},
359 {"GET", "google.com", "/dir/./file", 301, "/dir/"},
360
361
362
363 {"CONNECT", "google.com", "/dir", 301, "/dir/"},
364 {"CONNECT", "google.com", "/../search", 404, ""},
365 {"CONNECT", "google.com", "/dir/..", 200, "/dir/"},
366 {"CONNECT", "google.com", "/dir/..", 200, "/dir/"},
367 {"CONNECT", "google.com", "/dir/./file", 200, "/dir/"},
368 }
369
370 func TestServeMuxHandler(t *testing.T) {
371 setParallel(t)
372 mux := NewServeMux()
373 for _, e := range serveMuxRegister {
374 mux.Handle(e.pattern, e.h)
375 }
376
377 for _, tt := range serveMuxTests {
378 r := &Request{
379 Method: tt.method,
380 Host: tt.host,
381 URL: &url.URL{
382 Path: tt.path,
383 },
384 }
385 h, pattern := mux.Handler(r)
386 rr := httptest.NewRecorder()
387 h.ServeHTTP(rr, r)
388 if pattern != tt.pattern || rr.Code != tt.code {
389 t.Errorf("%s %s %s = %d, %q, want %d, %q", tt.method, tt.host, tt.path, rr.Code, pattern, tt.code, tt.pattern)
390 }
391 }
392 }
393
394
395 func TestServeMuxHandleFuncWithNilHandler(t *testing.T) {
396 setParallel(t)
397 defer func() {
398 if err := recover(); err == nil {
399 t.Error("expected call to mux.HandleFunc to panic")
400 }
401 }()
402 mux := NewServeMux()
403 mux.HandleFunc("/", nil)
404 }
405
406 var serveMuxTests2 = []struct {
407 method string
408 host string
409 url string
410 code int
411 redirOk bool
412 }{
413 {"GET", "google.com", "/", 404, false},
414 {"GET", "example.com", "/test/?example.com/test/", 200, false},
415 {"GET", "example.com", "test/?example.com/test/", 200, true},
416 }
417
418
419
420 func TestServeMuxHandlerRedirects(t *testing.T) {
421 setParallel(t)
422 mux := NewServeMux()
423 for _, e := range serveMuxRegister {
424 mux.Handle(e.pattern, e.h)
425 }
426
427 for _, tt := range serveMuxTests2 {
428 tries := 1
429 turl := tt.url
430 for {
431 u, e := url.Parse(turl)
432 if e != nil {
433 t.Fatal(e)
434 }
435 r := &Request{
436 Method: tt.method,
437 Host: tt.host,
438 URL: u,
439 }
440 h, _ := mux.Handler(r)
441 rr := httptest.NewRecorder()
442 h.ServeHTTP(rr, r)
443 if rr.Code != 301 {
444 if rr.Code != tt.code {
445 t.Errorf("%s %s %s = %d, want %d", tt.method, tt.host, tt.url, rr.Code, tt.code)
446 }
447 break
448 }
449 if !tt.redirOk {
450 t.Errorf("%s %s %s, unexpected redirect", tt.method, tt.host, tt.url)
451 break
452 }
453 turl = rr.HeaderMap.Get("Location")
454 tries--
455 }
456 if tries < 0 {
457 t.Errorf("%s %s %s, too many redirects", tt.method, tt.host, tt.url)
458 }
459 }
460 }
461
462
463 func TestMuxRedirectLeadingSlashes(t *testing.T) {
464 setParallel(t)
465 paths := []string{"//foo.txt", "///foo.txt", "/../../foo.txt"}
466 for _, path := range paths {
467 req, err := ReadRequest(bufio.NewReader(strings.NewReader("GET " + path + " HTTP/1.1\r\nHost: test\r\n\r\n")))
468 if err != nil {
469 t.Errorf("%s", err)
470 }
471 mux := NewServeMux()
472 resp := httptest.NewRecorder()
473
474 mux.ServeHTTP(resp, req)
475
476 if loc, expected := resp.Header().Get("Location"), "/foo.txt"; loc != expected {
477 t.Errorf("Expected Location header set to %q; got %q", expected, loc)
478 return
479 }
480
481 if code, expected := resp.Code, StatusMovedPermanently; code != expected {
482 t.Errorf("Expected response code of StatusMovedPermanently; got %d", code)
483 return
484 }
485 }
486 }
487
488
489
490
491
492 func TestServeWithSlashRedirectKeepsQueryString(t *testing.T) {
493 run(t, testServeWithSlashRedirectKeepsQueryString, []testMode{http1Mode})
494 }
495 func testServeWithSlashRedirectKeepsQueryString(t *testing.T, mode testMode) {
496 writeBackQuery := func(w ResponseWriter, r *Request) {
497 fmt.Fprintf(w, "%s", r.URL.RawQuery)
498 }
499
500 mux := NewServeMux()
501 mux.HandleFunc("/testOne", writeBackQuery)
502 mux.HandleFunc("/testTwo/", writeBackQuery)
503 mux.HandleFunc("/testThree", writeBackQuery)
504 mux.HandleFunc("/testThree/", func(w ResponseWriter, r *Request) {
505 fmt.Fprintf(w, "%s:bar", r.URL.RawQuery)
506 })
507
508 ts := newClientServerTest(t, mode, mux).ts
509
510 tests := [...]struct {
511 path string
512 method string
513 want string
514 statusOk bool
515 }{
516 0: {"/testOne?this=that", "GET", "this=that", true},
517 1: {"/testTwo?foo=bar", "GET", "foo=bar", true},
518 2: {"/testTwo?a=1&b=2&a=3", "GET", "a=1&b=2&a=3", true},
519 3: {"/testTwo?", "GET", "", true},
520 4: {"/testThree?foo", "GET", "foo", true},
521 5: {"/testThree/?foo", "GET", "foo:bar", true},
522 6: {"/testThree?foo", "CONNECT", "foo", true},
523 7: {"/testThree/?foo", "CONNECT", "foo:bar", true},
524
525
526 8: {"/testOne/foo/..?foo", "GET", "foo", true},
527 9: {"/testOne/foo/..?foo", "CONNECT", "404 page not found\n", false},
528 }
529
530 for i, tt := range tests {
531 req, _ := NewRequest(tt.method, ts.URL+tt.path, nil)
532 res, err := ts.Client().Do(req)
533 if err != nil {
534 continue
535 }
536 slurp, _ := io.ReadAll(res.Body)
537 res.Body.Close()
538 if !tt.statusOk {
539 if got, want := res.StatusCode, 404; got != want {
540 t.Errorf("#%d: Status = %d; want = %d", i, got, want)
541 }
542 }
543 if got, want := string(slurp), tt.want; got != want {
544 t.Errorf("#%d: Body = %q; want = %q", i, got, want)
545 }
546 }
547 }
548
549 func TestServeWithSlashRedirectForHostPatterns(t *testing.T) {
550 setParallel(t)
551
552 mux := NewServeMux()
553 mux.Handle("example.com/pkg/foo/", stringHandler("example.com/pkg/foo/"))
554 mux.Handle("example.com/pkg/bar", stringHandler("example.com/pkg/bar"))
555 mux.Handle("example.com/pkg/bar/", stringHandler("example.com/pkg/bar/"))
556 mux.Handle("example.com:3000/pkg/connect/", stringHandler("example.com:3000/pkg/connect/"))
557 mux.Handle("example.com:9000/", stringHandler("example.com:9000/"))
558 mux.Handle("/pkg/baz/", stringHandler("/pkg/baz/"))
559
560 tests := []struct {
561 method string
562 url string
563 code int
564 loc string
565 want string
566 }{
567 {"GET", "http://example.com/", 404, "", ""},
568 {"GET", "http://example.com/pkg/foo", 301, "/pkg/foo/", ""},
569 {"GET", "http://example.com/pkg/bar", 200, "", "example.com/pkg/bar"},
570 {"GET", "http://example.com/pkg/bar/", 200, "", "example.com/pkg/bar/"},
571 {"GET", "http://example.com/pkg/baz", 301, "/pkg/baz/", ""},
572 {"GET", "http://example.com:3000/pkg/foo", 301, "/pkg/foo/", ""},
573 {"CONNECT", "http://example.com/", 404, "", ""},
574 {"CONNECT", "http://example.com:3000/", 404, "", ""},
575 {"CONNECT", "http://example.com:9000/", 200, "", "example.com:9000/"},
576 {"CONNECT", "http://example.com/pkg/foo", 301, "/pkg/foo/", ""},
577 {"CONNECT", "http://example.com:3000/pkg/foo", 404, "", ""},
578 {"CONNECT", "http://example.com:3000/pkg/baz", 301, "/pkg/baz/", ""},
579 {"CONNECT", "http://example.com:3000/pkg/connect", 301, "/pkg/connect/", ""},
580 }
581
582 for i, tt := range tests {
583 req, _ := NewRequest(tt.method, tt.url, nil)
584 w := httptest.NewRecorder()
585 mux.ServeHTTP(w, req)
586
587 if got, want := w.Code, tt.code; got != want {
588 t.Errorf("#%d: Status = %d; want = %d", i, got, want)
589 }
590
591 if tt.code == 301 {
592 if got, want := w.HeaderMap.Get("Location"), tt.loc; got != want {
593 t.Errorf("#%d: Location = %q; want = %q", i, got, want)
594 }
595 } else {
596 if got, want := w.HeaderMap.Get("Result"), tt.want; got != want {
597 t.Errorf("#%d: Result = %q; want = %q", i, got, want)
598 }
599 }
600 }
601 }
602
603
604
605
606 func TestMuxNoSlashRedirectWithTrailingSlash(t *testing.T) {
607 mux := NewServeMux()
608 mux.HandleFunc("/{x}/", func(w ResponseWriter, r *Request) {
609 fmt.Fprintln(w, "ok")
610 })
611 w := httptest.NewRecorder()
612 req, _ := NewRequest("GET", "/", nil)
613 mux.ServeHTTP(w, req)
614 if g, w := w.Code, 404; g != w {
615 t.Errorf("got %d, want %d", g, w)
616 }
617 }
618
619
620
621
622 func TestMuxNoSlash405WithTrailingSlash(t *testing.T) {
623 mux := NewServeMux()
624 mux.HandleFunc("GET /{x}/", func(w ResponseWriter, r *Request) {
625 fmt.Fprintln(w, "ok")
626 })
627 w := httptest.NewRecorder()
628 req, _ := NewRequest("GET", "/", nil)
629 mux.ServeHTTP(w, req)
630 if g, w := w.Code, 404; g != w {
631 t.Errorf("got %d, want %d", g, w)
632 }
633 }
634
635 func TestShouldRedirectConcurrency(t *testing.T) { run(t, testShouldRedirectConcurrency) }
636 func testShouldRedirectConcurrency(t *testing.T, mode testMode) {
637 mux := NewServeMux()
638 newClientServerTest(t, mode, mux)
639 mux.HandleFunc("/", func(w ResponseWriter, r *Request) {})
640 }
641
642 func BenchmarkServeMux(b *testing.B) { benchmarkServeMux(b, true) }
643 func BenchmarkServeMux_SkipServe(b *testing.B) { benchmarkServeMux(b, false) }
644 func benchmarkServeMux(b *testing.B, runHandler bool) {
645 type test struct {
646 path string
647 code int
648 req *Request
649 }
650
651
652 var tests []test
653 endpoints := []string{"search", "dir", "file", "change", "count", "s"}
654 for _, e := range endpoints {
655 for i := 200; i < 230; i++ {
656 p := fmt.Sprintf("/%s/%d/", e, i)
657 tests = append(tests, test{
658 path: p,
659 code: i,
660 req: &Request{Method: "GET", Host: "localhost", URL: &url.URL{Path: p}},
661 })
662 }
663 }
664 mux := NewServeMux()
665 for _, tt := range tests {
666 mux.Handle(tt.path, serve(tt.code))
667 }
668
669 rw := httptest.NewRecorder()
670 b.ReportAllocs()
671 b.ResetTimer()
672 for i := 0; i < b.N; i++ {
673 for _, tt := range tests {
674 *rw = httptest.ResponseRecorder{}
675 h, pattern := mux.Handler(tt.req)
676 if runHandler {
677 h.ServeHTTP(rw, tt.req)
678 if pattern != tt.path || rw.Code != tt.code {
679 b.Fatalf("got %d, %q, want %d, %q", rw.Code, pattern, tt.code, tt.path)
680 }
681 }
682 }
683 }
684 }
685
686 func TestServerTimeouts(t *testing.T) { run(t, testServerTimeouts, []testMode{http1Mode}) }
687 func testServerTimeouts(t *testing.T, mode testMode) {
688 runTimeSensitiveTest(t, []time.Duration{
689 10 * time.Millisecond,
690 50 * time.Millisecond,
691 100 * time.Millisecond,
692 500 * time.Millisecond,
693 1 * time.Second,
694 }, func(t *testing.T, timeout time.Duration) error {
695 return testServerTimeoutsWithTimeout(t, timeout, mode)
696 })
697 }
698
699 func testServerTimeoutsWithTimeout(t *testing.T, timeout time.Duration, mode testMode) error {
700 var reqNum atomic.Int32
701 cst := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
702 fmt.Fprintf(res, "req=%d", reqNum.Add(1))
703 }), func(ts *httptest.Server) {
704 ts.Config.ReadTimeout = timeout
705 ts.Config.WriteTimeout = timeout
706 })
707 defer cst.close()
708 ts := cst.ts
709
710
711 c := ts.Client()
712 r, err := c.Get(ts.URL)
713 if err != nil {
714 return fmt.Errorf("http Get #1: %v", err)
715 }
716 got, err := io.ReadAll(r.Body)
717 expected := "req=1"
718 if string(got) != expected || err != nil {
719 return fmt.Errorf("Unexpected response for request #1; got %q ,%v; expected %q, nil",
720 string(got), err, expected)
721 }
722
723
724 t1 := time.Now()
725 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
726 if err != nil {
727 return fmt.Errorf("Dial: %v", err)
728 }
729 buf := make([]byte, 1)
730 n, err := conn.Read(buf)
731 conn.Close()
732 latency := time.Since(t1)
733 if n != 0 || err != io.EOF {
734 return fmt.Errorf("Read = %v, %v, wanted %v, %v", n, err, 0, io.EOF)
735 }
736 minLatency := timeout / 5 * 4
737 if latency < minLatency {
738 return fmt.Errorf("got EOF after %s, want >= %s", latency, minLatency)
739 }
740
741
742
743
744 r, err = c.Get(ts.URL)
745 if err != nil {
746 return fmt.Errorf("http Get #2: %v", err)
747 }
748 got, err = io.ReadAll(r.Body)
749 r.Body.Close()
750 expected = "req=2"
751 if string(got) != expected || err != nil {
752 return fmt.Errorf("Get #2 got %q, %v, want %q, nil", string(got), err, expected)
753 }
754
755 if !testing.Short() {
756 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
757 if err != nil {
758 return fmt.Errorf("long Dial: %v", err)
759 }
760 defer conn.Close()
761 go io.Copy(io.Discard, conn)
762 for i := 0; i < 5; i++ {
763 _, err := conn.Write([]byte("GET / HTTP/1.1\r\nHost: foo\r\n\r\n"))
764 if err != nil {
765 return fmt.Errorf("on write %d: %v", i, err)
766 }
767 time.Sleep(timeout / 2)
768 }
769 }
770 return nil
771 }
772
773 func TestServerReadTimeout(t *testing.T) { run(t, testServerReadTimeout) }
774 func testServerReadTimeout(t *testing.T, mode testMode) {
775 respBody := "response body"
776 for timeout := 5 * time.Millisecond; ; timeout *= 2 {
777 cst := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
778 _, err := io.Copy(io.Discard, req.Body)
779 if !errors.Is(err, os.ErrDeadlineExceeded) {
780 t.Errorf("server timed out reading request body: got err %v; want os.ErrDeadlineExceeded", err)
781 }
782 res.Write([]byte(respBody))
783 }), func(ts *httptest.Server) {
784 ts.Config.ReadHeaderTimeout = -1
785 ts.Config.ReadTimeout = timeout
786 t.Logf("Server.Config.ReadTimeout = %v", timeout)
787 })
788
789 var retries atomic.Int32
790 cst.c.Transport.(*Transport).Proxy = func(*Request) (*url.URL, error) {
791 if retries.Add(1) != 1 {
792 return nil, errors.New("too many retries")
793 }
794 return nil, nil
795 }
796
797 pr, pw := io.Pipe()
798 res, err := cst.c.Post(cst.ts.URL, "text/apocryphal", pr)
799 if err != nil {
800 t.Logf("Get error, retrying: %v", err)
801 cst.close()
802 continue
803 }
804 defer res.Body.Close()
805 got, err := io.ReadAll(res.Body)
806 if string(got) != respBody || err != nil {
807 t.Errorf("client read response body: %q, %v; want %q, nil", string(got), err, respBody)
808 }
809 pw.Close()
810 break
811 }
812 }
813
814 func TestServerNoReadTimeout(t *testing.T) { run(t, testServerNoReadTimeout) }
815 func testServerNoReadTimeout(t *testing.T, mode testMode) {
816 reqBody := "Hello, Gophers!"
817 resBody := "Hi, Gophers!"
818 for _, timeout := range []time.Duration{0, -1} {
819 cst := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
820 ctl := NewResponseController(res)
821 ctl.EnableFullDuplex()
822 res.WriteHeader(StatusOK)
823
824
825 if err := ctl.Flush(); err != nil {
826 t.Errorf("server flush response: %v", err)
827 return
828 }
829 got, err := io.ReadAll(req.Body)
830 if string(got) != reqBody || err != nil {
831 t.Errorf("server read request body: %v; got %q, want %q", err, got, reqBody)
832 }
833 res.Write([]byte(resBody))
834 }), func(ts *httptest.Server) {
835 ts.Config.ReadTimeout = timeout
836 t.Logf("Server.Config.ReadTimeout = %d", timeout)
837 })
838
839 pr, pw := io.Pipe()
840 res, err := cst.c.Post(cst.ts.URL, "text/plain", pr)
841 if err != nil {
842 t.Fatal(err)
843 }
844 defer res.Body.Close()
845
846
847 time.Sleep(10 * time.Millisecond)
848 pw.Write([]byte(reqBody))
849 pw.Close()
850
851 got, err := io.ReadAll(res.Body)
852 if string(got) != resBody || err != nil {
853 t.Errorf("client read response body: %v; got %v, want %q", err, got, resBody)
854 }
855 }
856 }
857
858 func TestServerWriteTimeout(t *testing.T) { run(t, testServerWriteTimeout) }
859 func testServerWriteTimeout(t *testing.T, mode testMode) {
860 for timeout := 5 * time.Millisecond; ; timeout *= 2 {
861 errc := make(chan error, 2)
862 cst := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
863 errc <- nil
864 _, err := io.Copy(res, neverEnding('a'))
865 errc <- err
866 }), func(ts *httptest.Server) {
867 ts.Config.WriteTimeout = timeout
868 t.Logf("Server.Config.WriteTimeout = %v", timeout)
869 })
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888 var retries atomic.Int32
889 cst.c.Transport.(*Transport).Proxy = func(*Request) (*url.URL, error) {
890 if retries.Add(1) != 1 {
891 return nil, errors.New("too many retries")
892 }
893 return nil, nil
894 }
895
896 res, err := cst.c.Get(cst.ts.URL)
897 if err != nil {
898
899 t.Logf("Get error, retrying: %v", err)
900 cst.close()
901 continue
902 }
903 defer res.Body.Close()
904 _, err = io.Copy(io.Discard, res.Body)
905 if err == nil {
906 t.Errorf("client reading from truncated request body: got nil error, want non-nil")
907 }
908 select {
909 case <-errc:
910 err = <-errc
911 if !errors.Is(err, os.ErrDeadlineExceeded) {
912 t.Errorf("server timed out writing request body: got err %v; want os.ErrDeadlineExceeded", err)
913 }
914 return
915 default:
916
917 t.Logf("handler didn't run, retrying")
918 cst.close()
919 }
920 }
921 }
922
923 func TestServerNoWriteTimeout(t *testing.T) { run(t, testServerNoWriteTimeout) }
924 func testServerNoWriteTimeout(t *testing.T, mode testMode) {
925 for _, timeout := range []time.Duration{0, -1} {
926 cst := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
927 _, err := io.Copy(res, neverEnding('a'))
928 t.Logf("server write response: %v", err)
929 }), func(ts *httptest.Server) {
930 ts.Config.WriteTimeout = timeout
931 t.Logf("Server.Config.WriteTimeout = %d", timeout)
932 })
933
934 res, err := cst.c.Get(cst.ts.URL)
935 if err != nil {
936 t.Fatal(err)
937 }
938 defer res.Body.Close()
939 n, err := io.CopyN(io.Discard, res.Body, 1<<20)
940 if n != 1<<20 || err != nil {
941 t.Errorf("client read response body: %d, %v", n, err)
942 }
943
944
945 res.Body.Close()
946 cst.ts.Config.Shutdown(context.Background())
947 }
948 }
949
950
951 func TestWriteDeadlineExtendedOnNewRequest(t *testing.T) {
952 run(t, testWriteDeadlineExtendedOnNewRequest)
953 }
954 func testWriteDeadlineExtendedOnNewRequest(t *testing.T, mode testMode) {
955 if testing.Short() {
956 t.Skip("skipping in short mode")
957 }
958 ts := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {}),
959 func(ts *httptest.Server) {
960 ts.Config.WriteTimeout = 250 * time.Millisecond
961 },
962 ).ts
963
964 c := ts.Client()
965
966 for i := 1; i <= 3; i++ {
967 req, err := NewRequest("GET", ts.URL, nil)
968 if err != nil {
969 t.Fatal(err)
970 }
971
972 r, err := c.Do(req)
973 if err != nil {
974 t.Fatalf("http2 Get #%d: %v", i, err)
975 }
976 r.Body.Close()
977 time.Sleep(ts.Config.WriteTimeout / 2)
978 }
979 }
980
981
982
983 func tryTimeouts(t *testing.T, testFunc func(timeout time.Duration) error) {
984 tries := []time.Duration{250 * time.Millisecond, 500 * time.Millisecond, 1 * time.Second}
985 for i, timeout := range tries {
986 err := testFunc(timeout)
987 if err == nil {
988 return
989 }
990 t.Logf("failed at %v: %v", timeout, err)
991 if i != len(tries)-1 {
992 t.Logf("retrying at %v ...", tries[i+1])
993 }
994 }
995 t.Fatal("all attempts failed")
996 }
997
998
999 func TestWriteDeadlineEnforcedPerStream(t *testing.T) {
1000 if testing.Short() {
1001 t.Skip("skipping in short mode")
1002 }
1003 setParallel(t)
1004 run(t, func(t *testing.T, mode testMode) {
1005 tryTimeouts(t, func(timeout time.Duration) error {
1006 return testWriteDeadlineEnforcedPerStream(t, mode, timeout)
1007 })
1008 })
1009 }
1010
1011 func testWriteDeadlineEnforcedPerStream(t *testing.T, mode testMode, timeout time.Duration) error {
1012 firstRequest := make(chan bool, 1)
1013 cst := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
1014 select {
1015 case firstRequest <- true:
1016
1017 default:
1018
1019 time.Sleep(timeout)
1020 }
1021 }), func(ts *httptest.Server) {
1022 ts.Config.WriteTimeout = timeout / 2
1023 })
1024 defer cst.close()
1025 ts := cst.ts
1026
1027 c := ts.Client()
1028
1029 req, err := NewRequest("GET", ts.URL, nil)
1030 if err != nil {
1031 return fmt.Errorf("NewRequest: %v", err)
1032 }
1033 r, err := c.Do(req)
1034 if err != nil {
1035 return fmt.Errorf("Get #1: %v", err)
1036 }
1037 r.Body.Close()
1038
1039 req, err = NewRequest("GET", ts.URL, nil)
1040 if err != nil {
1041 return fmt.Errorf("NewRequest: %v", err)
1042 }
1043 r, err = c.Do(req)
1044 if err == nil {
1045 r.Body.Close()
1046 return fmt.Errorf("Get #2 expected error, got nil")
1047 }
1048 if mode == http2Mode {
1049 expected := "stream ID 3; INTERNAL_ERROR"
1050 if !strings.Contains(err.Error(), expected) {
1051 return fmt.Errorf("http2 Get #2: expected error to contain %q, got %q", expected, err)
1052 }
1053 }
1054 return nil
1055 }
1056
1057
1058 func TestNoWriteDeadline(t *testing.T) {
1059 if testing.Short() {
1060 t.Skip("skipping in short mode")
1061 }
1062 setParallel(t)
1063 defer afterTest(t)
1064 run(t, func(t *testing.T, mode testMode) {
1065 tryTimeouts(t, func(timeout time.Duration) error {
1066 return testNoWriteDeadline(t, mode, timeout)
1067 })
1068 })
1069 }
1070
1071 func testNoWriteDeadline(t *testing.T, mode testMode, timeout time.Duration) error {
1072 firstRequest := make(chan bool, 1)
1073 cst := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
1074 select {
1075 case firstRequest <- true:
1076
1077 default:
1078
1079 time.Sleep(timeout)
1080 }
1081 }))
1082 defer cst.close()
1083 ts := cst.ts
1084
1085 c := ts.Client()
1086
1087 for i := 0; i < 2; i++ {
1088 req, err := NewRequest("GET", ts.URL, nil)
1089 if err != nil {
1090 return fmt.Errorf("NewRequest: %v", err)
1091 }
1092 r, err := c.Do(req)
1093 if err != nil {
1094 return fmt.Errorf("Get #%d: %v", i, err)
1095 }
1096 r.Body.Close()
1097 }
1098 return nil
1099 }
1100
1101
1102
1103
1104 func TestOnlyWriteTimeout(t *testing.T) { run(t, testOnlyWriteTimeout, []testMode{http1Mode}) }
1105 func testOnlyWriteTimeout(t *testing.T, mode testMode) {
1106 var (
1107 mu sync.RWMutex
1108 conn net.Conn
1109 )
1110 var afterTimeoutErrc = make(chan error, 1)
1111 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, req *Request) {
1112 buf := make([]byte, 512<<10)
1113 _, err := w.Write(buf)
1114 if err != nil {
1115 t.Errorf("handler Write error: %v", err)
1116 return
1117 }
1118 mu.RLock()
1119 defer mu.RUnlock()
1120 if conn == nil {
1121 t.Error("no established connection found")
1122 return
1123 }
1124 conn.SetWriteDeadline(time.Now().Add(-30 * time.Second))
1125 _, err = w.Write(buf)
1126 afterTimeoutErrc <- err
1127 }), func(ts *httptest.Server) {
1128 ts.Listener = trackLastConnListener{ts.Listener, &mu, &conn}
1129 }).ts
1130
1131 c := ts.Client()
1132
1133 err := func() error {
1134 res, err := c.Get(ts.URL)
1135 if err != nil {
1136 return err
1137 }
1138 _, err = io.Copy(io.Discard, res.Body)
1139 res.Body.Close()
1140 return err
1141 }()
1142 if err == nil {
1143 t.Errorf("expected an error copying body from Get request")
1144 }
1145
1146 if err := <-afterTimeoutErrc; err == nil {
1147 t.Error("expected write error after timeout")
1148 }
1149 }
1150
1151
1152 type trackLastConnListener struct {
1153 net.Listener
1154
1155 mu *sync.RWMutex
1156 last *net.Conn
1157 }
1158
1159 func (l trackLastConnListener) Accept() (c net.Conn, err error) {
1160 c, err = l.Listener.Accept()
1161 if err == nil {
1162 l.mu.Lock()
1163 *l.last = c
1164 l.mu.Unlock()
1165 }
1166 return
1167 }
1168
1169
1170 func TestIdentityResponse(t *testing.T) { run(t, testIdentityResponse) }
1171 func testIdentityResponse(t *testing.T, mode testMode) {
1172 if mode == http2Mode {
1173 t.Skip("https://go.dev/issue/56019")
1174 }
1175
1176 handler := HandlerFunc(func(rw ResponseWriter, req *Request) {
1177 rw.Header().Set("Content-Length", "3")
1178 rw.Header().Set("Transfer-Encoding", req.FormValue("te"))
1179 switch {
1180 case req.FormValue("overwrite") == "1":
1181 _, err := rw.Write([]byte("foo TOO LONG"))
1182 if err != ErrContentLength {
1183 t.Errorf("expected ErrContentLength; got %v", err)
1184 }
1185 case req.FormValue("underwrite") == "1":
1186 rw.Header().Set("Content-Length", "500")
1187 rw.Write([]byte("too short"))
1188 default:
1189 rw.Write([]byte("foo"))
1190 }
1191 })
1192
1193 ts := newClientServerTest(t, mode, handler).ts
1194 c := ts.Client()
1195
1196
1197
1198
1199
1200 for _, te := range []string{"", "identity"} {
1201 url := ts.URL + "/?te=" + te
1202 res, err := c.Get(url)
1203 if err != nil {
1204 t.Fatalf("error with Get of %s: %v", url, err)
1205 }
1206 if cl, expected := res.ContentLength, int64(3); cl != expected {
1207 t.Errorf("for %s expected res.ContentLength of %d; got %d", url, expected, cl)
1208 }
1209 if cl, expected := res.Header.Get("Content-Length"), "3"; cl != expected {
1210 t.Errorf("for %s expected Content-Length header of %q; got %q", url, expected, cl)
1211 }
1212 if tl, expected := len(res.TransferEncoding), 0; tl != expected {
1213 t.Errorf("for %s expected len(res.TransferEncoding) of %d; got %d (%v)",
1214 url, expected, tl, res.TransferEncoding)
1215 }
1216 res.Body.Close()
1217 }
1218
1219
1220 url := ts.URL + "/?overwrite=1"
1221 res, err := c.Get(url)
1222 if err != nil {
1223 t.Fatalf("error with Get of %s: %v", url, err)
1224 }
1225 res.Body.Close()
1226
1227 if mode != http1Mode {
1228 return
1229 }
1230
1231
1232
1233 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
1234 if err != nil {
1235 t.Fatalf("error dialing: %v", err)
1236 }
1237 _, err = conn.Write([]byte("GET /?underwrite=1 HTTP/1.1\r\nHost: foo\r\n\r\n"))
1238 if err != nil {
1239 t.Fatalf("error writing: %v", err)
1240 }
1241
1242
1243 got, _ := io.ReadAll(conn)
1244 expectedSuffix := "\r\n\r\ntoo short"
1245 if !strings.HasSuffix(string(got), expectedSuffix) {
1246 t.Errorf("Expected output to end with %q; got response body %q",
1247 expectedSuffix, string(got))
1248 }
1249 }
1250
1251 func testTCPConnectionCloses(t *testing.T, req string, h Handler) {
1252 setParallel(t)
1253 s := newClientServerTest(t, http1Mode, h).ts
1254
1255 conn, err := net.Dial("tcp", s.Listener.Addr().String())
1256 if err != nil {
1257 t.Fatal("dial error:", err)
1258 }
1259 defer conn.Close()
1260
1261 _, err = fmt.Fprint(conn, req)
1262 if err != nil {
1263 t.Fatal("print error:", err)
1264 }
1265
1266 r := bufio.NewReader(conn)
1267 res, err := ReadResponse(r, &Request{Method: "GET"})
1268 if err != nil {
1269 t.Fatal("ReadResponse error:", err)
1270 }
1271
1272 _, err = io.ReadAll(r)
1273 if err != nil {
1274 t.Fatal("read error:", err)
1275 }
1276
1277 if !res.Close {
1278 t.Errorf("Response.Close = false; want true")
1279 }
1280 }
1281
1282 func testTCPConnectionStaysOpen(t *testing.T, req string, handler Handler) {
1283 setParallel(t)
1284 ts := newClientServerTest(t, http1Mode, handler).ts
1285 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
1286 if err != nil {
1287 t.Fatal(err)
1288 }
1289 defer conn.Close()
1290 br := bufio.NewReader(conn)
1291 for i := 0; i < 2; i++ {
1292 if _, err := io.WriteString(conn, req); err != nil {
1293 t.Fatal(err)
1294 }
1295 res, err := ReadResponse(br, nil)
1296 if err != nil {
1297 t.Fatalf("res %d: %v", i+1, err)
1298 }
1299 if _, err := io.Copy(io.Discard, res.Body); err != nil {
1300 t.Fatalf("res %d body copy: %v", i+1, err)
1301 }
1302 res.Body.Close()
1303 }
1304 }
1305
1306
1307 func TestServeHTTP10Close(t *testing.T) {
1308 testTCPConnectionCloses(t, "GET / HTTP/1.0\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
1309 ServeFile(w, r, "testdata/file")
1310 }))
1311 }
1312
1313
1314 func TestClientCanClose(t *testing.T) {
1315 testTCPConnectionCloses(t, "GET / HTTP/1.1\r\nHost: foo\r\nConnection: close\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
1316
1317 }))
1318 }
1319
1320
1321
1322 func TestHandlersCanSetConnectionClose11(t *testing.T) {
1323 testTCPConnectionCloses(t, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
1324 w.Header().Set("Connection", "close")
1325 }))
1326 }
1327
1328 func TestHandlersCanSetConnectionClose10(t *testing.T) {
1329 testTCPConnectionCloses(t, "GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
1330 w.Header().Set("Connection", "close")
1331 }))
1332 }
1333
1334 func TestHTTP2UpgradeClosesConnection(t *testing.T) {
1335 testTCPConnectionCloses(t, "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
1336
1337
1338 }))
1339 }
1340
1341 func send204(w ResponseWriter, r *Request) { w.WriteHeader(204) }
1342 func send304(w ResponseWriter, r *Request) { w.WriteHeader(304) }
1343
1344
1345 func TestHTTP10KeepAlive204Response(t *testing.T) {
1346 testTCPConnectionStaysOpen(t, "GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n", HandlerFunc(send204))
1347 }
1348
1349 func TestHTTP11KeepAlive204Response(t *testing.T) {
1350 testTCPConnectionStaysOpen(t, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n", HandlerFunc(send204))
1351 }
1352
1353 func TestHTTP10KeepAlive304Response(t *testing.T) {
1354 testTCPConnectionStaysOpen(t,
1355 "GET / HTTP/1.0\r\nConnection: keep-alive\r\nIf-Modified-Since: Mon, 02 Jan 2006 15:04:05 GMT\r\n\r\n",
1356 HandlerFunc(send304))
1357 }
1358
1359
1360 func TestKeepAliveFinalChunkWithEOF(t *testing.T) { run(t, testKeepAliveFinalChunkWithEOF) }
1361 func testKeepAliveFinalChunkWithEOF(t *testing.T, mode testMode) {
1362 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1363 w.(Flusher).Flush()
1364 w.Write([]byte("{\"Addr\": \"" + r.RemoteAddr + "\"}"))
1365 }))
1366 type data struct {
1367 Addr string
1368 }
1369 var addrs [2]data
1370 for i := range addrs {
1371 res, err := cst.c.Get(cst.ts.URL)
1372 if err != nil {
1373 t.Fatal(err)
1374 }
1375 if err := json.NewDecoder(res.Body).Decode(&addrs[i]); err != nil {
1376 t.Fatal(err)
1377 }
1378 if addrs[i].Addr == "" {
1379 t.Fatal("no address")
1380 }
1381 res.Body.Close()
1382 }
1383 if addrs[0] != addrs[1] {
1384 t.Fatalf("connection not reused")
1385 }
1386 }
1387
1388 func TestSetsRemoteAddr(t *testing.T) { run(t, testSetsRemoteAddr) }
1389 func testSetsRemoteAddr(t *testing.T, mode testMode) {
1390 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1391 fmt.Fprintf(w, "%s", r.RemoteAddr)
1392 }))
1393
1394 res, err := cst.c.Get(cst.ts.URL)
1395 if err != nil {
1396 t.Fatalf("Get error: %v", err)
1397 }
1398 body, err := io.ReadAll(res.Body)
1399 if err != nil {
1400 t.Fatalf("ReadAll error: %v", err)
1401 }
1402 ip := string(body)
1403 if !strings.HasPrefix(ip, "127.0.0.1:") && !strings.HasPrefix(ip, "[::1]:") {
1404 t.Fatalf("Expected local addr; got %q", ip)
1405 }
1406 }
1407
1408 type blockingRemoteAddrListener struct {
1409 net.Listener
1410 conns chan<- net.Conn
1411 }
1412
1413 func (l *blockingRemoteAddrListener) Accept() (net.Conn, error) {
1414 c, err := l.Listener.Accept()
1415 if err != nil {
1416 return nil, err
1417 }
1418 brac := &blockingRemoteAddrConn{
1419 Conn: c,
1420 addrs: make(chan net.Addr, 1),
1421 }
1422 l.conns <- brac
1423 return brac, nil
1424 }
1425
1426 type blockingRemoteAddrConn struct {
1427 net.Conn
1428 addrs chan net.Addr
1429 }
1430
1431 func (c *blockingRemoteAddrConn) RemoteAddr() net.Addr {
1432 return <-c.addrs
1433 }
1434
1435
1436 func TestServerAllowsBlockingRemoteAddr(t *testing.T) {
1437 run(t, testServerAllowsBlockingRemoteAddr, []testMode{http1Mode})
1438 }
1439 func testServerAllowsBlockingRemoteAddr(t *testing.T, mode testMode) {
1440 conns := make(chan net.Conn)
1441 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1442 fmt.Fprintf(w, "RA:%s", r.RemoteAddr)
1443 }), func(ts *httptest.Server) {
1444 ts.Listener = &blockingRemoteAddrListener{
1445 Listener: ts.Listener,
1446 conns: conns,
1447 }
1448 }).ts
1449
1450 c := ts.Client()
1451
1452 c.Transport.(*Transport).DisableKeepAlives = true
1453
1454 fetch := func(num int, response chan<- string) {
1455 resp, err := c.Get(ts.URL)
1456 if err != nil {
1457 t.Errorf("Request %d: %v", num, err)
1458 response <- ""
1459 return
1460 }
1461 defer resp.Body.Close()
1462 body, err := io.ReadAll(resp.Body)
1463 if err != nil {
1464 t.Errorf("Request %d: %v", num, err)
1465 response <- ""
1466 return
1467 }
1468 response <- string(body)
1469 }
1470
1471
1472 response1c := make(chan string, 1)
1473 go fetch(1, response1c)
1474
1475
1476 conn1 := <-conns
1477
1478
1479 response2c := make(chan string, 1)
1480 go fetch(2, response2c)
1481 conn2 := <-conns
1482
1483
1484 conn2.(*blockingRemoteAddrConn).addrs <- &net.TCPAddr{
1485 IP: net.ParseIP("12.12.12.12"), Port: 12}
1486
1487
1488 response2 := <-response2c
1489 if g, e := response2, "RA:12.12.12.12:12"; g != e {
1490 t.Fatalf("response 2 addr = %q; want %q", g, e)
1491 }
1492
1493
1494 conn1.(*blockingRemoteAddrConn).addrs <- &net.TCPAddr{
1495 IP: net.ParseIP("21.21.21.21"), Port: 21}
1496
1497
1498 response1 := <-response1c
1499 if g, e := response1, "RA:21.21.21.21:21"; g != e {
1500 t.Fatalf("response 1 addr = %q; want %q", g, e)
1501 }
1502 }
1503
1504
1505
1506 func TestHeadResponses(t *testing.T) { run(t, testHeadResponses) }
1507 func testHeadResponses(t *testing.T, mode testMode) {
1508 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1509 _, err := w.Write([]byte("<html>"))
1510 if err != nil {
1511 t.Errorf("ResponseWriter.Write: %v", err)
1512 }
1513
1514
1515 _, err = io.Copy(w, struct{ io.Reader }{strings.NewReader("789a")})
1516 if err != nil {
1517 t.Errorf("Copy(ResponseWriter, ...): %v", err)
1518 }
1519 }))
1520 res, err := cst.c.Head(cst.ts.URL)
1521 if err != nil {
1522 t.Error(err)
1523 }
1524 if len(res.TransferEncoding) > 0 {
1525 t.Errorf("expected no TransferEncoding; got %v", res.TransferEncoding)
1526 }
1527 if ct := res.Header.Get("Content-Type"); ct != "text/html; charset=utf-8" {
1528 t.Errorf("Content-Type: %q; want text/html; charset=utf-8", ct)
1529 }
1530 if v := res.ContentLength; v != 10 {
1531 t.Errorf("Content-Length: %d; want 10", v)
1532 }
1533 body, err := io.ReadAll(res.Body)
1534 if err != nil {
1535 t.Error(err)
1536 }
1537 if len(body) > 0 {
1538 t.Errorf("got unexpected body %q", string(body))
1539 }
1540 }
1541
1542
1543
1544 func TestHeadReaderFrom(t *testing.T) { run(t, testHeadReaderFrom, []testMode{http1Mode}) }
1545 func testHeadReaderFrom(t *testing.T, mode testMode) {
1546
1547 wantBody := strings.Repeat("a", 4096)
1548 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1549 w.(io.ReaderFrom).ReadFrom(strings.NewReader(wantBody))
1550 }))
1551 res, err := cst.c.Head(cst.ts.URL)
1552 if err != nil {
1553 t.Fatal(err)
1554 }
1555 res.Body.Close()
1556 res, err = cst.c.Get(cst.ts.URL)
1557 if err != nil {
1558 t.Fatal(err)
1559 }
1560 gotBody, err := io.ReadAll(res.Body)
1561 res.Body.Close()
1562 if err != nil {
1563 t.Fatal(err)
1564 }
1565 if string(gotBody) != wantBody {
1566 t.Errorf("got unexpected body len=%v, want %v", len(gotBody), len(wantBody))
1567 }
1568 }
1569
1570 func TestTLSHandshakeTimeout(t *testing.T) {
1571 run(t, testTLSHandshakeTimeout, []testMode{https1Mode, http2Mode})
1572 }
1573 func testTLSHandshakeTimeout(t *testing.T, mode testMode) {
1574 errLog := new(strings.Builder)
1575 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}),
1576 func(ts *httptest.Server) {
1577 ts.Config.ReadTimeout = 250 * time.Millisecond
1578 ts.Config.ErrorLog = log.New(errLog, "", 0)
1579 },
1580 )
1581 ts := cst.ts
1582
1583 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
1584 if err != nil {
1585 t.Fatalf("Dial: %v", err)
1586 }
1587 var buf [1]byte
1588 n, err := conn.Read(buf[:])
1589 if err == nil || n != 0 {
1590 t.Errorf("Read = %d, %v; want an error and no bytes", n, err)
1591 }
1592 conn.Close()
1593
1594 cst.close()
1595 if v := errLog.String(); !strings.Contains(v, "timeout") && !strings.Contains(v, "TLS handshake") {
1596 t.Errorf("expected a TLS handshake timeout error; got %q", v)
1597 }
1598 }
1599
1600 func TestTLSServer(t *testing.T) { run(t, testTLSServer, []testMode{https1Mode, http2Mode}) }
1601 func testTLSServer(t *testing.T, mode testMode) {
1602 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1603 if r.TLS != nil {
1604 w.Header().Set("X-TLS-Set", "true")
1605 if r.TLS.HandshakeComplete {
1606 w.Header().Set("X-TLS-HandshakeComplete", "true")
1607 }
1608 }
1609 }), func(ts *httptest.Server) {
1610 ts.Config.ErrorLog = log.New(io.Discard, "", 0)
1611 }).ts
1612
1613
1614
1615
1616
1617
1618 idleConn, err := net.Dial("tcp", ts.Listener.Addr().String())
1619 if err != nil {
1620 t.Fatalf("Dial: %v", err)
1621 }
1622 defer idleConn.Close()
1623
1624 if !strings.HasPrefix(ts.URL, "https://") {
1625 t.Errorf("expected test TLS server to start with https://, got %q", ts.URL)
1626 return
1627 }
1628 client := ts.Client()
1629 res, err := client.Get(ts.URL)
1630 if err != nil {
1631 t.Error(err)
1632 return
1633 }
1634 if res == nil {
1635 t.Errorf("got nil Response")
1636 return
1637 }
1638 defer res.Body.Close()
1639 if res.Header.Get("X-TLS-Set") != "true" {
1640 t.Errorf("expected X-TLS-Set response header")
1641 return
1642 }
1643 if res.Header.Get("X-TLS-HandshakeComplete") != "true" {
1644 t.Errorf("expected X-TLS-HandshakeComplete header")
1645 }
1646 }
1647
1648 func TestServeTLS(t *testing.T) {
1649 CondSkipHTTP2(t)
1650
1651 defer afterTest(t)
1652 defer SetTestHookServerServe(nil)
1653
1654 cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
1655 if err != nil {
1656 t.Fatal(err)
1657 }
1658 tlsConf := &tls.Config{
1659 Certificates: []tls.Certificate{cert},
1660 }
1661
1662 ln := newLocalListener(t)
1663 defer ln.Close()
1664 addr := ln.Addr().String()
1665
1666 serving := make(chan bool, 1)
1667 SetTestHookServerServe(func(s *Server, ln net.Listener) {
1668 serving <- true
1669 })
1670 handler := HandlerFunc(func(w ResponseWriter, r *Request) {})
1671 s := &Server{
1672 Addr: addr,
1673 TLSConfig: tlsConf,
1674 Handler: handler,
1675 }
1676 errc := make(chan error, 1)
1677 go func() { errc <- s.ServeTLS(ln, "", "") }()
1678 select {
1679 case err := <-errc:
1680 t.Fatalf("ServeTLS: %v", err)
1681 case <-serving:
1682 }
1683
1684 c, err := tls.Dial("tcp", ln.Addr().String(), &tls.Config{
1685 InsecureSkipVerify: true,
1686 NextProtos: []string{"h2", "http/1.1"},
1687 })
1688 if err != nil {
1689 t.Fatal(err)
1690 }
1691 defer c.Close()
1692 if got, want := c.ConnectionState().NegotiatedProtocol, "h2"; got != want {
1693 t.Errorf("NegotiatedProtocol = %q; want %q", got, want)
1694 }
1695 if got, want := c.ConnectionState().NegotiatedProtocolIsMutual, true; got != want {
1696 t.Errorf("NegotiatedProtocolIsMutual = %v; want %v", got, want)
1697 }
1698 }
1699
1700
1701 func TestTLSServerRejectHTTPRequests(t *testing.T) {
1702 run(t, testTLSServerRejectHTTPRequests, []testMode{https1Mode, http2Mode})
1703 }
1704 func testTLSServerRejectHTTPRequests(t *testing.T, mode testMode) {
1705 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1706 t.Error("unexpected HTTPS request")
1707 }), func(ts *httptest.Server) {
1708 var errBuf bytes.Buffer
1709 ts.Config.ErrorLog = log.New(&errBuf, "", 0)
1710 }).ts
1711 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
1712 if err != nil {
1713 t.Fatal(err)
1714 }
1715 defer conn.Close()
1716 io.WriteString(conn, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n")
1717 slurp, err := io.ReadAll(conn)
1718 if err != nil {
1719 t.Fatal(err)
1720 }
1721 const wantPrefix = "HTTP/1.0 400 Bad Request\r\n"
1722 if !strings.HasPrefix(string(slurp), wantPrefix) {
1723 t.Errorf("response = %q; wanted prefix %q", slurp, wantPrefix)
1724 }
1725 }
1726
1727
1728 func TestAutomaticHTTP2_Serve_NoTLSConfig(t *testing.T) {
1729 testAutomaticHTTP2_Serve(t, nil, true)
1730 }
1731
1732 func TestAutomaticHTTP2_Serve_NonH2TLSConfig(t *testing.T) {
1733 testAutomaticHTTP2_Serve(t, &tls.Config{}, false)
1734 }
1735
1736 func TestAutomaticHTTP2_Serve_H2TLSConfig(t *testing.T) {
1737 testAutomaticHTTP2_Serve(t, &tls.Config{NextProtos: []string{"h2"}}, true)
1738 }
1739
1740 func testAutomaticHTTP2_Serve(t *testing.T, tlsConf *tls.Config, wantH2 bool) {
1741 setParallel(t)
1742 defer afterTest(t)
1743 ln := newLocalListener(t)
1744 ln.Close()
1745 var s Server
1746 s.TLSConfig = tlsConf
1747 if err := s.Serve(ln); err == nil {
1748 t.Fatal("expected an error")
1749 }
1750 gotH2 := s.TLSNextProto["h2"] != nil
1751 if gotH2 != wantH2 {
1752 t.Errorf("http2 configured = %v; want %v", gotH2, wantH2)
1753 }
1754 }
1755
1756 func TestAutomaticHTTP2_Serve_WithTLSConfig(t *testing.T) {
1757 setParallel(t)
1758 defer afterTest(t)
1759 ln := newLocalListener(t)
1760 ln.Close()
1761 var s Server
1762
1763
1764 s.TLSConfig = &tls.Config{
1765 NextProtos: []string{"h2"},
1766 }
1767 if err := s.Serve(ln); err == nil {
1768 t.Fatal("expected an error")
1769 }
1770 on := s.TLSNextProto["h2"] != nil
1771 if !on {
1772 t.Errorf("http2 wasn't automatically enabled")
1773 }
1774 }
1775
1776 func TestAutomaticHTTP2_ListenAndServe(t *testing.T) {
1777 cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
1778 if err != nil {
1779 t.Fatal(err)
1780 }
1781 testAutomaticHTTP2_ListenAndServe(t, &tls.Config{
1782 Certificates: []tls.Certificate{cert},
1783 })
1784 }
1785
1786 func TestAutomaticHTTP2_ListenAndServe_GetCertificate(t *testing.T) {
1787 cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
1788 if err != nil {
1789 t.Fatal(err)
1790 }
1791 testAutomaticHTTP2_ListenAndServe(t, &tls.Config{
1792 GetCertificate: func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
1793 return &cert, nil
1794 },
1795 })
1796 }
1797
1798 func TestAutomaticHTTP2_ListenAndServe_GetConfigForClient(t *testing.T) {
1799 cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
1800 if err != nil {
1801 t.Fatal(err)
1802 }
1803 conf := &tls.Config{
1804
1805
1806 NextProtos: []string{"h2"},
1807 Certificates: []tls.Certificate{cert},
1808 }
1809 testAutomaticHTTP2_ListenAndServe(t, &tls.Config{
1810 GetConfigForClient: func(clientHello *tls.ClientHelloInfo) (*tls.Config, error) {
1811 return conf, nil
1812 },
1813 })
1814 }
1815
1816 func testAutomaticHTTP2_ListenAndServe(t *testing.T, tlsConf *tls.Config) {
1817 CondSkipHTTP2(t)
1818
1819 defer afterTest(t)
1820 defer SetTestHookServerServe(nil)
1821 var ok bool
1822 var s *Server
1823 const maxTries = 5
1824 var ln net.Listener
1825 Try:
1826 for try := 0; try < maxTries; try++ {
1827 ln = newLocalListener(t)
1828 addr := ln.Addr().String()
1829 ln.Close()
1830 t.Logf("Got %v", addr)
1831 lnc := make(chan net.Listener, 1)
1832 SetTestHookServerServe(func(s *Server, ln net.Listener) {
1833 lnc <- ln
1834 })
1835 s = &Server{
1836 Addr: addr,
1837 TLSConfig: tlsConf,
1838 }
1839 errc := make(chan error, 1)
1840 go func() { errc <- s.ListenAndServeTLS("", "") }()
1841 select {
1842 case err := <-errc:
1843 t.Logf("On try #%v: %v", try+1, err)
1844 continue
1845 case ln = <-lnc:
1846 ok = true
1847 t.Logf("Listening on %v", ln.Addr().String())
1848 break Try
1849 }
1850 }
1851 if !ok {
1852 t.Fatalf("Failed to start up after %d tries", maxTries)
1853 }
1854 defer ln.Close()
1855 c, err := tls.Dial("tcp", ln.Addr().String(), &tls.Config{
1856 InsecureSkipVerify: true,
1857 NextProtos: []string{"h2", "http/1.1"},
1858 })
1859 if err != nil {
1860 t.Fatal(err)
1861 }
1862 defer c.Close()
1863 if got, want := c.ConnectionState().NegotiatedProtocol, "h2"; got != want {
1864 t.Errorf("NegotiatedProtocol = %q; want %q", got, want)
1865 }
1866 if got, want := c.ConnectionState().NegotiatedProtocolIsMutual, true; got != want {
1867 t.Errorf("NegotiatedProtocolIsMutual = %v; want %v", got, want)
1868 }
1869 }
1870
1871 type serverExpectTest struct {
1872 contentLength int
1873 chunked bool
1874 expectation string
1875 readBody bool
1876 expectedResponse string
1877 }
1878
1879 func expectTest(contentLength int, expectation string, readBody bool, expectedResponse string) serverExpectTest {
1880 return serverExpectTest{
1881 contentLength: contentLength,
1882 expectation: expectation,
1883 readBody: readBody,
1884 expectedResponse: expectedResponse,
1885 }
1886 }
1887
1888 var serverExpectTests = []serverExpectTest{
1889
1890 expectTest(100, "100-continue", true, "100 Continue"),
1891 expectTest(100, "100-cOntInUE", true, "100 Continue"),
1892
1893
1894 expectTest(100, "", true, "200 OK"),
1895
1896
1897
1898 expectTest(100, "100-continue", false, "401 Unauthorized"),
1899
1900 expectTest(100, "", false, "401 Unauthorized"),
1901
1902
1903 expectTest(0, "a-pony", false, "417 Expectation Failed"),
1904
1905
1906 expectTest(0, "100-continue", true, "200 OK"),
1907
1908 expectTest(0, "100-continue", false, "401 Unauthorized"),
1909
1910 {
1911 expectation: "100-continue",
1912 readBody: true,
1913 chunked: true,
1914 expectedResponse: "100 Continue",
1915 },
1916 }
1917
1918
1919
1920 func TestServerExpect(t *testing.T) { run(t, testServerExpect, []testMode{http1Mode}) }
1921 func testServerExpect(t *testing.T, mode testMode) {
1922 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1923
1924
1925
1926 if strings.Contains(r.URL.RawQuery, "readbody=true") {
1927 io.ReadAll(r.Body)
1928 w.Write([]byte("Hi"))
1929 } else {
1930 w.WriteHeader(StatusUnauthorized)
1931 }
1932 })).ts
1933
1934 runTest := func(test serverExpectTest) {
1935 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
1936 if err != nil {
1937 t.Fatalf("Dial: %v", err)
1938 }
1939 defer conn.Close()
1940
1941
1942
1943 writeBody := test.contentLength != 0 && strings.ToLower(test.expectation) != "100-continue"
1944
1945 wg := sync.WaitGroup{}
1946 wg.Add(1)
1947 defer wg.Wait()
1948
1949 go func() {
1950 defer wg.Done()
1951
1952 contentLen := fmt.Sprintf("Content-Length: %d", test.contentLength)
1953 if test.chunked {
1954 contentLen = "Transfer-Encoding: chunked"
1955 }
1956 _, err := fmt.Fprintf(conn, "POST /?readbody=%v HTTP/1.1\r\n"+
1957 "Connection: close\r\n"+
1958 "%s\r\n"+
1959 "Expect: %s\r\nHost: foo\r\n\r\n",
1960 test.readBody, contentLen, test.expectation)
1961 if err != nil {
1962 t.Errorf("On test %#v, error writing request headers: %v", test, err)
1963 return
1964 }
1965 if writeBody {
1966 var targ io.WriteCloser = struct {
1967 io.Writer
1968 io.Closer
1969 }{
1970 conn,
1971 io.NopCloser(nil),
1972 }
1973 if test.chunked {
1974 targ = httputil.NewChunkedWriter(conn)
1975 }
1976 body := strings.Repeat("A", test.contentLength)
1977 _, err = fmt.Fprint(targ, body)
1978 if err == nil {
1979 err = targ.Close()
1980 }
1981 if err != nil {
1982 if !test.readBody {
1983
1984
1985 t.Logf("On test %#v, acceptable error writing request body: %v", test, err)
1986 return
1987 }
1988 t.Errorf("On test %#v, error writing request body: %v", test, err)
1989 }
1990 }
1991 }()
1992 bufr := bufio.NewReader(conn)
1993 line, err := bufr.ReadString('\n')
1994 if err != nil {
1995 if writeBody && !test.readBody {
1996
1997
1998
1999
2000
2001 t.Logf("On test %#v, acceptable error from ReadString: %v", test, err)
2002 return
2003 }
2004 t.Fatalf("On test %#v, ReadString: %v", test, err)
2005 }
2006 if !strings.Contains(line, test.expectedResponse) {
2007 t.Errorf("On test %#v, got first line = %q; want %q", test, line, test.expectedResponse)
2008 }
2009 }
2010
2011 for _, test := range serverExpectTests {
2012 runTest(test)
2013 }
2014 }
2015
2016
2017
2018 func TestServerUnreadRequestBodyLittle(t *testing.T) {
2019 setParallel(t)
2020 defer afterTest(t)
2021 conn := new(testConn)
2022 body := strings.Repeat("x", 100<<10)
2023 conn.readBuf.Write([]byte(fmt.Sprintf(
2024 "POST / HTTP/1.1\r\n"+
2025 "Host: test\r\n"+
2026 "Content-Length: %d\r\n"+
2027 "\r\n", len(body))))
2028 conn.readBuf.Write([]byte(body))
2029
2030 done := make(chan bool)
2031
2032 readBufLen := func() int {
2033 conn.readMu.Lock()
2034 defer conn.readMu.Unlock()
2035 return conn.readBuf.Len()
2036 }
2037
2038 ls := &oneConnListener{conn}
2039 go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
2040 defer close(done)
2041 if bufLen := readBufLen(); bufLen < len(body)/2 {
2042 t.Errorf("on request, read buffer length is %d; expected about 100 KB", bufLen)
2043 }
2044 rw.WriteHeader(200)
2045 rw.(Flusher).Flush()
2046 if g, e := readBufLen(), 0; g != e {
2047 t.Errorf("after WriteHeader, read buffer length is %d; want %d", g, e)
2048 }
2049 if c := rw.Header().Get("Connection"); c != "" {
2050 t.Errorf(`Connection header = %q; want ""`, c)
2051 }
2052 }))
2053 <-done
2054 }
2055
2056
2057
2058
2059 func TestServerUnreadRequestBodyLarge(t *testing.T) {
2060 setParallel(t)
2061 if testing.Short() && testenv.Builder() == "" {
2062 t.Log("skipping in short mode")
2063 }
2064 conn := new(testConn)
2065 body := strings.Repeat("x", 1<<20)
2066 conn.readBuf.Write([]byte(fmt.Sprintf(
2067 "POST / HTTP/1.1\r\n"+
2068 "Host: test\r\n"+
2069 "Content-Length: %d\r\n"+
2070 "\r\n", len(body))))
2071 conn.readBuf.Write([]byte(body))
2072 conn.closec = make(chan bool, 1)
2073
2074 ls := &oneConnListener{conn}
2075 go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
2076 if conn.readBuf.Len() < len(body)/2 {
2077 t.Errorf("on request, read buffer length is %d; expected about 1MB", conn.readBuf.Len())
2078 }
2079 rw.WriteHeader(200)
2080 rw.(Flusher).Flush()
2081 if conn.readBuf.Len() < len(body)/2 {
2082 t.Errorf("post-WriteHeader, read buffer length is %d; expected about 1MB", conn.readBuf.Len())
2083 }
2084 }))
2085 <-conn.closec
2086
2087 if res := conn.writeBuf.String(); !strings.Contains(res, "Connection: close") {
2088 t.Errorf("Expected a Connection: close header; got response: %s", res)
2089 }
2090 }
2091
2092 type handlerBodyCloseTest struct {
2093 bodySize int
2094 bodyChunked bool
2095 reqConnClose bool
2096
2097 wantEOFSearch bool
2098 wantNextReq bool
2099 }
2100
2101 func (t handlerBodyCloseTest) connectionHeader() string {
2102 if t.reqConnClose {
2103 return "Connection: close\r\n"
2104 }
2105 return ""
2106 }
2107
2108 var handlerBodyCloseTests = [...]handlerBodyCloseTest{
2109
2110
2111 0: {
2112 bodySize: 20 << 10,
2113 bodyChunked: false,
2114 reqConnClose: false,
2115 wantEOFSearch: true,
2116 wantNextReq: true,
2117 },
2118
2119
2120
2121 1: {
2122 bodySize: 20 << 10,
2123 bodyChunked: true,
2124 reqConnClose: false,
2125 wantEOFSearch: true,
2126 wantNextReq: true,
2127 },
2128
2129
2130
2131
2132 2: {
2133 bodySize: 20 << 10,
2134 bodyChunked: false,
2135 reqConnClose: true,
2136 wantEOFSearch: false,
2137 wantNextReq: false,
2138 },
2139
2140
2141
2142
2143
2144
2145 3: {
2146 bodySize: 20 << 10,
2147 bodyChunked: true,
2148 reqConnClose: true,
2149 wantEOFSearch: true,
2150 wantNextReq: false,
2151 },
2152
2153
2154 4: {
2155 bodySize: 1 << 20,
2156 bodyChunked: false,
2157 reqConnClose: false,
2158 wantEOFSearch: false,
2159 wantNextReq: false,
2160 },
2161
2162
2163 5: {
2164 bodySize: 1 << 20,
2165 bodyChunked: true,
2166 reqConnClose: false,
2167 wantEOFSearch: true,
2168 wantNextReq: false,
2169 },
2170
2171
2172
2173
2174 6: {
2175 bodySize: 1 << 20,
2176 bodyChunked: true,
2177 reqConnClose: true,
2178 wantEOFSearch: true,
2179 wantNextReq: false,
2180 },
2181
2182
2183
2184 7: {
2185 bodySize: 1 << 20,
2186 bodyChunked: false,
2187 reqConnClose: true,
2188 wantEOFSearch: false,
2189 wantNextReq: false,
2190 },
2191 }
2192
2193 func TestHandlerBodyClose(t *testing.T) {
2194 setParallel(t)
2195 if testing.Short() && testenv.Builder() == "" {
2196 t.Skip("skipping in -short mode")
2197 }
2198 for i, tt := range handlerBodyCloseTests {
2199 testHandlerBodyClose(t, i, tt)
2200 }
2201 }
2202
2203 func testHandlerBodyClose(t *testing.T, i int, tt handlerBodyCloseTest) {
2204 conn := new(testConn)
2205 body := strings.Repeat("x", tt.bodySize)
2206 if tt.bodyChunked {
2207 conn.readBuf.WriteString("POST / HTTP/1.1\r\n" +
2208 "Host: test\r\n" +
2209 tt.connectionHeader() +
2210 "Transfer-Encoding: chunked\r\n" +
2211 "\r\n")
2212 cw := internal.NewChunkedWriter(&conn.readBuf)
2213 io.WriteString(cw, body)
2214 cw.Close()
2215 conn.readBuf.WriteString("\r\n")
2216 } else {
2217 conn.readBuf.Write([]byte(fmt.Sprintf(
2218 "POST / HTTP/1.1\r\n"+
2219 "Host: test\r\n"+
2220 tt.connectionHeader()+
2221 "Content-Length: %d\r\n"+
2222 "\r\n", len(body))))
2223 conn.readBuf.Write([]byte(body))
2224 }
2225 if !tt.reqConnClose {
2226 conn.readBuf.WriteString("GET / HTTP/1.1\r\nHost: test\r\n\r\n")
2227 }
2228 conn.closec = make(chan bool, 1)
2229
2230 readBufLen := func() int {
2231 conn.readMu.Lock()
2232 defer conn.readMu.Unlock()
2233 return conn.readBuf.Len()
2234 }
2235
2236 ls := &oneConnListener{conn}
2237 var numReqs int
2238 var size0, size1 int
2239 go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
2240 numReqs++
2241 if numReqs == 1 {
2242 size0 = readBufLen()
2243 req.Body.Close()
2244 size1 = readBufLen()
2245 }
2246 }))
2247 <-conn.closec
2248 if numReqs < 1 || numReqs > 2 {
2249 t.Fatalf("%d. bug in test. unexpected number of requests = %d", i, numReqs)
2250 }
2251 didSearch := size0 != size1
2252 if didSearch != tt.wantEOFSearch {
2253 t.Errorf("%d. did EOF search = %v; want %v (size went from %d to %d)", i, didSearch, !didSearch, size0, size1)
2254 }
2255 if tt.wantNextReq && numReqs != 2 {
2256 t.Errorf("%d. numReq = %d; want 2", i, numReqs)
2257 }
2258 }
2259
2260
2261
2262 type testHandlerBodyConsumer struct {
2263 name string
2264 f func(io.ReadCloser)
2265 }
2266
2267 var testHandlerBodyConsumers = []testHandlerBodyConsumer{
2268 {"nil", func(io.ReadCloser) {}},
2269 {"close", func(r io.ReadCloser) { r.Close() }},
2270 {"discard", func(r io.ReadCloser) { io.Copy(io.Discard, r) }},
2271 }
2272
2273 func TestRequestBodyReadErrorClosesConnection(t *testing.T) {
2274 setParallel(t)
2275 defer afterTest(t)
2276 for _, handler := range testHandlerBodyConsumers {
2277 conn := new(testConn)
2278 conn.readBuf.WriteString("POST /public HTTP/1.1\r\n" +
2279 "Host: test\r\n" +
2280 "Transfer-Encoding: chunked\r\n" +
2281 "\r\n" +
2282 "hax\r\n" +
2283 "GET /secret HTTP/1.1\r\n" +
2284 "Host: test\r\n" +
2285 "\r\n")
2286
2287 conn.closec = make(chan bool, 1)
2288 ls := &oneConnListener{conn}
2289 var numReqs int
2290 go Serve(ls, HandlerFunc(func(_ ResponseWriter, req *Request) {
2291 numReqs++
2292 if strings.Contains(req.URL.Path, "secret") {
2293 t.Error("Request for /secret encountered, should not have happened.")
2294 }
2295 handler.f(req.Body)
2296 }))
2297 <-conn.closec
2298 if numReqs != 1 {
2299 t.Errorf("Handler %v: got %d reqs; want 1", handler.name, numReqs)
2300 }
2301 }
2302 }
2303
2304 func TestInvalidTrailerClosesConnection(t *testing.T) {
2305 setParallel(t)
2306 defer afterTest(t)
2307 for _, handler := range testHandlerBodyConsumers {
2308 conn := new(testConn)
2309 conn.readBuf.WriteString("POST /public HTTP/1.1\r\n" +
2310 "Host: test\r\n" +
2311 "Trailer: hack\r\n" +
2312 "Transfer-Encoding: chunked\r\n" +
2313 "\r\n" +
2314 "3\r\n" +
2315 "hax\r\n" +
2316 "0\r\n" +
2317 "I'm not a valid trailer\r\n" +
2318 "GET /secret HTTP/1.1\r\n" +
2319 "Host: test\r\n" +
2320 "\r\n")
2321
2322 conn.closec = make(chan bool, 1)
2323 ln := &oneConnListener{conn}
2324 var numReqs int
2325 go Serve(ln, HandlerFunc(func(_ ResponseWriter, req *Request) {
2326 numReqs++
2327 if strings.Contains(req.URL.Path, "secret") {
2328 t.Errorf("Handler %s, Request for /secret encountered, should not have happened.", handler.name)
2329 }
2330 handler.f(req.Body)
2331 }))
2332 <-conn.closec
2333 if numReqs != 1 {
2334 t.Errorf("Handler %s: got %d reqs; want 1", handler.name, numReqs)
2335 }
2336 }
2337 }
2338
2339
2340
2341
2342 type slowTestConn struct {
2343
2344 script []any
2345 closec chan bool
2346
2347 mu sync.Mutex
2348 rd, wd time.Time
2349 noopConn
2350 }
2351
2352 func (c *slowTestConn) SetDeadline(t time.Time) error {
2353 c.SetReadDeadline(t)
2354 c.SetWriteDeadline(t)
2355 return nil
2356 }
2357
2358 func (c *slowTestConn) SetReadDeadline(t time.Time) error {
2359 c.mu.Lock()
2360 defer c.mu.Unlock()
2361 c.rd = t
2362 return nil
2363 }
2364
2365 func (c *slowTestConn) SetWriteDeadline(t time.Time) error {
2366 c.mu.Lock()
2367 defer c.mu.Unlock()
2368 c.wd = t
2369 return nil
2370 }
2371
2372 func (c *slowTestConn) Read(b []byte) (n int, err error) {
2373 c.mu.Lock()
2374 defer c.mu.Unlock()
2375 restart:
2376 if !c.rd.IsZero() && time.Now().After(c.rd) {
2377 return 0, syscall.ETIMEDOUT
2378 }
2379 if len(c.script) == 0 {
2380 return 0, io.EOF
2381 }
2382
2383 switch cue := c.script[0].(type) {
2384 case time.Duration:
2385 if !c.rd.IsZero() {
2386
2387
2388 if remaining := time.Until(c.rd); remaining < cue {
2389 c.script[0] = cue - remaining
2390 time.Sleep(remaining)
2391 return 0, syscall.ETIMEDOUT
2392 }
2393 }
2394 c.script = c.script[1:]
2395 time.Sleep(cue)
2396 goto restart
2397
2398 case string:
2399 n = copy(b, cue)
2400
2401 if len(cue) > n {
2402 c.script[0] = cue[n:]
2403 } else {
2404 c.script = c.script[1:]
2405 }
2406
2407 default:
2408 panic("unknown cue in slowTestConn script")
2409 }
2410
2411 return
2412 }
2413
2414 func (c *slowTestConn) Close() error {
2415 select {
2416 case c.closec <- true:
2417 default:
2418 }
2419 return nil
2420 }
2421
2422 func (c *slowTestConn) Write(b []byte) (int, error) {
2423 if !c.wd.IsZero() && time.Now().After(c.wd) {
2424 return 0, syscall.ETIMEDOUT
2425 }
2426 return len(b), nil
2427 }
2428
2429 func TestRequestBodyTimeoutClosesConnection(t *testing.T) {
2430 if testing.Short() {
2431 t.Skip("skipping in -short mode")
2432 }
2433 defer afterTest(t)
2434 for _, handler := range testHandlerBodyConsumers {
2435 conn := &slowTestConn{
2436 script: []any{
2437 "POST /public HTTP/1.1\r\n" +
2438 "Host: test\r\n" +
2439 "Content-Length: 10000\r\n" +
2440 "\r\n",
2441 "foo bar baz",
2442 600 * time.Millisecond,
2443 "GET /secret HTTP/1.1\r\n" +
2444 "Host: test\r\n" +
2445 "\r\n",
2446 },
2447 closec: make(chan bool, 1),
2448 }
2449 ls := &oneConnListener{conn}
2450
2451 var numReqs int
2452 s := Server{
2453 Handler: HandlerFunc(func(_ ResponseWriter, req *Request) {
2454 numReqs++
2455 if strings.Contains(req.URL.Path, "secret") {
2456 t.Error("Request for /secret encountered, should not have happened.")
2457 }
2458 handler.f(req.Body)
2459 }),
2460 ReadTimeout: 400 * time.Millisecond,
2461 }
2462 go s.Serve(ls)
2463 <-conn.closec
2464
2465 if numReqs != 1 {
2466 t.Errorf("Handler %v: got %d reqs; want 1", handler.name, numReqs)
2467 }
2468 }
2469 }
2470
2471
2472 type cancelableTimeoutContext struct {
2473 context.Context
2474 }
2475
2476 func (c cancelableTimeoutContext) Err() error {
2477 if c.Context.Err() != nil {
2478 return context.DeadlineExceeded
2479 }
2480 return nil
2481 }
2482
2483 func TestTimeoutHandler(t *testing.T) { run(t, testTimeoutHandler) }
2484 func testTimeoutHandler(t *testing.T, mode testMode) {
2485 sendHi := make(chan bool, 1)
2486 writeErrors := make(chan error, 1)
2487 sayHi := HandlerFunc(func(w ResponseWriter, r *Request) {
2488 <-sendHi
2489 _, werr := w.Write([]byte("hi"))
2490 writeErrors <- werr
2491 })
2492 ctx, cancel := context.WithCancel(context.Background())
2493 h := NewTestTimeoutHandler(sayHi, cancelableTimeoutContext{ctx})
2494 cst := newClientServerTest(t, mode, h)
2495
2496
2497 sendHi <- true
2498 res, err := cst.c.Get(cst.ts.URL)
2499 if err != nil {
2500 t.Error(err)
2501 }
2502 if g, e := res.StatusCode, StatusOK; g != e {
2503 t.Errorf("got res.StatusCode %d; expected %d", g, e)
2504 }
2505 body, _ := io.ReadAll(res.Body)
2506 if g, e := string(body), "hi"; g != e {
2507 t.Errorf("got body %q; expected %q", g, e)
2508 }
2509 if g := <-writeErrors; g != nil {
2510 t.Errorf("got unexpected Write error on first request: %v", g)
2511 }
2512
2513
2514 cancel()
2515
2516 res, err = cst.c.Get(cst.ts.URL)
2517 if err != nil {
2518 t.Error(err)
2519 }
2520 if g, e := res.StatusCode, StatusServiceUnavailable; g != e {
2521 t.Errorf("got res.StatusCode %d; expected %d", g, e)
2522 }
2523 body, _ = io.ReadAll(res.Body)
2524 if !strings.Contains(string(body), "<title>Timeout</title>") {
2525 t.Errorf("expected timeout body; got %q", string(body))
2526 }
2527 if g, w := res.Header.Get("Content-Type"), "text/html; charset=utf-8"; g != w {
2528 t.Errorf("response content-type = %q; want %q", g, w)
2529 }
2530
2531
2532
2533 sendHi <- true
2534 if g, e := <-writeErrors, ErrHandlerTimeout; g != e {
2535 t.Errorf("expected Write error of %v; got %v", e, g)
2536 }
2537 }
2538
2539
2540 func TestTimeoutHandlerRace(t *testing.T) { run(t, testTimeoutHandlerRace) }
2541 func testTimeoutHandlerRace(t *testing.T, mode testMode) {
2542 delayHi := HandlerFunc(func(w ResponseWriter, r *Request) {
2543 ms, _ := strconv.Atoi(r.URL.Path[1:])
2544 if ms == 0 {
2545 ms = 1
2546 }
2547 for i := 0; i < ms; i++ {
2548 w.Write([]byte("hi"))
2549 time.Sleep(time.Millisecond)
2550 }
2551 })
2552
2553 ts := newClientServerTest(t, mode, TimeoutHandler(delayHi, 20*time.Millisecond, "")).ts
2554
2555 c := ts.Client()
2556
2557 var wg sync.WaitGroup
2558 gate := make(chan bool, 10)
2559 n := 50
2560 if testing.Short() {
2561 n = 10
2562 gate = make(chan bool, 3)
2563 }
2564 for i := 0; i < n; i++ {
2565 gate <- true
2566 wg.Add(1)
2567 go func() {
2568 defer wg.Done()
2569 defer func() { <-gate }()
2570 res, err := c.Get(fmt.Sprintf("%s/%d", ts.URL, rand.Intn(50)))
2571 if err == nil {
2572 io.Copy(io.Discard, res.Body)
2573 res.Body.Close()
2574 }
2575 }()
2576 }
2577 wg.Wait()
2578 }
2579
2580
2581
2582 func TestTimeoutHandlerRaceHeader(t *testing.T) { run(t, testTimeoutHandlerRaceHeader) }
2583 func testTimeoutHandlerRaceHeader(t *testing.T, mode testMode) {
2584 delay204 := HandlerFunc(func(w ResponseWriter, r *Request) {
2585 w.WriteHeader(204)
2586 })
2587
2588 ts := newClientServerTest(t, mode, TimeoutHandler(delay204, time.Nanosecond, "")).ts
2589
2590 var wg sync.WaitGroup
2591 gate := make(chan bool, 50)
2592 n := 500
2593 if testing.Short() {
2594 n = 10
2595 }
2596
2597 c := ts.Client()
2598 for i := 0; i < n; i++ {
2599 gate <- true
2600 wg.Add(1)
2601 go func() {
2602 defer wg.Done()
2603 defer func() { <-gate }()
2604 res, err := c.Get(ts.URL)
2605 if err != nil {
2606
2607
2608 t.Log(err)
2609 return
2610 }
2611 defer res.Body.Close()
2612 io.Copy(io.Discard, res.Body)
2613 }()
2614 }
2615 wg.Wait()
2616 }
2617
2618
2619 func TestTimeoutHandlerRaceHeaderTimeout(t *testing.T) { run(t, testTimeoutHandlerRaceHeaderTimeout) }
2620 func testTimeoutHandlerRaceHeaderTimeout(t *testing.T, mode testMode) {
2621 sendHi := make(chan bool, 1)
2622 writeErrors := make(chan error, 1)
2623 sayHi := HandlerFunc(func(w ResponseWriter, r *Request) {
2624 w.Header().Set("Content-Type", "text/plain")
2625 <-sendHi
2626 _, werr := w.Write([]byte("hi"))
2627 writeErrors <- werr
2628 })
2629 ctx, cancel := context.WithCancel(context.Background())
2630 h := NewTestTimeoutHandler(sayHi, cancelableTimeoutContext{ctx})
2631 cst := newClientServerTest(t, mode, h)
2632
2633
2634 sendHi <- true
2635 res, err := cst.c.Get(cst.ts.URL)
2636 if err != nil {
2637 t.Error(err)
2638 }
2639 if g, e := res.StatusCode, StatusOK; g != e {
2640 t.Errorf("got res.StatusCode %d; expected %d", g, e)
2641 }
2642 body, _ := io.ReadAll(res.Body)
2643 if g, e := string(body), "hi"; g != e {
2644 t.Errorf("got body %q; expected %q", g, e)
2645 }
2646 if g := <-writeErrors; g != nil {
2647 t.Errorf("got unexpected Write error on first request: %v", g)
2648 }
2649
2650
2651 cancel()
2652
2653 res, err = cst.c.Get(cst.ts.URL)
2654 if err != nil {
2655 t.Error(err)
2656 }
2657 if g, e := res.StatusCode, StatusServiceUnavailable; g != e {
2658 t.Errorf("got res.StatusCode %d; expected %d", g, e)
2659 }
2660 body, _ = io.ReadAll(res.Body)
2661 if !strings.Contains(string(body), "<title>Timeout</title>") {
2662 t.Errorf("expected timeout body; got %q", string(body))
2663 }
2664
2665
2666
2667 sendHi <- true
2668 if g, e := <-writeErrors, ErrHandlerTimeout; g != e {
2669 t.Errorf("expected Write error of %v; got %v", e, g)
2670 }
2671 }
2672
2673
2674 func TestTimeoutHandlerStartTimerWhenServing(t *testing.T) {
2675 run(t, testTimeoutHandlerStartTimerWhenServing)
2676 }
2677 func testTimeoutHandlerStartTimerWhenServing(t *testing.T, mode testMode) {
2678 if testing.Short() {
2679 t.Skip("skipping sleeping test in -short mode")
2680 }
2681 var handler HandlerFunc = func(w ResponseWriter, _ *Request) {
2682 w.WriteHeader(StatusNoContent)
2683 }
2684 timeout := 300 * time.Millisecond
2685 ts := newClientServerTest(t, mode, TimeoutHandler(handler, timeout, "")).ts
2686 defer ts.Close()
2687
2688 c := ts.Client()
2689
2690
2691
2692
2693 time.Sleep(2 * timeout)
2694 res, err := c.Get(ts.URL)
2695 if err != nil {
2696 t.Fatal(err)
2697 }
2698 defer res.Body.Close()
2699 if res.StatusCode != StatusNoContent {
2700 t.Errorf("got res.StatusCode %d, want %v", res.StatusCode, StatusNoContent)
2701 }
2702 }
2703
2704 func TestTimeoutHandlerContextCanceled(t *testing.T) { run(t, testTimeoutHandlerContextCanceled) }
2705 func testTimeoutHandlerContextCanceled(t *testing.T, mode testMode) {
2706 writeErrors := make(chan error, 1)
2707 sayHi := HandlerFunc(func(w ResponseWriter, r *Request) {
2708 w.Header().Set("Content-Type", "text/plain")
2709 var err error
2710
2711
2712
2713 for i := 0; i < 100; i++ {
2714 _, err = w.Write([]byte("a"))
2715 if err != nil {
2716 break
2717 }
2718 time.Sleep(1 * time.Millisecond)
2719 }
2720 writeErrors <- err
2721 })
2722 ctx, cancel := context.WithCancel(context.Background())
2723 cancel()
2724 h := NewTestTimeoutHandler(sayHi, ctx)
2725 cst := newClientServerTest(t, mode, h)
2726 defer cst.close()
2727
2728 res, err := cst.c.Get(cst.ts.URL)
2729 if err != nil {
2730 t.Error(err)
2731 }
2732 if g, e := res.StatusCode, StatusServiceUnavailable; g != e {
2733 t.Errorf("got res.StatusCode %d; expected %d", g, e)
2734 }
2735 body, _ := io.ReadAll(res.Body)
2736 if g, e := string(body), ""; g != e {
2737 t.Errorf("got body %q; expected %q", g, e)
2738 }
2739 if g, e := <-writeErrors, context.Canceled; g != e {
2740 t.Errorf("got unexpected Write in handler: %v, want %g", g, e)
2741 }
2742 }
2743
2744
2745 func TestTimeoutHandlerEmptyResponse(t *testing.T) { run(t, testTimeoutHandlerEmptyResponse) }
2746 func testTimeoutHandlerEmptyResponse(t *testing.T, mode testMode) {
2747 var handler HandlerFunc = func(w ResponseWriter, _ *Request) {
2748
2749 }
2750 timeout := 300 * time.Millisecond
2751 ts := newClientServerTest(t, mode, TimeoutHandler(handler, timeout, "")).ts
2752
2753 c := ts.Client()
2754
2755 res, err := c.Get(ts.URL)
2756 if err != nil {
2757 t.Fatal(err)
2758 }
2759 defer res.Body.Close()
2760 if res.StatusCode != StatusOK {
2761 t.Errorf("got res.StatusCode %d, want %v", res.StatusCode, StatusOK)
2762 }
2763 }
2764
2765
2766 func TestTimeoutHandlerPanicRecovery(t *testing.T) {
2767 wrapper := func(h Handler) Handler {
2768 return TimeoutHandler(h, time.Second, "")
2769 }
2770 run(t, func(t *testing.T, mode testMode) {
2771 testHandlerPanic(t, false, mode, wrapper, "intentional death for testing")
2772 }, testNotParallel)
2773 }
2774
2775 func TestRedirectBadPath(t *testing.T) {
2776
2777
2778 rr := httptest.NewRecorder()
2779 req := &Request{
2780 Method: "GET",
2781 URL: &url.URL{
2782 Scheme: "http",
2783 Path: "not-empty-but-no-leading-slash",
2784 },
2785 }
2786 Redirect(rr, req, "", 304)
2787 if rr.Code != 304 {
2788 t.Errorf("Code = %d; want 304", rr.Code)
2789 }
2790 }
2791
2792
2793 func TestRedirect(t *testing.T) {
2794 req, _ := NewRequest("GET", "http://example.com/qux/", nil)
2795
2796 var tests = []struct {
2797 in string
2798 want string
2799 }{
2800
2801 {"http://foobar.com/baz", "http://foobar.com/baz"},
2802
2803 {"https://foobar.com/baz", "https://foobar.com/baz"},
2804
2805 {"test://foobar.com/baz", "test://foobar.com/baz"},
2806
2807 {"//foobar.com/baz", "//foobar.com/baz"},
2808
2809 {"/foobar.com/baz", "/foobar.com/baz"},
2810
2811 {"foobar.com/baz", "/qux/foobar.com/baz"},
2812
2813 {"../quux/foobar.com/baz", "/quux/foobar.com/baz"},
2814
2815 {"///foobar.com/baz", "/foobar.com/baz"},
2816
2817
2818 {"/foo?next=http://bar.com/", "/foo?next=http://bar.com/"},
2819 {"http://localhost:8080/_ah/login?continue=http://localhost:8080/",
2820 "http://localhost:8080/_ah/login?continue=http://localhost:8080/"},
2821
2822 {"/фубар", "/%d1%84%d1%83%d0%b1%d0%b0%d1%80"},
2823 {"http://foo.com/фубар", "http://foo.com/%d1%84%d1%83%d0%b1%d0%b0%d1%80"},
2824 }
2825
2826 for _, tt := range tests {
2827 rec := httptest.NewRecorder()
2828 Redirect(rec, req, tt.in, 302)
2829 if got, want := rec.Code, 302; got != want {
2830 t.Errorf("Redirect(%q) generated status code %v; want %v", tt.in, got, want)
2831 }
2832 if got := rec.Header().Get("Location"); got != tt.want {
2833 t.Errorf("Redirect(%q) generated Location header %q; want %q", tt.in, got, tt.want)
2834 }
2835 }
2836 }
2837
2838
2839
2840 func TestRedirectContentTypeAndBody(t *testing.T) {
2841 type ctHeader struct {
2842 Values []string
2843 }
2844
2845 var tests = []struct {
2846 method string
2847 ct *ctHeader
2848 wantCT string
2849 wantBody string
2850 }{
2851 {MethodGet, nil, "text/html; charset=utf-8", "<a href=\"/foo\">Found</a>.\n\n"},
2852 {MethodHead, nil, "text/html; charset=utf-8", ""},
2853 {MethodPost, nil, "", ""},
2854 {MethodDelete, nil, "", ""},
2855 {"foo", nil, "", ""},
2856 {MethodGet, &ctHeader{[]string{"application/test"}}, "application/test", ""},
2857 {MethodGet, &ctHeader{[]string{}}, "", ""},
2858 {MethodGet, &ctHeader{nil}, "", ""},
2859 }
2860 for _, tt := range tests {
2861 req := httptest.NewRequest(tt.method, "http://example.com/qux/", nil)
2862 rec := httptest.NewRecorder()
2863 if tt.ct != nil {
2864 rec.Header()["Content-Type"] = tt.ct.Values
2865 }
2866 Redirect(rec, req, "/foo", 302)
2867 if got, want := rec.Code, 302; got != want {
2868 t.Errorf("Redirect(%q, %#v) generated status code %v; want %v", tt.method, tt.ct, got, want)
2869 }
2870 if got, want := rec.Header().Get("Content-Type"), tt.wantCT; got != want {
2871 t.Errorf("Redirect(%q, %#v) generated Content-Type header %q; want %q", tt.method, tt.ct, got, want)
2872 }
2873 resp := rec.Result()
2874 body, err := io.ReadAll(resp.Body)
2875 if err != nil {
2876 t.Fatal(err)
2877 }
2878 if got, want := string(body), tt.wantBody; got != want {
2879 t.Errorf("Redirect(%q, %#v) generated Body %q; want %q", tt.method, tt.ct, got, want)
2880 }
2881 }
2882 }
2883
2884
2885
2886
2887
2888
2889
2890 func TestZeroLengthPostAndResponse(t *testing.T) { run(t, testZeroLengthPostAndResponse) }
2891
2892 func testZeroLengthPostAndResponse(t *testing.T, mode testMode) {
2893 cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, r *Request) {
2894 all, err := io.ReadAll(r.Body)
2895 if err != nil {
2896 t.Fatalf("handler ReadAll: %v", err)
2897 }
2898 if len(all) != 0 {
2899 t.Errorf("handler got %d bytes; expected 0", len(all))
2900 }
2901 rw.Header().Set("Content-Length", "0")
2902 }))
2903
2904 req, err := NewRequest("POST", cst.ts.URL, strings.NewReader(""))
2905 if err != nil {
2906 t.Fatal(err)
2907 }
2908 req.ContentLength = 0
2909
2910 var resp [5]*Response
2911 for i := range resp {
2912 resp[i], err = cst.c.Do(req)
2913 if err != nil {
2914 t.Fatalf("client post #%d: %v", i, err)
2915 }
2916 }
2917
2918 for i := range resp {
2919 all, err := io.ReadAll(resp[i].Body)
2920 if err != nil {
2921 t.Fatalf("req #%d: client ReadAll: %v", i, err)
2922 }
2923 if len(all) != 0 {
2924 t.Errorf("req #%d: client got %d bytes; expected 0", i, len(all))
2925 }
2926 }
2927 }
2928
2929 func TestHandlerPanicNil(t *testing.T) {
2930 run(t, func(t *testing.T, mode testMode) {
2931 testHandlerPanic(t, false, mode, nil, nil)
2932 }, testNotParallel)
2933 }
2934
2935 func TestHandlerPanic(t *testing.T) {
2936 run(t, func(t *testing.T, mode testMode) {
2937 testHandlerPanic(t, false, mode, nil, "intentional death for testing")
2938 }, testNotParallel)
2939 }
2940
2941 func TestHandlerPanicWithHijack(t *testing.T) {
2942
2943 run(t, func(t *testing.T, mode testMode) {
2944 testHandlerPanic(t, true, mode, nil, "intentional death for testing")
2945 }, []testMode{http1Mode})
2946 }
2947
2948 func testHandlerPanic(t *testing.T, withHijack bool, mode testMode, wrapper func(Handler) Handler, panicValue any) {
2949
2950
2951
2952
2953
2954
2955
2956
2957 pr, pw := io.Pipe()
2958 defer pw.Close()
2959
2960 var handler Handler = HandlerFunc(func(w ResponseWriter, r *Request) {
2961 if withHijack {
2962 rwc, _, err := w.(Hijacker).Hijack()
2963 if err != nil {
2964 t.Logf("unexpected error: %v", err)
2965 }
2966 defer rwc.Close()
2967 }
2968 panic(panicValue)
2969 })
2970 if wrapper != nil {
2971 handler = wrapper(handler)
2972 }
2973 cst := newClientServerTest(t, mode, handler, func(ts *httptest.Server) {
2974 ts.Config.ErrorLog = log.New(pw, "", 0)
2975 })
2976
2977
2978 done := make(chan bool, 1)
2979 go func() {
2980 buf := make([]byte, 4<<10)
2981 _, err := pr.Read(buf)
2982 pr.Close()
2983 if err != nil && err != io.EOF {
2984 t.Error(err)
2985 }
2986 done <- true
2987 }()
2988
2989 _, err := cst.c.Get(cst.ts.URL)
2990 if err == nil {
2991 t.Logf("expected an error")
2992 }
2993
2994 if panicValue == nil {
2995 return
2996 }
2997
2998 <-done
2999 }
3000
3001 type terrorWriter struct{ t *testing.T }
3002
3003 func (w terrorWriter) Write(p []byte) (int, error) {
3004 w.t.Errorf("%s", p)
3005 return len(p), nil
3006 }
3007
3008
3009
3010 func TestServerWriteHijackZeroBytes(t *testing.T) {
3011 run(t, testServerWriteHijackZeroBytes, []testMode{http1Mode})
3012 }
3013 func testServerWriteHijackZeroBytes(t *testing.T, mode testMode) {
3014 done := make(chan struct{})
3015 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3016 defer close(done)
3017 w.(Flusher).Flush()
3018 conn, _, err := w.(Hijacker).Hijack()
3019 if err != nil {
3020 t.Errorf("Hijack: %v", err)
3021 return
3022 }
3023 defer conn.Close()
3024 _, err = w.Write(nil)
3025 if err != ErrHijacked {
3026 t.Errorf("Write error = %v; want ErrHijacked", err)
3027 }
3028 }), func(ts *httptest.Server) {
3029 ts.Config.ErrorLog = log.New(terrorWriter{t}, "Unexpected write: ", 0)
3030 }).ts
3031
3032 c := ts.Client()
3033 res, err := c.Get(ts.URL)
3034 if err != nil {
3035 t.Fatal(err)
3036 }
3037 res.Body.Close()
3038 <-done
3039 }
3040
3041 func TestServerNoDate(t *testing.T) {
3042 run(t, func(t *testing.T, mode testMode) {
3043 testServerNoHeader(t, mode, "Date")
3044 })
3045 }
3046
3047 func TestServerContentType(t *testing.T) {
3048 run(t, func(t *testing.T, mode testMode) {
3049 testServerNoHeader(t, mode, "Content-Type")
3050 })
3051 }
3052
3053 func testServerNoHeader(t *testing.T, mode testMode, header string) {
3054 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3055 w.Header()[header] = nil
3056 io.WriteString(w, "<html>foo</html>")
3057 }))
3058 res, err := cst.c.Get(cst.ts.URL)
3059 if err != nil {
3060 t.Fatal(err)
3061 }
3062 res.Body.Close()
3063 if got, ok := res.Header[header]; ok {
3064 t.Fatalf("Expected no %s header; got %q", header, got)
3065 }
3066 }
3067
3068 func TestStripPrefix(t *testing.T) { run(t, testStripPrefix) }
3069 func testStripPrefix(t *testing.T, mode testMode) {
3070 h := HandlerFunc(func(w ResponseWriter, r *Request) {
3071 w.Header().Set("X-Path", r.URL.Path)
3072 w.Header().Set("X-RawPath", r.URL.RawPath)
3073 })
3074 ts := newClientServerTest(t, mode, StripPrefix("/foo/bar", h)).ts
3075
3076 c := ts.Client()
3077
3078 cases := []struct {
3079 reqPath string
3080 path string
3081 rawPath string
3082 }{
3083 {"/foo/bar/qux", "/qux", ""},
3084 {"/foo/bar%2Fqux", "/qux", "%2Fqux"},
3085 {"/foo%2Fbar/qux", "", ""},
3086 {"/bar", "", ""},
3087 }
3088 for _, tc := range cases {
3089 t.Run(tc.reqPath, func(t *testing.T) {
3090 res, err := c.Get(ts.URL + tc.reqPath)
3091 if err != nil {
3092 t.Fatal(err)
3093 }
3094 res.Body.Close()
3095 if tc.path == "" {
3096 if res.StatusCode != StatusNotFound {
3097 t.Errorf("got %q, want 404 Not Found", res.Status)
3098 }
3099 return
3100 }
3101 if res.StatusCode != StatusOK {
3102 t.Fatalf("got %q, want 200 OK", res.Status)
3103 }
3104 if g, w := res.Header.Get("X-Path"), tc.path; g != w {
3105 t.Errorf("got Path %q, want %q", g, w)
3106 }
3107 if g, w := res.Header.Get("X-RawPath"), tc.rawPath; g != w {
3108 t.Errorf("got RawPath %q, want %q", g, w)
3109 }
3110 })
3111 }
3112 }
3113
3114
3115 func TestStripPrefixNotModifyRequest(t *testing.T) {
3116 h := StripPrefix("/foo", NotFoundHandler())
3117 req := httptest.NewRequest("GET", "/foo/bar", nil)
3118 h.ServeHTTP(httptest.NewRecorder(), req)
3119 if req.URL.Path != "/foo/bar" {
3120 t.Errorf("StripPrefix should not modify the provided Request, but it did")
3121 }
3122 }
3123
3124 func TestRequestLimit(t *testing.T) { run(t, testRequestLimit) }
3125 func testRequestLimit(t *testing.T, mode testMode) {
3126 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3127 t.Fatalf("didn't expect to get request in Handler")
3128 }), optQuietLog)
3129 req, _ := NewRequest("GET", cst.ts.URL, nil)
3130 var bytesPerHeader = len("header12345: val12345\r\n")
3131 for i := 0; i < ((DefaultMaxHeaderBytes+4096)/bytesPerHeader)+1; i++ {
3132 req.Header.Set(fmt.Sprintf("header%05d", i), fmt.Sprintf("val%05d", i))
3133 }
3134 res, err := cst.c.Do(req)
3135 if res != nil {
3136 defer res.Body.Close()
3137 }
3138 if mode == http2Mode {
3139
3140
3141
3142
3143 if err == nil && res.StatusCode != 431 {
3144 t.Fatalf("expected 431 response status; got: %d %s", res.StatusCode, res.Status)
3145 }
3146 } else {
3147
3148
3149
3150
3151 if err != nil {
3152 t.Fatalf("Do: %v", err)
3153 }
3154 if res.StatusCode != 431 {
3155 t.Fatalf("expected 431 response status; got: %d %s", res.StatusCode, res.Status)
3156 }
3157 }
3158 }
3159
3160 type neverEnding byte
3161
3162 func (b neverEnding) Read(p []byte) (n int, err error) {
3163 for i := range p {
3164 p[i] = byte(b)
3165 }
3166 return len(p), nil
3167 }
3168
3169 type bodyLimitReader struct {
3170 mu sync.Mutex
3171 count int
3172 limit int
3173 closed chan struct{}
3174 }
3175
3176 func (r *bodyLimitReader) Read(p []byte) (int, error) {
3177 r.mu.Lock()
3178 defer r.mu.Unlock()
3179 select {
3180 case <-r.closed:
3181 return 0, errors.New("closed")
3182 default:
3183 }
3184 if r.count > r.limit {
3185 return 0, errors.New("at limit")
3186 }
3187 r.count += len(p)
3188 for i := range p {
3189 p[i] = 'a'
3190 }
3191 return len(p), nil
3192 }
3193
3194 func (r *bodyLimitReader) Close() error {
3195 r.mu.Lock()
3196 defer r.mu.Unlock()
3197 close(r.closed)
3198 return nil
3199 }
3200
3201 func TestRequestBodyLimit(t *testing.T) { run(t, testRequestBodyLimit) }
3202 func testRequestBodyLimit(t *testing.T, mode testMode) {
3203 const limit = 1 << 20
3204 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3205 r.Body = MaxBytesReader(w, r.Body, limit)
3206 n, err := io.Copy(io.Discard, r.Body)
3207 if err == nil {
3208 t.Errorf("expected error from io.Copy")
3209 }
3210 if n != limit {
3211 t.Errorf("io.Copy = %d, want %d", n, limit)
3212 }
3213 mbErr, ok := err.(*MaxBytesError)
3214 if !ok {
3215 t.Errorf("expected MaxBytesError, got %T", err)
3216 }
3217 if mbErr.Limit != limit {
3218 t.Errorf("MaxBytesError.Limit = %d, want %d", mbErr.Limit, limit)
3219 }
3220 }))
3221
3222 body := &bodyLimitReader{
3223 closed: make(chan struct{}),
3224 limit: limit * 200,
3225 }
3226 req, _ := NewRequest("POST", cst.ts.URL, body)
3227
3228
3229
3230
3231
3232
3233
3234
3235
3236
3237 resp, err := cst.c.Do(req)
3238 if err == nil {
3239 resp.Body.Close()
3240 }
3241
3242
3243 <-body.closed
3244
3245 if body.count > limit*100 {
3246 t.Errorf("handler restricted the request body to %d bytes, but client managed to write %d",
3247 limit, body.count)
3248 }
3249 }
3250
3251
3252
3253 func TestClientWriteShutdown(t *testing.T) { run(t, testClientWriteShutdown) }
3254 func testClientWriteShutdown(t *testing.T, mode testMode) {
3255 if runtime.GOOS == "plan9" {
3256 t.Skip("skipping test; see https://golang.org/issue/17906")
3257 }
3258 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {})).ts
3259 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3260 if err != nil {
3261 t.Fatalf("Dial: %v", err)
3262 }
3263 err = conn.(*net.TCPConn).CloseWrite()
3264 if err != nil {
3265 t.Fatalf("CloseWrite: %v", err)
3266 }
3267
3268 bs, err := io.ReadAll(conn)
3269 if err != nil {
3270 t.Errorf("ReadAll: %v", err)
3271 }
3272 got := string(bs)
3273 if got != "" {
3274 t.Errorf("read %q from server; want nothing", got)
3275 }
3276 }
3277
3278
3279
3280 func TestServerBufferedChunking(t *testing.T) {
3281 conn := new(testConn)
3282 conn.readBuf.Write([]byte("GET / HTTP/1.1\r\nHost: foo\r\n\r\n"))
3283 conn.closec = make(chan bool, 1)
3284 ls := &oneConnListener{conn}
3285 go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
3286 rw.(Flusher).Flush()
3287 rw.Write([]byte{'x'})
3288 rw.Write([]byte{'y'})
3289 rw.Write([]byte{'z'})
3290 }))
3291 <-conn.closec
3292 if !bytes.HasSuffix(conn.writeBuf.Bytes(), []byte("\r\n\r\n3\r\nxyz\r\n0\r\n\r\n")) {
3293 t.Errorf("response didn't end with a single 3 byte 'xyz' chunk; got:\n%q",
3294 conn.writeBuf.Bytes())
3295 }
3296 }
3297
3298
3299
3300
3301
3302 func TestServerGracefulClose(t *testing.T) {
3303
3304 run(t, testServerGracefulClose, []testMode{http1Mode}, testNotParallel)
3305 }
3306 func testServerGracefulClose(t *testing.T, mode testMode) {
3307 runTimeSensitiveTest(t, []time.Duration{
3308 1 * time.Millisecond,
3309 5 * time.Millisecond,
3310 10 * time.Millisecond,
3311 50 * time.Millisecond,
3312 100 * time.Millisecond,
3313 500 * time.Millisecond,
3314 time.Second,
3315 5 * time.Second,
3316 }, func(t *testing.T, timeout time.Duration) error {
3317 SetRSTAvoidanceDelay(t, timeout)
3318 t.Logf("set RST avoidance delay to %v", timeout)
3319
3320 const bodySize = 5 << 20
3321 req := []byte(fmt.Sprintf("POST / HTTP/1.1\r\nHost: foo.com\r\nContent-Length: %d\r\n\r\n", bodySize))
3322 for i := 0; i < bodySize; i++ {
3323 req = append(req, 'x')
3324 }
3325
3326 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3327 Error(w, "bye", StatusUnauthorized)
3328 }))
3329
3330
3331 defer cst.close()
3332 ts := cst.ts
3333
3334 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3335 if err != nil {
3336 return err
3337 }
3338 writeErr := make(chan error)
3339 go func() {
3340 _, err := conn.Write(req)
3341 writeErr <- err
3342 }()
3343 defer func() {
3344 conn.Close()
3345
3346
3347
3348 <-writeErr
3349 }()
3350
3351 br := bufio.NewReader(conn)
3352 lineNum := 0
3353 for {
3354 line, err := br.ReadString('\n')
3355 if err == io.EOF {
3356 break
3357 }
3358 if err != nil {
3359 return fmt.Errorf("ReadLine: %v", err)
3360 }
3361 lineNum++
3362 if lineNum == 1 && !strings.Contains(line, "401 Unauthorized") {
3363 t.Errorf("Response line = %q; want a 401", line)
3364 }
3365 }
3366 return nil
3367 })
3368 }
3369
3370 func TestCaseSensitiveMethod(t *testing.T) { run(t, testCaseSensitiveMethod) }
3371 func testCaseSensitiveMethod(t *testing.T, mode testMode) {
3372 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3373 if r.Method != "get" {
3374 t.Errorf(`Got method %q; want "get"`, r.Method)
3375 }
3376 }))
3377 defer cst.close()
3378 req, _ := NewRequest("get", cst.ts.URL, nil)
3379 res, err := cst.c.Do(req)
3380 if err != nil {
3381 t.Error(err)
3382 return
3383 }
3384
3385 res.Body.Close()
3386 }
3387
3388
3389
3390
3391
3392 func TestContentLengthZero(t *testing.T) {
3393 run(t, testContentLengthZero, []testMode{http1Mode})
3394 }
3395 func testContentLengthZero(t *testing.T, mode testMode) {
3396 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {})).ts
3397
3398 for _, version := range []string{"HTTP/1.0", "HTTP/1.1"} {
3399 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3400 if err != nil {
3401 t.Fatalf("error dialing: %v", err)
3402 }
3403 _, err = fmt.Fprintf(conn, "GET / %v\r\nConnection: keep-alive\r\nHost: foo\r\n\r\n", version)
3404 if err != nil {
3405 t.Fatalf("error writing: %v", err)
3406 }
3407 req, _ := NewRequest("GET", "/", nil)
3408 res, err := ReadResponse(bufio.NewReader(conn), req)
3409 if err != nil {
3410 t.Fatalf("error reading response: %v", err)
3411 }
3412 if te := res.TransferEncoding; len(te) > 0 {
3413 t.Errorf("For version %q, Transfer-Encoding = %q; want none", version, te)
3414 }
3415 if cl := res.ContentLength; cl != 0 {
3416 t.Errorf("For version %q, Content-Length = %v; want 0", version, cl)
3417 }
3418 conn.Close()
3419 }
3420 }
3421
3422 func TestCloseNotifier(t *testing.T) {
3423 run(t, testCloseNotifier, []testMode{http1Mode})
3424 }
3425 func testCloseNotifier(t *testing.T, mode testMode) {
3426 gotReq := make(chan bool, 1)
3427 sawClose := make(chan bool, 1)
3428 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
3429 gotReq <- true
3430 cc := rw.(CloseNotifier).CloseNotify()
3431 <-cc
3432 sawClose <- true
3433 })).ts
3434 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3435 if err != nil {
3436 t.Fatalf("error dialing: %v", err)
3437 }
3438 diec := make(chan bool)
3439 go func() {
3440 _, err = fmt.Fprintf(conn, "GET / HTTP/1.1\r\nConnection: keep-alive\r\nHost: foo\r\n\r\n")
3441 if err != nil {
3442 t.Error(err)
3443 return
3444 }
3445 <-diec
3446 conn.Close()
3447 }()
3448 For:
3449 for {
3450 select {
3451 case <-gotReq:
3452 diec <- true
3453 case <-sawClose:
3454 break For
3455 }
3456 }
3457 ts.Close()
3458 }
3459
3460
3461
3462
3463
3464 func TestCloseNotifierPipelined(t *testing.T) {
3465 run(t, testCloseNotifierPipelined, []testMode{http1Mode})
3466 }
3467 func testCloseNotifierPipelined(t *testing.T, mode testMode) {
3468 gotReq := make(chan bool, 2)
3469 sawClose := make(chan bool, 2)
3470 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
3471 gotReq <- true
3472 cc := rw.(CloseNotifier).CloseNotify()
3473 select {
3474 case <-cc:
3475 t.Error("unexpected CloseNotify")
3476 case <-time.After(100 * time.Millisecond):
3477 }
3478 sawClose <- true
3479 })).ts
3480 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3481 if err != nil {
3482 t.Fatalf("error dialing: %v", err)
3483 }
3484 diec := make(chan bool, 1)
3485 defer close(diec)
3486 go func() {
3487 const req = "GET / HTTP/1.1\r\nConnection: keep-alive\r\nHost: foo\r\n\r\n"
3488 _, err = io.WriteString(conn, req+req)
3489 if err != nil {
3490 t.Error(err)
3491 return
3492 }
3493 <-diec
3494 conn.Close()
3495 }()
3496 reqs := 0
3497 closes := 0
3498 for {
3499 select {
3500 case <-gotReq:
3501 reqs++
3502 if reqs > 2 {
3503 t.Fatal("too many requests")
3504 }
3505 case <-sawClose:
3506 closes++
3507 if closes > 1 {
3508 return
3509 }
3510 }
3511 }
3512 }
3513
3514 func TestCloseNotifierChanLeak(t *testing.T) {
3515 defer afterTest(t)
3516 req := reqBytes("GET / HTTP/1.0\nHost: golang.org")
3517 for i := 0; i < 20; i++ {
3518 var output bytes.Buffer
3519 conn := &rwTestConn{
3520 Reader: bytes.NewReader(req),
3521 Writer: &output,
3522 closec: make(chan bool, 1),
3523 }
3524 ln := &oneConnListener{conn: conn}
3525 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
3526
3527
3528
3529 _ = rw.(CloseNotifier).CloseNotify()
3530 })
3531 go Serve(ln, handler)
3532 <-conn.closec
3533 }
3534 }
3535
3536
3537
3538
3539
3540
3541
3542
3543
3544
3545 func TestHijackAfterCloseNotifier(t *testing.T) {
3546 run(t, testHijackAfterCloseNotifier, []testMode{http1Mode})
3547 }
3548 func testHijackAfterCloseNotifier(t *testing.T, mode testMode) {
3549 script := make(chan string, 2)
3550 script <- "closenotify"
3551 script <- "hijack"
3552 close(script)
3553 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3554 plan := <-script
3555 switch plan {
3556 default:
3557 panic("bogus plan; too many requests")
3558 case "closenotify":
3559 w.(CloseNotifier).CloseNotify()
3560 w.Header().Set("X-Addr", r.RemoteAddr)
3561 case "hijack":
3562 c, _, err := w.(Hijacker).Hijack()
3563 if err != nil {
3564 t.Errorf("Hijack in Handler: %v", err)
3565 return
3566 }
3567 if _, ok := c.(*net.TCPConn); !ok {
3568
3569
3570 t.Errorf("type of hijacked conn is %T; want *net.TCPConn", c)
3571 }
3572 fmt.Fprintf(c, "HTTP/1.0 200 OK\r\nX-Addr: %v\r\nContent-Length: 0\r\n\r\n", r.RemoteAddr)
3573 c.Close()
3574 return
3575 }
3576 })).ts
3577 res1, err := ts.Client().Get(ts.URL)
3578 if err != nil {
3579 log.Fatal(err)
3580 }
3581 res2, err := ts.Client().Get(ts.URL)
3582 if err != nil {
3583 log.Fatal(err)
3584 }
3585 addr1 := res1.Header.Get("X-Addr")
3586 addr2 := res2.Header.Get("X-Addr")
3587 if addr1 == "" || addr1 != addr2 {
3588 t.Errorf("addr1, addr2 = %q, %q; want same", addr1, addr2)
3589 }
3590 }
3591
3592 func TestHijackBeforeRequestBodyRead(t *testing.T) {
3593 run(t, testHijackBeforeRequestBodyRead, []testMode{http1Mode})
3594 }
3595 func testHijackBeforeRequestBodyRead(t *testing.T, mode testMode) {
3596 var requestBody = bytes.Repeat([]byte("a"), 1<<20)
3597 bodyOkay := make(chan bool, 1)
3598 gotCloseNotify := make(chan bool, 1)
3599 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3600 defer close(bodyOkay)
3601
3602 reqBody := r.Body
3603 r.Body = nil
3604
3605 gone := w.(CloseNotifier).CloseNotify()
3606 slurp, err := io.ReadAll(reqBody)
3607 if err != nil {
3608 t.Errorf("Body read: %v", err)
3609 return
3610 }
3611 if len(slurp) != len(requestBody) {
3612 t.Errorf("Backend read %d request body bytes; want %d", len(slurp), len(requestBody))
3613 return
3614 }
3615 if !bytes.Equal(slurp, requestBody) {
3616 t.Error("Backend read wrong request body.")
3617 return
3618 }
3619 bodyOkay <- true
3620 <-gone
3621 gotCloseNotify <- true
3622 })).ts
3623
3624 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3625 if err != nil {
3626 t.Fatal(err)
3627 }
3628 defer conn.Close()
3629
3630 fmt.Fprintf(conn, "POST / HTTP/1.1\r\nHost: foo\r\nContent-Length: %d\r\n\r\n%s",
3631 len(requestBody), requestBody)
3632 if !<-bodyOkay {
3633
3634 return
3635 }
3636 conn.Close()
3637 <-gotCloseNotify
3638 }
3639
3640 func TestOptions(t *testing.T) { run(t, testOptions, []testMode{http1Mode}) }
3641 func testOptions(t *testing.T, mode testMode) {
3642 uric := make(chan string, 2)
3643 mux := NewServeMux()
3644 mux.HandleFunc("/", func(w ResponseWriter, r *Request) {
3645 uric <- r.RequestURI
3646 })
3647 ts := newClientServerTest(t, mode, mux).ts
3648
3649 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3650 if err != nil {
3651 t.Fatal(err)
3652 }
3653 defer conn.Close()
3654
3655
3656 _, err = conn.Write([]byte("OPTIONS * HTTP/1.1\r\nHost: foo.com\r\n\r\n"))
3657 if err != nil {
3658 t.Fatal(err)
3659 }
3660 br := bufio.NewReader(conn)
3661 res, err := ReadResponse(br, &Request{Method: "OPTIONS"})
3662 if err != nil {
3663 t.Fatal(err)
3664 }
3665 if res.StatusCode != 200 {
3666 t.Errorf("Got non-200 response to OPTIONS *: %#v", res)
3667 }
3668
3669
3670 _, err = conn.Write([]byte("GET * HTTP/1.1\r\nHost: foo.com\r\n\r\n"))
3671 if err != nil {
3672 t.Fatal(err)
3673 }
3674 res, err = ReadResponse(br, &Request{Method: "GET"})
3675 if err != nil {
3676 t.Fatal(err)
3677 }
3678 if res.StatusCode != 400 {
3679 t.Errorf("Got non-400 response to GET *: %#v", res)
3680 }
3681
3682 res, err = Get(ts.URL + "/second")
3683 if err != nil {
3684 t.Fatal(err)
3685 }
3686 res.Body.Close()
3687 if got := <-uric; got != "/second" {
3688 t.Errorf("Handler saw request for %q; want /second", got)
3689 }
3690 }
3691
3692 func TestOptionsHandler(t *testing.T) { run(t, testOptionsHandler, []testMode{http1Mode}) }
3693 func testOptionsHandler(t *testing.T, mode testMode) {
3694 rc := make(chan *Request, 1)
3695
3696 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3697 rc <- r
3698 }), func(ts *httptest.Server) {
3699 ts.Config.DisableGeneralOptionsHandler = true
3700 }).ts
3701
3702 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3703 if err != nil {
3704 t.Fatal(err)
3705 }
3706 defer conn.Close()
3707
3708 _, err = conn.Write([]byte("OPTIONS * HTTP/1.1\r\nHost: foo.com\r\n\r\n"))
3709 if err != nil {
3710 t.Fatal(err)
3711 }
3712
3713 if got := <-rc; got.Method != "OPTIONS" || got.RequestURI != "*" {
3714 t.Errorf("Expected OPTIONS * request, got %v", got)
3715 }
3716 }
3717
3718
3719
3720
3721
3722
3723
3724
3725
3726
3727 func TestHeaderToWire(t *testing.T) {
3728 tests := []struct {
3729 name string
3730 handler func(ResponseWriter, *Request)
3731 check func(got, logs string) error
3732 }{
3733 {
3734 name: "write without Header",
3735 handler: func(rw ResponseWriter, r *Request) {
3736 rw.Write([]byte("hello world"))
3737 },
3738 check: func(got, logs string) error {
3739 if !strings.Contains(got, "Content-Length:") {
3740 return errors.New("no content-length")
3741 }
3742 if !strings.Contains(got, "Content-Type: text/plain") {
3743 return errors.New("no content-type")
3744 }
3745 return nil
3746 },
3747 },
3748 {
3749 name: "Header mutation before write",
3750 handler: func(rw ResponseWriter, r *Request) {
3751 h := rw.Header()
3752 h.Set("Content-Type", "some/type")
3753 rw.Write([]byte("hello world"))
3754 h.Set("Too-Late", "bogus")
3755 },
3756 check: func(got, logs string) error {
3757 if !strings.Contains(got, "Content-Length:") {
3758 return errors.New("no content-length")
3759 }
3760 if !strings.Contains(got, "Content-Type: some/type") {
3761 return errors.New("wrong content-type")
3762 }
3763 if strings.Contains(got, "Too-Late") {
3764 return errors.New("don't want too-late header")
3765 }
3766 return nil
3767 },
3768 },
3769 {
3770 name: "write then useless Header mutation",
3771 handler: func(rw ResponseWriter, r *Request) {
3772 rw.Write([]byte("hello world"))
3773 rw.Header().Set("Too-Late", "Write already wrote headers")
3774 },
3775 check: func(got, logs string) error {
3776 if strings.Contains(got, "Too-Late") {
3777 return errors.New("header appeared from after WriteHeader")
3778 }
3779 return nil
3780 },
3781 },
3782 {
3783 name: "flush then write",
3784 handler: func(rw ResponseWriter, r *Request) {
3785 rw.(Flusher).Flush()
3786 rw.Write([]byte("post-flush"))
3787 rw.Header().Set("Too-Late", "Write already wrote headers")
3788 },
3789 check: func(got, logs string) error {
3790 if !strings.Contains(got, "Transfer-Encoding: chunked") {
3791 return errors.New("not chunked")
3792 }
3793 if strings.Contains(got, "Too-Late") {
3794 return errors.New("header appeared from after WriteHeader")
3795 }
3796 return nil
3797 },
3798 },
3799 {
3800 name: "header then flush",
3801 handler: func(rw ResponseWriter, r *Request) {
3802 rw.Header().Set("Content-Type", "some/type")
3803 rw.(Flusher).Flush()
3804 rw.Write([]byte("post-flush"))
3805 rw.Header().Set("Too-Late", "Write already wrote headers")
3806 },
3807 check: func(got, logs string) error {
3808 if !strings.Contains(got, "Transfer-Encoding: chunked") {
3809 return errors.New("not chunked")
3810 }
3811 if strings.Contains(got, "Too-Late") {
3812 return errors.New("header appeared from after WriteHeader")
3813 }
3814 if !strings.Contains(got, "Content-Type: some/type") {
3815 return errors.New("wrong content-type")
3816 }
3817 return nil
3818 },
3819 },
3820 {
3821 name: "sniff-on-first-write content-type",
3822 handler: func(rw ResponseWriter, r *Request) {
3823 rw.Write([]byte("<html><head></head><body>some html</body></html>"))
3824 rw.Header().Set("Content-Type", "x/wrong")
3825 },
3826 check: func(got, logs string) error {
3827 if !strings.Contains(got, "Content-Type: text/html") {
3828 return errors.New("wrong content-type; want html")
3829 }
3830 return nil
3831 },
3832 },
3833 {
3834 name: "explicit content-type wins",
3835 handler: func(rw ResponseWriter, r *Request) {
3836 rw.Header().Set("Content-Type", "some/type")
3837 rw.Write([]byte("<html><head></head><body>some html</body></html>"))
3838 },
3839 check: func(got, logs string) error {
3840 if !strings.Contains(got, "Content-Type: some/type") {
3841 return errors.New("wrong content-type; want html")
3842 }
3843 return nil
3844 },
3845 },
3846 {
3847 name: "empty handler",
3848 handler: func(rw ResponseWriter, r *Request) {
3849 },
3850 check: func(got, logs string) error {
3851 if !strings.Contains(got, "Content-Length: 0") {
3852 return errors.New("want 0 content-length")
3853 }
3854 return nil
3855 },
3856 },
3857 {
3858 name: "only Header, no write",
3859 handler: func(rw ResponseWriter, r *Request) {
3860 rw.Header().Set("Some-Header", "some-value")
3861 },
3862 check: func(got, logs string) error {
3863 if !strings.Contains(got, "Some-Header") {
3864 return errors.New("didn't get header")
3865 }
3866 return nil
3867 },
3868 },
3869 {
3870 name: "WriteHeader call",
3871 handler: func(rw ResponseWriter, r *Request) {
3872 rw.WriteHeader(404)
3873 rw.Header().Set("Too-Late", "some-value")
3874 },
3875 check: func(got, logs string) error {
3876 if !strings.Contains(got, "404") {
3877 return errors.New("wrong status")
3878 }
3879 if strings.Contains(got, "Too-Late") {
3880 return errors.New("shouldn't have seen Too-Late")
3881 }
3882 return nil
3883 },
3884 },
3885 }
3886 for _, tc := range tests {
3887 ht := newHandlerTest(HandlerFunc(tc.handler))
3888 got := ht.rawResponse("GET / HTTP/1.1\nHost: golang.org")
3889 logs := ht.logbuf.String()
3890 if err := tc.check(got, logs); err != nil {
3891 t.Errorf("%s: %v\nGot response:\n%s\n\n%s", tc.name, err, got, logs)
3892 }
3893 }
3894 }
3895
3896 type errorListener struct {
3897 errs []error
3898 }
3899
3900 func (l *errorListener) Accept() (c net.Conn, err error) {
3901 if len(l.errs) == 0 {
3902 return nil, io.EOF
3903 }
3904 err = l.errs[0]
3905 l.errs = l.errs[1:]
3906 return
3907 }
3908
3909 func (l *errorListener) Close() error {
3910 return nil
3911 }
3912
3913 func (l *errorListener) Addr() net.Addr {
3914 return dummyAddr("test-address")
3915 }
3916
3917 func TestAcceptMaxFds(t *testing.T) {
3918 setParallel(t)
3919
3920 ln := &errorListener{[]error{
3921 &net.OpError{
3922 Op: "accept",
3923 Err: syscall.EMFILE,
3924 }}}
3925 server := &Server{
3926 Handler: HandlerFunc(HandlerFunc(func(ResponseWriter, *Request) {})),
3927 ErrorLog: log.New(io.Discard, "", 0),
3928 }
3929 err := server.Serve(ln)
3930 if err != io.EOF {
3931 t.Errorf("got error %v, want EOF", err)
3932 }
3933 }
3934
3935 func TestWriteAfterHijack(t *testing.T) {
3936 req := reqBytes("GET / HTTP/1.1\nHost: golang.org")
3937 var buf strings.Builder
3938 wrotec := make(chan bool, 1)
3939 conn := &rwTestConn{
3940 Reader: bytes.NewReader(req),
3941 Writer: &buf,
3942 closec: make(chan bool, 1),
3943 }
3944 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
3945 conn, bufrw, err := rw.(Hijacker).Hijack()
3946 if err != nil {
3947 t.Error(err)
3948 return
3949 }
3950 go func() {
3951 bufrw.Write([]byte("[hijack-to-bufw]"))
3952 bufrw.Flush()
3953 conn.Write([]byte("[hijack-to-conn]"))
3954 conn.Close()
3955 wrotec <- true
3956 }()
3957 })
3958 ln := &oneConnListener{conn: conn}
3959 go Serve(ln, handler)
3960 <-conn.closec
3961 <-wrotec
3962 if g, w := buf.String(), "[hijack-to-bufw][hijack-to-conn]"; g != w {
3963 t.Errorf("wrote %q; want %q", g, w)
3964 }
3965 }
3966
3967 func TestDoubleHijack(t *testing.T) {
3968 req := reqBytes("GET / HTTP/1.1\nHost: golang.org")
3969 var buf bytes.Buffer
3970 conn := &rwTestConn{
3971 Reader: bytes.NewReader(req),
3972 Writer: &buf,
3973 closec: make(chan bool, 1),
3974 }
3975 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
3976 conn, _, err := rw.(Hijacker).Hijack()
3977 if err != nil {
3978 t.Error(err)
3979 return
3980 }
3981 _, _, err = rw.(Hijacker).Hijack()
3982 if err == nil {
3983 t.Errorf("got err = nil; want err != nil")
3984 }
3985 conn.Close()
3986 })
3987 ln := &oneConnListener{conn: conn}
3988 go Serve(ln, handler)
3989 <-conn.closec
3990 }
3991
3992
3993
3994
3995
3996
3997
3998 func TestHTTP10ConnectionHeader(t *testing.T) {
3999 run(t, testHTTP10ConnectionHeader, []testMode{http1Mode})
4000 }
4001 func testHTTP10ConnectionHeader(t *testing.T, mode testMode) {
4002 mux := NewServeMux()
4003 mux.Handle("/", HandlerFunc(func(ResponseWriter, *Request) {}))
4004 ts := newClientServerTest(t, mode, mux).ts
4005
4006
4007 tests := []struct {
4008 req string
4009 expect []string
4010 }{
4011 {
4012 req: "GET / HTTP/1.0\r\n\r\n",
4013 expect: nil,
4014 },
4015 {
4016 req: "OPTIONS * HTTP/1.0\r\n\r\n",
4017 expect: nil,
4018 },
4019 {
4020 req: "GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n",
4021 expect: []string{"keep-alive"},
4022 },
4023 }
4024
4025 for _, tt := range tests {
4026 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
4027 if err != nil {
4028 t.Fatal("dial err:", err)
4029 }
4030
4031 _, err = fmt.Fprint(conn, tt.req)
4032 if err != nil {
4033 t.Fatal("conn write err:", err)
4034 }
4035
4036 resp, err := ReadResponse(bufio.NewReader(conn), &Request{Method: "GET"})
4037 if err != nil {
4038 t.Fatal("ReadResponse err:", err)
4039 }
4040 conn.Close()
4041 resp.Body.Close()
4042
4043 got := resp.Header["Connection"]
4044 if !slices.Equal(got, tt.expect) {
4045 t.Errorf("wrong Connection headers for request %q. Got %q expect %q", tt.req, got, tt.expect)
4046 }
4047 }
4048 }
4049
4050
4051 func TestServerReaderFromOrder(t *testing.T) { run(t, testServerReaderFromOrder) }
4052 func testServerReaderFromOrder(t *testing.T, mode testMode) {
4053 pr, pw := io.Pipe()
4054 const size = 3 << 20
4055 cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
4056 rw.Header().Set("Content-Type", "text/plain")
4057 done := make(chan bool)
4058 go func() {
4059 io.Copy(rw, pr)
4060 close(done)
4061 }()
4062 time.Sleep(25 * time.Millisecond)
4063 n, err := io.Copy(io.Discard, req.Body)
4064 if err != nil {
4065 t.Errorf("handler Copy: %v", err)
4066 return
4067 }
4068 if n != size {
4069 t.Errorf("handler Copy = %d; want %d", n, size)
4070 }
4071 pw.Write([]byte("hi"))
4072 pw.Close()
4073 <-done
4074 }))
4075
4076 req, err := NewRequest("POST", cst.ts.URL, io.LimitReader(neverEnding('a'), size))
4077 if err != nil {
4078 t.Fatal(err)
4079 }
4080 res, err := cst.c.Do(req)
4081 if err != nil {
4082 t.Fatal(err)
4083 }
4084 all, err := io.ReadAll(res.Body)
4085 if err != nil {
4086 t.Fatal(err)
4087 }
4088 res.Body.Close()
4089 if string(all) != "hi" {
4090 t.Errorf("Body = %q; want hi", all)
4091 }
4092 }
4093
4094
4095 func TestCodesPreventingContentTypeAndBody(t *testing.T) {
4096 for _, code := range []int{StatusNotModified, StatusNoContent} {
4097 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
4098 if r.URL.Path == "/header" {
4099 w.Header().Set("Content-Length", "123")
4100 }
4101 w.WriteHeader(code)
4102 if r.URL.Path == "/more" {
4103 w.Write([]byte("stuff"))
4104 }
4105 }))
4106 for _, req := range []string{
4107 "GET / HTTP/1.0",
4108 "GET /header HTTP/1.0",
4109 "GET /more HTTP/1.0",
4110 "GET / HTTP/1.1\nHost: foo",
4111 "GET /header HTTP/1.1\nHost: foo",
4112 "GET /more HTTP/1.1\nHost: foo",
4113 } {
4114 got := ht.rawResponse(req)
4115 wantStatus := fmt.Sprintf("%d %s", code, StatusText(code))
4116 if !strings.Contains(got, wantStatus) {
4117 t.Errorf("Code %d: Wanted %q Modified for %q: %s", code, wantStatus, req, got)
4118 } else if strings.Contains(got, "Content-Length") {
4119 t.Errorf("Code %d: Got a Content-Length from %q: %s", code, req, got)
4120 } else if strings.Contains(got, "stuff") {
4121 t.Errorf("Code %d: Response contains a body from %q: %s", code, req, got)
4122 }
4123 }
4124 }
4125 }
4126
4127 func TestContentTypeOkayOn204(t *testing.T) {
4128 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
4129 w.Header().Set("Content-Length", "123")
4130 w.Header().Set("Content-Type", "foo/bar")
4131 w.WriteHeader(204)
4132 }))
4133 got := ht.rawResponse("GET / HTTP/1.1\nHost: foo")
4134 if !strings.Contains(got, "Content-Type: foo/bar") {
4135 t.Errorf("Response = %q; want Content-Type: foo/bar", got)
4136 }
4137 if strings.Contains(got, "Content-Length: 123") {
4138 t.Errorf("Response = %q; don't want a Content-Length", got)
4139 }
4140 }
4141
4142
4143
4144
4145
4146
4147
4148 func TestTransportAndServerSharedBodyRace(t *testing.T) {
4149 run(t, testTransportAndServerSharedBodyRace, testNotParallel)
4150 }
4151 func testTransportAndServerSharedBodyRace(t *testing.T, mode testMode) {
4152
4153
4154
4155
4156 runTimeSensitiveTest(t, []time.Duration{
4157 1 * time.Millisecond,
4158 5 * time.Millisecond,
4159 10 * time.Millisecond,
4160 50 * time.Millisecond,
4161 100 * time.Millisecond,
4162 500 * time.Millisecond,
4163 time.Second,
4164 5 * time.Second,
4165 }, func(t *testing.T, timeout time.Duration) error {
4166 SetRSTAvoidanceDelay(t, timeout)
4167 t.Logf("set RST avoidance delay to %v", timeout)
4168
4169 const bodySize = 1 << 20
4170
4171 var wg sync.WaitGroup
4172 backend := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
4173
4174
4175
4176
4177
4178
4179
4180
4181 wg.Add(1)
4182 defer wg.Done()
4183
4184 n, err := io.CopyN(rw, req.Body, bodySize)
4185 t.Logf("backend CopyN: %v, %v", n, err)
4186 <-req.Context().Done()
4187 }))
4188
4189
4190 defer func() {
4191 wg.Wait()
4192 backend.close()
4193 }()
4194
4195 var proxy *clientServerTest
4196 proxy = newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
4197 req2, _ := NewRequest("POST", backend.ts.URL, req.Body)
4198 req2.ContentLength = bodySize
4199 cancel := make(chan struct{})
4200 req2.Cancel = cancel
4201
4202 bresp, err := proxy.c.Do(req2)
4203 if err != nil {
4204 t.Errorf("Proxy outbound request: %v", err)
4205 return
4206 }
4207 _, err = io.CopyN(io.Discard, bresp.Body, bodySize/2)
4208 if err != nil {
4209 t.Errorf("Proxy copy error: %v", err)
4210 return
4211 }
4212 t.Cleanup(func() { bresp.Body.Close() })
4213
4214
4215
4216
4217
4218
4219 if mode == http2Mode {
4220 close(cancel)
4221 } else {
4222 proxy.c.Transport.(*Transport).CancelRequest(req2)
4223 }
4224 rw.Write([]byte("OK"))
4225 }))
4226 defer proxy.close()
4227
4228 req, _ := NewRequest("POST", proxy.ts.URL, io.LimitReader(neverEnding('a'), bodySize))
4229 res, err := proxy.c.Do(req)
4230 if err != nil {
4231 return fmt.Errorf("original request: %v", err)
4232 }
4233 res.Body.Close()
4234 return nil
4235 })
4236 }
4237
4238
4239
4240
4241 func TestRequestBodyCloseDoesntBlock(t *testing.T) {
4242 run(t, testRequestBodyCloseDoesntBlock, []testMode{http1Mode})
4243 }
4244 func testRequestBodyCloseDoesntBlock(t *testing.T, mode testMode) {
4245 if testing.Short() {
4246 t.Skip("skipping in -short mode")
4247 }
4248
4249 readErrCh := make(chan error, 1)
4250 errCh := make(chan error, 2)
4251
4252 server := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
4253 go func(body io.Reader) {
4254 _, err := body.Read(make([]byte, 100))
4255 readErrCh <- err
4256 }(req.Body)
4257 time.Sleep(500 * time.Millisecond)
4258 })).ts
4259
4260 closeConn := make(chan bool)
4261 defer close(closeConn)
4262 go func() {
4263 conn, err := net.Dial("tcp", server.Listener.Addr().String())
4264 if err != nil {
4265 errCh <- err
4266 return
4267 }
4268 defer conn.Close()
4269 _, err = conn.Write([]byte("POST / HTTP/1.1\r\nConnection: close\r\nHost: foo\r\nContent-Length: 100000\r\n\r\n"))
4270 if err != nil {
4271 errCh <- err
4272 return
4273 }
4274
4275
4276 <-closeConn
4277 }()
4278 select {
4279 case err := <-readErrCh:
4280 if err == nil {
4281 t.Error("Read was nil. Expected error.")
4282 }
4283 case err := <-errCh:
4284 t.Error(err)
4285 }
4286 }
4287
4288
4289 func TestResponseWriterWriteString(t *testing.T) {
4290 okc := make(chan bool, 1)
4291 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
4292 _, ok := w.(io.StringWriter)
4293 okc <- ok
4294 }))
4295 ht.rawResponse("GET / HTTP/1.0")
4296 select {
4297 case ok := <-okc:
4298 if !ok {
4299 t.Error("ResponseWriter did not implement io.StringWriter")
4300 }
4301 default:
4302 t.Error("handler was never called")
4303 }
4304 }
4305
4306 func TestServerConnState(t *testing.T) { run(t, testServerConnState, []testMode{http1Mode}) }
4307 func testServerConnState(t *testing.T, mode testMode) {
4308 handler := map[string]func(w ResponseWriter, r *Request){
4309 "/": func(w ResponseWriter, r *Request) {
4310 fmt.Fprintf(w, "Hello.")
4311 },
4312 "/close": func(w ResponseWriter, r *Request) {
4313 w.Header().Set("Connection", "close")
4314 fmt.Fprintf(w, "Hello.")
4315 },
4316 "/hijack": func(w ResponseWriter, r *Request) {
4317 c, _, _ := w.(Hijacker).Hijack()
4318 c.Write([]byte("HTTP/1.0 200 OK\r\nConnection: close\r\n\r\nHello."))
4319 c.Close()
4320 },
4321 "/hijack-panic": func(w ResponseWriter, r *Request) {
4322 c, _, _ := w.(Hijacker).Hijack()
4323 c.Write([]byte("HTTP/1.0 200 OK\r\nConnection: close\r\n\r\nHello."))
4324 c.Close()
4325 panic("intentional panic")
4326 },
4327 }
4328
4329
4330 type stateLog struct {
4331 active net.Conn
4332 got []ConnState
4333 want []ConnState
4334 complete chan<- struct{}
4335 }
4336 activeLog := make(chan *stateLog, 1)
4337
4338
4339
4340
4341 wantLog := func(doRequests func(), want ...ConnState) {
4342 t.Helper()
4343 complete := make(chan struct{})
4344 activeLog <- &stateLog{want: want, complete: complete}
4345
4346 doRequests()
4347
4348 <-complete
4349 sl := <-activeLog
4350 if !slices.Equal(sl.got, sl.want) {
4351 t.Errorf("Request(s) produced unexpected state sequence.\nGot: %v\nWant: %v", sl.got, sl.want)
4352 }
4353
4354
4355
4356 }
4357
4358 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4359 handler[r.URL.Path](w, r)
4360 }), func(ts *httptest.Server) {
4361 ts.Config.ErrorLog = log.New(io.Discard, "", 0)
4362 ts.Config.ConnState = func(c net.Conn, state ConnState) {
4363 if c == nil {
4364 t.Errorf("nil conn seen in state %s", state)
4365 return
4366 }
4367 sl := <-activeLog
4368 if sl.active == nil && state == StateNew {
4369 sl.active = c
4370 } else if sl.active != c {
4371 t.Errorf("unexpected conn in state %s", state)
4372 activeLog <- sl
4373 return
4374 }
4375 sl.got = append(sl.got, state)
4376 if sl.complete != nil && (len(sl.got) >= len(sl.want) || !slices.Equal(sl.got, sl.want[:len(sl.got)])) {
4377 close(sl.complete)
4378 sl.complete = nil
4379 }
4380 activeLog <- sl
4381 }
4382 }).ts
4383 defer func() {
4384 activeLog <- &stateLog{}
4385 ts.Close()
4386 }()
4387
4388 c := ts.Client()
4389
4390 mustGet := func(url string, headers ...string) {
4391 t.Helper()
4392 req, err := NewRequest("GET", url, nil)
4393 if err != nil {
4394 t.Fatal(err)
4395 }
4396 for len(headers) > 0 {
4397 req.Header.Add(headers[0], headers[1])
4398 headers = headers[2:]
4399 }
4400 res, err := c.Do(req)
4401 if err != nil {
4402 t.Errorf("Error fetching %s: %v", url, err)
4403 return
4404 }
4405 _, err = io.ReadAll(res.Body)
4406 defer res.Body.Close()
4407 if err != nil {
4408 t.Errorf("Error reading %s: %v", url, err)
4409 }
4410 }
4411
4412 wantLog(func() {
4413 mustGet(ts.URL + "/")
4414 mustGet(ts.URL + "/close")
4415 }, StateNew, StateActive, StateIdle, StateActive, StateClosed)
4416
4417 wantLog(func() {
4418 mustGet(ts.URL + "/")
4419 mustGet(ts.URL+"/", "Connection", "close")
4420 }, StateNew, StateActive, StateIdle, StateActive, StateClosed)
4421
4422 wantLog(func() {
4423 mustGet(ts.URL + "/hijack")
4424 }, StateNew, StateActive, StateHijacked)
4425
4426 wantLog(func() {
4427 mustGet(ts.URL + "/hijack-panic")
4428 }, StateNew, StateActive, StateHijacked)
4429
4430 wantLog(func() {
4431 c, err := net.Dial("tcp", ts.Listener.Addr().String())
4432 if err != nil {
4433 t.Fatal(err)
4434 }
4435 c.Close()
4436 }, StateNew, StateClosed)
4437
4438 wantLog(func() {
4439 c, err := net.Dial("tcp", ts.Listener.Addr().String())
4440 if err != nil {
4441 t.Fatal(err)
4442 }
4443 if _, err := io.WriteString(c, "BOGUS REQUEST\r\n\r\n"); err != nil {
4444 t.Fatal(err)
4445 }
4446 c.Read(make([]byte, 1))
4447 c.Close()
4448 }, StateNew, StateActive, StateClosed)
4449
4450 wantLog(func() {
4451 c, err := net.Dial("tcp", ts.Listener.Addr().String())
4452 if err != nil {
4453 t.Fatal(err)
4454 }
4455 if _, err := io.WriteString(c, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n"); err != nil {
4456 t.Fatal(err)
4457 }
4458 res, err := ReadResponse(bufio.NewReader(c), nil)
4459 if err != nil {
4460 t.Fatal(err)
4461 }
4462 if _, err := io.Copy(io.Discard, res.Body); err != nil {
4463 t.Fatal(err)
4464 }
4465 c.Close()
4466 }, StateNew, StateActive, StateIdle, StateClosed)
4467 }
4468
4469 func TestServerKeepAlivesEnabledResultClose(t *testing.T) {
4470 run(t, testServerKeepAlivesEnabledResultClose, []testMode{http1Mode})
4471 }
4472 func testServerKeepAlivesEnabledResultClose(t *testing.T, mode testMode) {
4473 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4474 }), func(ts *httptest.Server) {
4475 ts.Config.SetKeepAlivesEnabled(false)
4476 }).ts
4477 res, err := ts.Client().Get(ts.URL)
4478 if err != nil {
4479 t.Fatal(err)
4480 }
4481 defer res.Body.Close()
4482 if !res.Close {
4483 t.Errorf("Body.Close == false; want true")
4484 }
4485 }
4486
4487
4488 func TestServerEmptyBodyRace(t *testing.T) { run(t, testServerEmptyBodyRace) }
4489 func testServerEmptyBodyRace(t *testing.T, mode testMode) {
4490 var n int32
4491 cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
4492 atomic.AddInt32(&n, 1)
4493 }), optQuietLog)
4494 var wg sync.WaitGroup
4495 const reqs = 20
4496 for i := 0; i < reqs; i++ {
4497 wg.Add(1)
4498 go func() {
4499 defer wg.Done()
4500 res, err := cst.c.Get(cst.ts.URL)
4501 if err != nil {
4502
4503
4504 time.Sleep(10 * time.Millisecond)
4505 res, err = cst.c.Get(cst.ts.URL)
4506 if err != nil {
4507 t.Error(err)
4508 return
4509 }
4510 }
4511 defer res.Body.Close()
4512 _, err = io.Copy(io.Discard, res.Body)
4513 if err != nil {
4514 t.Error(err)
4515 return
4516 }
4517 }()
4518 }
4519 wg.Wait()
4520 if got := atomic.LoadInt32(&n); got != reqs {
4521 t.Errorf("handler ran %d times; want %d", got, reqs)
4522 }
4523 }
4524
4525 func TestServerConnStateNew(t *testing.T) {
4526 sawNew := false
4527 srv := &Server{
4528 ConnState: func(c net.Conn, state ConnState) {
4529 if state == StateNew {
4530 sawNew = true
4531 }
4532 },
4533 Handler: HandlerFunc(func(w ResponseWriter, r *Request) {}),
4534 }
4535 srv.Serve(&oneConnListener{
4536 conn: &rwTestConn{
4537 Reader: strings.NewReader("GET / HTTP/1.1\r\nHost: foo\r\n\r\n"),
4538 Writer: io.Discard,
4539 },
4540 })
4541 if !sawNew {
4542 t.Error("StateNew not seen")
4543 }
4544 }
4545
4546 type closeWriteTestConn struct {
4547 rwTestConn
4548 didCloseWrite bool
4549 }
4550
4551 func (c *closeWriteTestConn) CloseWrite() error {
4552 c.didCloseWrite = true
4553 return nil
4554 }
4555
4556 func TestCloseWrite(t *testing.T) {
4557 SetRSTAvoidanceDelay(t, 1*time.Millisecond)
4558
4559 var srv Server
4560 var testConn closeWriteTestConn
4561 c := ExportServerNewConn(&srv, &testConn)
4562 ExportCloseWriteAndWait(c)
4563 if !testConn.didCloseWrite {
4564 t.Error("didn't see CloseWrite call")
4565 }
4566 }
4567
4568
4569
4570
4571
4572
4573
4574
4575 func TestServerFlushAndHijack(t *testing.T) { run(t, testServerFlushAndHijack, []testMode{http1Mode}) }
4576 func testServerFlushAndHijack(t *testing.T, mode testMode) {
4577 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4578 io.WriteString(w, "Hello, ")
4579 w.(Flusher).Flush()
4580 conn, buf, _ := w.(Hijacker).Hijack()
4581 buf.WriteString("6\r\nworld!\r\n0\r\n\r\n")
4582 if err := buf.Flush(); err != nil {
4583 t.Error(err)
4584 }
4585 if err := conn.Close(); err != nil {
4586 t.Error(err)
4587 }
4588 })).ts
4589 res, err := Get(ts.URL)
4590 if err != nil {
4591 t.Fatal(err)
4592 }
4593 defer res.Body.Close()
4594 all, err := io.ReadAll(res.Body)
4595 if err != nil {
4596 t.Fatal(err)
4597 }
4598 if want := "Hello, world!"; string(all) != want {
4599 t.Errorf("Got %q; want %q", all, want)
4600 }
4601 }
4602
4603
4604
4605
4606
4607
4608
4609 func TestServerKeepAliveAfterWriteError(t *testing.T) {
4610 run(t, testServerKeepAliveAfterWriteError, []testMode{http1Mode})
4611 }
4612 func testServerKeepAliveAfterWriteError(t *testing.T, mode testMode) {
4613 if testing.Short() {
4614 t.Skip("skipping in -short mode")
4615 }
4616 const numReq = 3
4617 addrc := make(chan string, numReq)
4618 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4619 addrc <- r.RemoteAddr
4620 time.Sleep(500 * time.Millisecond)
4621 w.(Flusher).Flush()
4622 }), func(ts *httptest.Server) {
4623 ts.Config.WriteTimeout = 250 * time.Millisecond
4624 }).ts
4625
4626 errc := make(chan error, numReq)
4627 go func() {
4628 defer close(errc)
4629 for i := 0; i < numReq; i++ {
4630 res, err := Get(ts.URL)
4631 if res != nil {
4632 res.Body.Close()
4633 }
4634 errc <- err
4635 }
4636 }()
4637
4638 addrSeen := map[string]bool{}
4639 numOkay := 0
4640 for {
4641 select {
4642 case v := <-addrc:
4643 addrSeen[v] = true
4644 case err, ok := <-errc:
4645 if !ok {
4646 if len(addrSeen) != numReq {
4647 t.Errorf("saw %d unique client addresses; want %d", len(addrSeen), numReq)
4648 }
4649 if numOkay != 0 {
4650 t.Errorf("got %d successful client requests; want 0", numOkay)
4651 }
4652 return
4653 }
4654 if err == nil {
4655 numOkay++
4656 }
4657 }
4658 }
4659 }
4660
4661
4662
4663 func TestNoContentLengthIfTransferEncoding(t *testing.T) {
4664 run(t, testNoContentLengthIfTransferEncoding, []testMode{http1Mode})
4665 }
4666 func testNoContentLengthIfTransferEncoding(t *testing.T, mode testMode) {
4667 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4668 w.Header().Set("Transfer-Encoding", "foo")
4669 io.WriteString(w, "<html>")
4670 })).ts
4671 c, err := net.Dial("tcp", ts.Listener.Addr().String())
4672 if err != nil {
4673 t.Fatalf("Dial: %v", err)
4674 }
4675 defer c.Close()
4676 if _, err := io.WriteString(c, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n"); err != nil {
4677 t.Fatal(err)
4678 }
4679 bs := bufio.NewScanner(c)
4680 var got strings.Builder
4681 for bs.Scan() {
4682 if strings.TrimSpace(bs.Text()) == "" {
4683 break
4684 }
4685 got.WriteString(bs.Text())
4686 got.WriteByte('\n')
4687 }
4688 if err := bs.Err(); err != nil {
4689 t.Fatal(err)
4690 }
4691 if strings.Contains(got.String(), "Content-Length") {
4692 t.Errorf("Unexpected Content-Length in response headers: %s", got.String())
4693 }
4694 if strings.Contains(got.String(), "Content-Type") {
4695 t.Errorf("Unexpected Content-Type in response headers: %s", got.String())
4696 }
4697 }
4698
4699
4700
4701 func TestTolerateCRLFBeforeRequestLine(t *testing.T) {
4702 req := []byte("POST / HTTP/1.1\r\nHost: golang.org\r\nContent-Length: 3\r\n\r\nABC" +
4703 "\r\n\r\n" +
4704 "GET / HTTP/1.1\r\nHost: golang.org\r\n\r\n")
4705 var buf bytes.Buffer
4706 conn := &rwTestConn{
4707 Reader: bytes.NewReader(req),
4708 Writer: &buf,
4709 closec: make(chan bool, 1),
4710 }
4711 ln := &oneConnListener{conn: conn}
4712 numReq := 0
4713 go Serve(ln, HandlerFunc(func(rw ResponseWriter, r *Request) {
4714 numReq++
4715 }))
4716 <-conn.closec
4717 if numReq != 2 {
4718 t.Errorf("num requests = %d; want 2", numReq)
4719 t.Logf("Res: %s", buf.Bytes())
4720 }
4721 }
4722
4723 func TestIssue13893_Expect100(t *testing.T) {
4724
4725 req := reqBytes(`PUT /readbody HTTP/1.1
4726 User-Agent: PycURL/7.22.0
4727 Host: 127.0.0.1:9000
4728 Accept: */*
4729 Expect: 100-continue
4730 Content-Length: 10
4731
4732 HelloWorld
4733
4734 `)
4735 var buf bytes.Buffer
4736 conn := &rwTestConn{
4737 Reader: bytes.NewReader(req),
4738 Writer: &buf,
4739 closec: make(chan bool, 1),
4740 }
4741 ln := &oneConnListener{conn: conn}
4742 go Serve(ln, HandlerFunc(func(w ResponseWriter, r *Request) {
4743 if _, ok := r.Header["Expect"]; !ok {
4744 t.Error("Expect header should not be filtered out")
4745 }
4746 }))
4747 <-conn.closec
4748 }
4749
4750 func TestIssue11549_Expect100(t *testing.T) {
4751 req := reqBytes(`PUT /readbody HTTP/1.1
4752 User-Agent: PycURL/7.22.0
4753 Host: 127.0.0.1:9000
4754 Accept: */*
4755 Expect: 100-continue
4756 Content-Length: 10
4757
4758 HelloWorldPUT /noreadbody HTTP/1.1
4759 User-Agent: PycURL/7.22.0
4760 Host: 127.0.0.1:9000
4761 Accept: */*
4762 Expect: 100-continue
4763 Content-Length: 10
4764
4765 GET /should-be-ignored HTTP/1.1
4766 Host: foo
4767
4768 `)
4769 var buf strings.Builder
4770 conn := &rwTestConn{
4771 Reader: bytes.NewReader(req),
4772 Writer: &buf,
4773 closec: make(chan bool, 1),
4774 }
4775 ln := &oneConnListener{conn: conn}
4776 numReq := 0
4777 go Serve(ln, HandlerFunc(func(w ResponseWriter, r *Request) {
4778 numReq++
4779 if r.URL.Path == "/readbody" {
4780 io.ReadAll(r.Body)
4781 }
4782 io.WriteString(w, "Hello world!")
4783 }))
4784 <-conn.closec
4785 if numReq != 2 {
4786 t.Errorf("num requests = %d; want 2", numReq)
4787 }
4788 if !strings.Contains(buf.String(), "Connection: close\r\n") {
4789 t.Errorf("expected 'Connection: close' in response; got: %s", buf.String())
4790 }
4791 }
4792
4793
4794
4795 func TestHandlerFinishSkipBigContentLengthRead(t *testing.T) {
4796 setParallel(t)
4797 conn := newTestConn()
4798 conn.readBuf.WriteString(
4799 "POST / HTTP/1.1\r\n" +
4800 "Host: test\r\n" +
4801 "Content-Length: 9999999999\r\n" +
4802 "\r\n" + strings.Repeat("a", 1<<20))
4803
4804 ls := &oneConnListener{conn}
4805 var inHandlerLen int
4806 go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
4807 inHandlerLen = conn.readBuf.Len()
4808 rw.WriteHeader(404)
4809 }))
4810 <-conn.closec
4811 afterHandlerLen := conn.readBuf.Len()
4812
4813 if afterHandlerLen != inHandlerLen {
4814 t.Errorf("unexpected implicit read. Read buffer went from %d -> %d", inHandlerLen, afterHandlerLen)
4815 }
4816 }
4817
4818 func TestHandlerSetsBodyNil(t *testing.T) { run(t, testHandlerSetsBodyNil) }
4819 func testHandlerSetsBodyNil(t *testing.T, mode testMode) {
4820 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4821 r.Body = nil
4822 fmt.Fprintf(w, "%v", r.RemoteAddr)
4823 }))
4824 get := func() string {
4825 res, err := cst.c.Get(cst.ts.URL)
4826 if err != nil {
4827 t.Fatal(err)
4828 }
4829 defer res.Body.Close()
4830 slurp, err := io.ReadAll(res.Body)
4831 if err != nil {
4832 t.Fatal(err)
4833 }
4834 return string(slurp)
4835 }
4836 a, b := get(), get()
4837 if a != b {
4838 t.Errorf("Failed to reuse connections between requests: %v vs %v", a, b)
4839 }
4840 }
4841
4842
4843
4844 func TestServerValidatesHostHeader(t *testing.T) {
4845 tests := []struct {
4846 proto string
4847 host string
4848 want int
4849 }{
4850 {"HTTP/0.9", "", 505},
4851
4852 {"HTTP/1.1", "", 400},
4853 {"HTTP/1.1", "Host: \r\n", 200},
4854 {"HTTP/1.1", "Host: 1.2.3.4\r\n", 200},
4855 {"HTTP/1.1", "Host: foo.com\r\n", 200},
4856 {"HTTP/1.1", "Host: foo-bar_baz.com\r\n", 200},
4857 {"HTTP/1.1", "Host: foo.com:80\r\n", 200},
4858 {"HTTP/1.1", "Host: ::1\r\n", 200},
4859 {"HTTP/1.1", "Host: [::1]\r\n", 200},
4860 {"HTTP/1.1", "Host: [::1]:80\r\n", 200},
4861 {"HTTP/1.1", "Host: [::1%25en0]:80\r\n", 200},
4862 {"HTTP/1.1", "Host: 1.2.3.4\r\n", 200},
4863 {"HTTP/1.1", "Host: \x06\r\n", 400},
4864 {"HTTP/1.1", "Host: \xff\r\n", 400},
4865 {"HTTP/1.1", "Host: {\r\n", 400},
4866 {"HTTP/1.1", "Host: }\r\n", 400},
4867 {"HTTP/1.1", "Host: first\r\nHost: second\r\n", 400},
4868
4869
4870
4871 {"HTTP/1.0", "", 200},
4872 {"HTTP/1.0", "Host: first\r\nHost: second\r\n", 400},
4873 {"HTTP/1.0", "Host: \xff\r\n", 400},
4874
4875
4876 {"PRI * HTTP/2.0", "", 200},
4877
4878
4879 {"CONNECT golang.org:443 HTTP/1.1", "", 200},
4880
4881
4882 {"PRI / HTTP/2.0", "", 505},
4883 {"GET / HTTP/2.0", "", 505},
4884 {"GET / HTTP/3.0", "", 505},
4885 }
4886 for _, tt := range tests {
4887 conn := newTestConn()
4888 methodTarget := "GET / "
4889 if !strings.HasPrefix(tt.proto, "HTTP/") {
4890 methodTarget = ""
4891 }
4892 io.WriteString(&conn.readBuf, methodTarget+tt.proto+"\r\n"+tt.host+"\r\n")
4893
4894 ln := &oneConnListener{conn}
4895 srv := Server{
4896 ErrorLog: quietLog,
4897 Handler: HandlerFunc(func(ResponseWriter, *Request) {}),
4898 }
4899 go srv.Serve(ln)
4900 <-conn.closec
4901 res, err := ReadResponse(bufio.NewReader(&conn.writeBuf), nil)
4902 if err != nil {
4903 t.Errorf("For %s %q, ReadResponse: %v", tt.proto, tt.host, res)
4904 continue
4905 }
4906 if res.StatusCode != tt.want {
4907 t.Errorf("For %s %q, Status = %d; want %d", tt.proto, tt.host, res.StatusCode, tt.want)
4908 }
4909 }
4910 }
4911
4912 func TestServerHandlersCanHandleH2PRI(t *testing.T) {
4913 run(t, testServerHandlersCanHandleH2PRI, []testMode{http1Mode})
4914 }
4915 func testServerHandlersCanHandleH2PRI(t *testing.T, mode testMode) {
4916 const upgradeResponse = "upgrade here"
4917 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4918 conn, br, err := w.(Hijacker).Hijack()
4919 if err != nil {
4920 t.Error(err)
4921 return
4922 }
4923 defer conn.Close()
4924 if r.Method != "PRI" || r.RequestURI != "*" {
4925 t.Errorf("Got method/target %q %q; want PRI *", r.Method, r.RequestURI)
4926 return
4927 }
4928 if !r.Close {
4929 t.Errorf("Request.Close = true; want false")
4930 }
4931 const want = "SM\r\n\r\n"
4932 buf := make([]byte, len(want))
4933 n, err := io.ReadFull(br, buf)
4934 if err != nil || string(buf[:n]) != want {
4935 t.Errorf("Read = %v, %v (%q), want %q", n, err, buf[:n], want)
4936 return
4937 }
4938 io.WriteString(conn, upgradeResponse)
4939 })).ts
4940
4941 c, err := net.Dial("tcp", ts.Listener.Addr().String())
4942 if err != nil {
4943 t.Fatalf("Dial: %v", err)
4944 }
4945 defer c.Close()
4946 io.WriteString(c, "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n")
4947 slurp, err := io.ReadAll(c)
4948 if err != nil {
4949 t.Fatal(err)
4950 }
4951 if string(slurp) != upgradeResponse {
4952 t.Errorf("Handler response = %q; want %q", slurp, upgradeResponse)
4953 }
4954 }
4955
4956
4957
4958 func TestServerValidatesHeaders(t *testing.T) {
4959 setParallel(t)
4960 tests := []struct {
4961 header string
4962 want int
4963 }{
4964 {"", 200},
4965 {"Foo: bar\r\n", 200},
4966 {"X-Foo: bar\r\n", 200},
4967 {"Foo: a space\r\n", 200},
4968
4969 {"A space: foo\r\n", 400},
4970 {"foo\xffbar: foo\r\n", 400},
4971 {"foo\x00bar: foo\r\n", 400},
4972 {"Foo: " + strings.Repeat("x", 1<<21) + "\r\n", 431},
4973
4974
4975 {"Foo : bar\r\n", 400},
4976 {"Foo\t: bar\r\n", 400},
4977
4978
4979
4980 {": empty key\r\n", 400},
4981
4982
4983
4984
4985 {"Content-Length: notdigits\r\n", 400},
4986 {"Content-Length: notdigits\r\nTransfer-Encoding: chunked\r\n\r\n0\r\n\r\n", 400},
4987
4988 {"foo: foo foo\r\n", 200},
4989 {"foo: foo\tfoo\r\n", 200},
4990 {"foo: foo\x00foo\r\n", 400},
4991 {"foo: foo\x7ffoo\r\n", 400},
4992 {"foo: foo\xfffoo\r\n", 200},
4993 }
4994 for _, tt := range tests {
4995 conn := newTestConn()
4996 io.WriteString(&conn.readBuf, "GET / HTTP/1.1\r\nHost: foo\r\n"+tt.header+"\r\n")
4997
4998 ln := &oneConnListener{conn}
4999 srv := Server{
5000 ErrorLog: quietLog,
5001 Handler: HandlerFunc(func(ResponseWriter, *Request) {}),
5002 }
5003 go srv.Serve(ln)
5004 <-conn.closec
5005 res, err := ReadResponse(bufio.NewReader(&conn.writeBuf), nil)
5006 if err != nil {
5007 t.Errorf("For %q, ReadResponse: %v", tt.header, res)
5008 continue
5009 }
5010 if res.StatusCode != tt.want {
5011 t.Errorf("For %q, Status = %d; want %d", tt.header, res.StatusCode, tt.want)
5012 }
5013 }
5014 }
5015
5016 func TestServerRequestContextCancel_ServeHTTPDone(t *testing.T) {
5017 run(t, testServerRequestContextCancel_ServeHTTPDone)
5018 }
5019 func testServerRequestContextCancel_ServeHTTPDone(t *testing.T, mode testMode) {
5020 ctxc := make(chan context.Context, 1)
5021 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5022 ctx := r.Context()
5023 select {
5024 case <-ctx.Done():
5025 t.Error("should not be Done in ServeHTTP")
5026 default:
5027 }
5028 ctxc <- ctx
5029 }))
5030 res, err := cst.c.Get(cst.ts.URL)
5031 if err != nil {
5032 t.Fatal(err)
5033 }
5034 res.Body.Close()
5035 ctx := <-ctxc
5036 select {
5037 case <-ctx.Done():
5038 default:
5039 t.Error("context should be done after ServeHTTP completes")
5040 }
5041 }
5042
5043
5044
5045
5046
5047 func TestServerRequestContextCancel_ConnClose(t *testing.T) {
5048 run(t, testServerRequestContextCancel_ConnClose, []testMode{http1Mode})
5049 }
5050 func testServerRequestContextCancel_ConnClose(t *testing.T, mode testMode) {
5051 inHandler := make(chan struct{})
5052 handlerDone := make(chan struct{})
5053 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5054 close(inHandler)
5055 <-r.Context().Done()
5056 close(handlerDone)
5057 })).ts
5058 c, err := net.Dial("tcp", ts.Listener.Addr().String())
5059 if err != nil {
5060 t.Fatal(err)
5061 }
5062 defer c.Close()
5063 io.WriteString(c, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n")
5064 <-inHandler
5065 c.Close()
5066 <-handlerDone
5067 }
5068
5069 func TestServerContext_ServerContextKey(t *testing.T) {
5070 run(t, testServerContext_ServerContextKey)
5071 }
5072 func testServerContext_ServerContextKey(t *testing.T, mode testMode) {
5073 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5074 ctx := r.Context()
5075 got := ctx.Value(ServerContextKey)
5076 if _, ok := got.(*Server); !ok {
5077 t.Errorf("context value = %T; want *http.Server", got)
5078 }
5079 }))
5080 res, err := cst.c.Get(cst.ts.URL)
5081 if err != nil {
5082 t.Fatal(err)
5083 }
5084 res.Body.Close()
5085 }
5086
5087 func TestServerContext_LocalAddrContextKey(t *testing.T) {
5088 run(t, testServerContext_LocalAddrContextKey)
5089 }
5090 func testServerContext_LocalAddrContextKey(t *testing.T, mode testMode) {
5091 ch := make(chan any, 1)
5092 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5093 ch <- r.Context().Value(LocalAddrContextKey)
5094 }))
5095 if _, err := cst.c.Head(cst.ts.URL); err != nil {
5096 t.Fatal(err)
5097 }
5098
5099 host := cst.ts.Listener.Addr().String()
5100 got := <-ch
5101 if addr, ok := got.(net.Addr); !ok {
5102 t.Errorf("local addr value = %T; want net.Addr", got)
5103 } else if fmt.Sprint(addr) != host {
5104 t.Errorf("local addr = %v; want %v", addr, host)
5105 }
5106 }
5107
5108
5109 func TestHandlerSetTransferEncodingChunked(t *testing.T) {
5110 setParallel(t)
5111 defer afterTest(t)
5112 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
5113 w.Header().Set("Transfer-Encoding", "chunked")
5114 w.Write([]byte("hello"))
5115 }))
5116 resp := ht.rawResponse("GET / HTTP/1.1\nHost: foo")
5117 const hdr = "Transfer-Encoding: chunked"
5118 if n := strings.Count(resp, hdr); n != 1 {
5119 t.Errorf("want 1 occurrence of %q in response, got %v\nresponse: %v", hdr, n, resp)
5120 }
5121 }
5122
5123
5124 func TestHandlerSetTransferEncodingGzip(t *testing.T) {
5125 setParallel(t)
5126 defer afterTest(t)
5127 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
5128 w.Header().Set("Transfer-Encoding", "gzip")
5129 gz := gzip.NewWriter(w)
5130 gz.Write([]byte("hello"))
5131 gz.Close()
5132 }))
5133 resp := ht.rawResponse("GET / HTTP/1.1\nHost: foo")
5134 for _, v := range []string{"gzip", "chunked"} {
5135 hdr := "Transfer-Encoding: " + v
5136 if n := strings.Count(resp, hdr); n != 1 {
5137 t.Errorf("want 1 occurrence of %q in response, got %v\nresponse: %v", hdr, n, resp)
5138 }
5139 }
5140 }
5141
5142 func BenchmarkClientServer(b *testing.B) {
5143 run(b, benchmarkClientServer, []testMode{http1Mode, https1Mode, http2Mode})
5144 }
5145 func benchmarkClientServer(b *testing.B, mode testMode) {
5146 b.ReportAllocs()
5147 b.StopTimer()
5148 ts := newClientServerTest(b, mode, HandlerFunc(func(rw ResponseWriter, r *Request) {
5149 fmt.Fprintf(rw, "Hello world.\n")
5150 })).ts
5151 b.StartTimer()
5152
5153 c := ts.Client()
5154 for i := 0; i < b.N; i++ {
5155 res, err := c.Get(ts.URL)
5156 if err != nil {
5157 b.Fatal("Get:", err)
5158 }
5159 all, err := io.ReadAll(res.Body)
5160 res.Body.Close()
5161 if err != nil {
5162 b.Fatal("ReadAll:", err)
5163 }
5164 body := string(all)
5165 if body != "Hello world.\n" {
5166 b.Fatal("Got body:", body)
5167 }
5168 }
5169
5170 b.StopTimer()
5171 }
5172
5173 func BenchmarkClientServerParallel(b *testing.B) {
5174 for _, parallelism := range []int{4, 64} {
5175 b.Run(fmt.Sprint(parallelism), func(b *testing.B) {
5176 run(b, func(b *testing.B, mode testMode) {
5177 benchmarkClientServerParallel(b, parallelism, mode)
5178 }, []testMode{http1Mode, https1Mode, http2Mode})
5179 })
5180 }
5181 }
5182
5183 func benchmarkClientServerParallel(b *testing.B, parallelism int, mode testMode) {
5184 b.ReportAllocs()
5185 ts := newClientServerTest(b, mode, HandlerFunc(func(rw ResponseWriter, r *Request) {
5186 fmt.Fprintf(rw, "Hello world.\n")
5187 })).ts
5188 b.ResetTimer()
5189 b.SetParallelism(parallelism)
5190 b.RunParallel(func(pb *testing.PB) {
5191 c := ts.Client()
5192 for pb.Next() {
5193 res, err := c.Get(ts.URL)
5194 if err != nil {
5195 b.Logf("Get: %v", err)
5196 continue
5197 }
5198 all, err := io.ReadAll(res.Body)
5199 res.Body.Close()
5200 if err != nil {
5201 b.Logf("ReadAll: %v", err)
5202 continue
5203 }
5204 body := string(all)
5205 if body != "Hello world.\n" {
5206 panic("Got body: " + body)
5207 }
5208 }
5209 })
5210 }
5211
5212
5213
5214
5215
5216
5217
5218
5219
5220
5221 func BenchmarkServer(b *testing.B) {
5222 b.ReportAllocs()
5223
5224 if url := os.Getenv("TEST_BENCH_SERVER_URL"); url != "" {
5225 n, err := strconv.Atoi(os.Getenv("TEST_BENCH_CLIENT_N"))
5226 if err != nil {
5227 panic(err)
5228 }
5229 for i := 0; i < n; i++ {
5230 res, err := Get(url)
5231 if err != nil {
5232 log.Panicf("Get: %v", err)
5233 }
5234 all, err := io.ReadAll(res.Body)
5235 res.Body.Close()
5236 if err != nil {
5237 log.Panicf("ReadAll: %v", err)
5238 }
5239 body := string(all)
5240 if body != "Hello world.\n" {
5241 log.Panicf("Got body: %q", body)
5242 }
5243 }
5244 os.Exit(0)
5245 return
5246 }
5247
5248 var res = []byte("Hello world.\n")
5249 b.StopTimer()
5250 ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, r *Request) {
5251 rw.Header().Set("Content-Type", "text/html; charset=utf-8")
5252 rw.Write(res)
5253 }))
5254 defer ts.Close()
5255 b.StartTimer()
5256
5257 cmd := testenv.Command(b, os.Args[0], "-test.run=^$", "-test.bench=^BenchmarkServer$")
5258 cmd.Env = append([]string{
5259 fmt.Sprintf("TEST_BENCH_CLIENT_N=%d", b.N),
5260 fmt.Sprintf("TEST_BENCH_SERVER_URL=%s", ts.URL),
5261 }, os.Environ()...)
5262 out, err := cmd.CombinedOutput()
5263 if err != nil {
5264 b.Errorf("Test failure: %v, with output: %s", err, out)
5265 }
5266 }
5267
5268
5269 func getNoBody(urlStr string) (*Response, error) {
5270 res, err := Get(urlStr)
5271 if err != nil {
5272 return nil, err
5273 }
5274 res.Body.Close()
5275 return res, nil
5276 }
5277
5278
5279
5280 func BenchmarkClient(b *testing.B) {
5281 b.ReportAllocs()
5282 b.StopTimer()
5283 defer afterTest(b)
5284
5285 var data = []byte("Hello world.\n")
5286 if server := os.Getenv("TEST_BENCH_SERVER"); server != "" {
5287
5288 port := os.Getenv("TEST_BENCH_SERVER_PORT")
5289 if port == "" {
5290 port = "0"
5291 }
5292 ln, err := net.Listen("tcp", "localhost:"+port)
5293 if err != nil {
5294 fmt.Fprintln(os.Stderr, err.Error())
5295 os.Exit(1)
5296 }
5297 fmt.Println(ln.Addr().String())
5298 HandleFunc("/", func(w ResponseWriter, r *Request) {
5299 r.ParseForm()
5300 if r.Form.Get("stop") != "" {
5301 os.Exit(0)
5302 }
5303 w.Header().Set("Content-Type", "text/html; charset=utf-8")
5304 w.Write(data)
5305 })
5306 var srv Server
5307 log.Fatal(srv.Serve(ln))
5308 }
5309
5310
5311 ctx, cancel := context.WithCancel(context.Background())
5312 cmd := testenv.CommandContext(b, ctx, os.Args[0], "-test.run=^$", "-test.bench=^BenchmarkClient$")
5313 cmd.Env = append(cmd.Environ(), "TEST_BENCH_SERVER=yes")
5314 cmd.Stderr = os.Stderr
5315 stdout, err := cmd.StdoutPipe()
5316 if err != nil {
5317 b.Fatal(err)
5318 }
5319 if err := cmd.Start(); err != nil {
5320 b.Fatalf("subprocess failed to start: %v", err)
5321 }
5322
5323 done := make(chan error, 1)
5324 go func() {
5325 done <- cmd.Wait()
5326 close(done)
5327 }()
5328 defer func() {
5329 cancel()
5330 <-done
5331 }()
5332
5333
5334
5335 bs := bufio.NewScanner(stdout)
5336 if !bs.Scan() {
5337 b.Fatalf("failed to read listening URL from child: %v", bs.Err())
5338 }
5339 url := "http://" + strings.TrimSpace(bs.Text()) + "/"
5340 if _, err := getNoBody(url); err != nil {
5341 b.Fatalf("initial probe of child process failed: %v", err)
5342 }
5343
5344
5345 b.StartTimer()
5346 for i := 0; i < b.N; i++ {
5347 res, err := Get(url)
5348 if err != nil {
5349 b.Fatalf("Get: %v", err)
5350 }
5351 body, err := io.ReadAll(res.Body)
5352 res.Body.Close()
5353 if err != nil {
5354 b.Fatalf("ReadAll: %v", err)
5355 }
5356 if !bytes.Equal(body, data) {
5357 b.Fatalf("Got body: %q", body)
5358 }
5359 }
5360 b.StopTimer()
5361
5362
5363 getNoBody(url + "?stop=yes")
5364 if err := <-done; err != nil {
5365 b.Fatalf("subprocess failed: %v", err)
5366 }
5367 }
5368
5369 func BenchmarkServerFakeConnNoKeepAlive(b *testing.B) {
5370 b.ReportAllocs()
5371 req := reqBytes(`GET / HTTP/1.0
5372 Host: golang.org
5373 Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8
5374 User-Agent: Mozilla/5.0 (Macintosh; Intel Mac OS X 10_8_2) AppleWebKit/537.17 (KHTML, like Gecko) Chrome/24.0.1312.52 Safari/537.17
5375 Accept-Encoding: gzip,deflate,sdch
5376 Accept-Language: en-US,en;q=0.8
5377 Accept-Charset: ISO-8859-1,utf-8;q=0.7,*;q=0.3
5378 `)
5379 res := []byte("Hello world!\n")
5380
5381 conn := newTestConn()
5382 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
5383 rw.Header().Set("Content-Type", "text/html; charset=utf-8")
5384 rw.Write(res)
5385 })
5386 ln := new(oneConnListener)
5387 for i := 0; i < b.N; i++ {
5388 conn.readBuf.Reset()
5389 conn.writeBuf.Reset()
5390 conn.readBuf.Write(req)
5391 ln.conn = conn
5392 Serve(ln, handler)
5393 <-conn.closec
5394 }
5395 }
5396
5397
5398 type repeatReader struct {
5399 content []byte
5400 count int
5401 off int
5402 }
5403
5404 func (r *repeatReader) Read(p []byte) (n int, err error) {
5405 if r.count <= 0 {
5406 return 0, io.EOF
5407 }
5408 n = copy(p, r.content[r.off:])
5409 r.off += n
5410 if r.off == len(r.content) {
5411 r.count--
5412 r.off = 0
5413 }
5414 return
5415 }
5416
5417 func BenchmarkServerFakeConnWithKeepAlive(b *testing.B) {
5418 b.ReportAllocs()
5419
5420 req := reqBytes(`GET / HTTP/1.1
5421 Host: golang.org
5422 Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8
5423 User-Agent: Mozilla/5.0 (Macintosh; Intel Mac OS X 10_8_2) AppleWebKit/537.17 (KHTML, like Gecko) Chrome/24.0.1312.52 Safari/537.17
5424 Accept-Encoding: gzip,deflate,sdch
5425 Accept-Language: en-US,en;q=0.8
5426 Accept-Charset: ISO-8859-1,utf-8;q=0.7,*;q=0.3
5427 `)
5428 res := []byte("Hello world!\n")
5429
5430 conn := &rwTestConn{
5431 Reader: &repeatReader{content: req, count: b.N},
5432 Writer: io.Discard,
5433 closec: make(chan bool, 1),
5434 }
5435 handled := 0
5436 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
5437 handled++
5438 rw.Header().Set("Content-Type", "text/html; charset=utf-8")
5439 rw.Write(res)
5440 })
5441 ln := &oneConnListener{conn: conn}
5442 go Serve(ln, handler)
5443 <-conn.closec
5444 if b.N != handled {
5445 b.Errorf("b.N=%d but handled %d", b.N, handled)
5446 }
5447 }
5448
5449
5450
5451 func BenchmarkServerFakeConnWithKeepAliveLite(b *testing.B) {
5452 b.ReportAllocs()
5453
5454 req := reqBytes(`GET / HTTP/1.1
5455 Host: golang.org
5456 `)
5457 res := []byte("Hello world!\n")
5458
5459 conn := &rwTestConn{
5460 Reader: &repeatReader{content: req, count: b.N},
5461 Writer: io.Discard,
5462 closec: make(chan bool, 1),
5463 }
5464 handled := 0
5465 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
5466 handled++
5467 rw.Write(res)
5468 })
5469 ln := &oneConnListener{conn: conn}
5470 go Serve(ln, handler)
5471 <-conn.closec
5472 if b.N != handled {
5473 b.Errorf("b.N=%d but handled %d", b.N, handled)
5474 }
5475 }
5476
5477 const someResponse = "<html>some response</html>"
5478
5479
5480 var response = bytes.Repeat([]byte(someResponse), 2<<10/len(someResponse))
5481
5482
5483 func BenchmarkServerHandlerTypeLen(b *testing.B) {
5484 benchmarkHandler(b, HandlerFunc(func(w ResponseWriter, r *Request) {
5485 w.Header().Set("Content-Type", "text/html")
5486 w.Header().Set("Content-Length", strconv.Itoa(len(response)))
5487 w.Write(response)
5488 }))
5489 }
5490
5491
5492 func BenchmarkServerHandlerNoLen(b *testing.B) {
5493 benchmarkHandler(b, HandlerFunc(func(w ResponseWriter, r *Request) {
5494 w.Header().Set("Content-Type", "text/html")
5495 w.Write(response)
5496 }))
5497 }
5498
5499
5500 func BenchmarkServerHandlerNoType(b *testing.B) {
5501 benchmarkHandler(b, HandlerFunc(func(w ResponseWriter, r *Request) {
5502 w.Header().Set("Content-Length", strconv.Itoa(len(response)))
5503 w.Write(response)
5504 }))
5505 }
5506
5507
5508 func BenchmarkServerHandlerNoHeader(b *testing.B) {
5509 benchmarkHandler(b, HandlerFunc(func(w ResponseWriter, r *Request) {
5510 w.Write(response)
5511 }))
5512 }
5513
5514 func benchmarkHandler(b *testing.B, h Handler) {
5515 b.ReportAllocs()
5516 req := reqBytes(`GET / HTTP/1.1
5517 Host: golang.org
5518 `)
5519 conn := &rwTestConn{
5520 Reader: &repeatReader{content: req, count: b.N},
5521 Writer: io.Discard,
5522 closec: make(chan bool, 1),
5523 }
5524 handled := 0
5525 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
5526 handled++
5527 h.ServeHTTP(rw, r)
5528 })
5529 ln := &oneConnListener{conn: conn}
5530 go Serve(ln, handler)
5531 <-conn.closec
5532 if b.N != handled {
5533 b.Errorf("b.N=%d but handled %d", b.N, handled)
5534 }
5535 }
5536
5537 func BenchmarkServerHijack(b *testing.B) {
5538 b.ReportAllocs()
5539 req := reqBytes(`GET / HTTP/1.1
5540 Host: golang.org
5541 `)
5542 h := HandlerFunc(func(w ResponseWriter, r *Request) {
5543 conn, _, err := w.(Hijacker).Hijack()
5544 if err != nil {
5545 panic(err)
5546 }
5547 conn.Close()
5548 })
5549 conn := &rwTestConn{
5550 Writer: io.Discard,
5551 closec: make(chan bool, 1),
5552 }
5553 ln := &oneConnListener{conn: conn}
5554 for i := 0; i < b.N; i++ {
5555 conn.Reader = bytes.NewReader(req)
5556 ln.conn = conn
5557 Serve(ln, h)
5558 <-conn.closec
5559 }
5560 }
5561
5562 func BenchmarkCloseNotifier(b *testing.B) { run(b, benchmarkCloseNotifier, []testMode{http1Mode}) }
5563 func benchmarkCloseNotifier(b *testing.B, mode testMode) {
5564 b.ReportAllocs()
5565 b.StopTimer()
5566 sawClose := make(chan bool)
5567 ts := newClientServerTest(b, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
5568 <-rw.(CloseNotifier).CloseNotify()
5569 sawClose <- true
5570 })).ts
5571 b.StartTimer()
5572 for i := 0; i < b.N; i++ {
5573 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
5574 if err != nil {
5575 b.Fatalf("error dialing: %v", err)
5576 }
5577 _, err = fmt.Fprintf(conn, "GET / HTTP/1.1\r\nConnection: keep-alive\r\nHost: foo\r\n\r\n")
5578 if err != nil {
5579 b.Fatal(err)
5580 }
5581 conn.Close()
5582 <-sawClose
5583 }
5584 b.StopTimer()
5585 }
5586
5587
5588 func TestConcurrentServerServe(t *testing.T) {
5589 setParallel(t)
5590 for i := 0; i < 100; i++ {
5591 ln1 := &oneConnListener{conn: nil}
5592 ln2 := &oneConnListener{conn: nil}
5593 srv := Server{}
5594 go func() { srv.Serve(ln1) }()
5595 go func() { srv.Serve(ln2) }()
5596 }
5597 }
5598
5599 func TestServerIdleTimeout(t *testing.T) { run(t, testServerIdleTimeout, []testMode{http1Mode}) }
5600 func testServerIdleTimeout(t *testing.T, mode testMode) {
5601 if testing.Short() {
5602 t.Skip("skipping in short mode")
5603 }
5604 runTimeSensitiveTest(t, []time.Duration{
5605 10 * time.Millisecond,
5606 100 * time.Millisecond,
5607 1 * time.Second,
5608 10 * time.Second,
5609 }, func(t *testing.T, readHeaderTimeout time.Duration) error {
5610 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5611 io.Copy(io.Discard, r.Body)
5612 io.WriteString(w, r.RemoteAddr)
5613 }), func(ts *httptest.Server) {
5614 ts.Config.ReadHeaderTimeout = readHeaderTimeout
5615 ts.Config.IdleTimeout = 2 * readHeaderTimeout
5616 })
5617 defer cst.close()
5618 ts := cst.ts
5619 t.Logf("ReadHeaderTimeout = %v", ts.Config.ReadHeaderTimeout)
5620 t.Logf("IdleTimeout = %v", ts.Config.IdleTimeout)
5621 c := ts.Client()
5622
5623 get := func() (string, error) {
5624 res, err := c.Get(ts.URL)
5625 if err != nil {
5626 return "", err
5627 }
5628 defer res.Body.Close()
5629 slurp, err := io.ReadAll(res.Body)
5630 if err != nil {
5631
5632
5633
5634 t.Fatal(err)
5635 }
5636 return string(slurp), nil
5637 }
5638
5639 a1, err := get()
5640 if err != nil {
5641 return err
5642 }
5643 a2, err := get()
5644 if err != nil {
5645 return err
5646 }
5647 if a1 != a2 {
5648 return fmt.Errorf("did requests on different connections")
5649 }
5650 time.Sleep(ts.Config.IdleTimeout * 3 / 2)
5651 a3, err := get()
5652 if err != nil {
5653 return err
5654 }
5655 if a2 == a3 {
5656 return fmt.Errorf("request three unexpectedly on same connection")
5657 }
5658
5659
5660 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
5661 if err != nil {
5662 return err
5663 }
5664 defer conn.Close()
5665 conn.Write([]byte("GET / HTTP/1.1\r\nHost: foo.com\r\n"))
5666 time.Sleep(ts.Config.ReadHeaderTimeout * 2)
5667 if _, err := io.CopyN(io.Discard, conn, 1); err == nil {
5668 return fmt.Errorf("copy byte succeeded; want err")
5669 }
5670
5671 return nil
5672 })
5673 }
5674
5675 func get(t *testing.T, c *Client, url string) string {
5676 res, err := c.Get(url)
5677 if err != nil {
5678 t.Fatal(err)
5679 }
5680 defer res.Body.Close()
5681 slurp, err := io.ReadAll(res.Body)
5682 if err != nil {
5683 t.Fatal(err)
5684 }
5685 return string(slurp)
5686 }
5687
5688
5689
5690 func TestServerSetKeepAlivesEnabledClosesConns(t *testing.T) {
5691 run(t, testServerSetKeepAlivesEnabledClosesConns, []testMode{http1Mode})
5692 }
5693 func testServerSetKeepAlivesEnabledClosesConns(t *testing.T, mode testMode) {
5694 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5695 io.WriteString(w, r.RemoteAddr)
5696 })).ts
5697
5698 c := ts.Client()
5699 tr := c.Transport.(*Transport)
5700
5701 get := func() string { return get(t, c, ts.URL) }
5702
5703 a1, a2 := get(), get()
5704 if a1 == a2 {
5705 t.Logf("made two requests from a single conn %q (as expected)", a1)
5706 } else {
5707 t.Errorf("server reported requests from %q and %q; expected same connection", a1, a2)
5708 }
5709
5710
5711
5712
5713
5714 if conns := tr.IdleConnStrsForTesting(); len(conns) != 1 {
5715 t.Errorf("found %d idle conns (%q); want 1", len(conns), conns)
5716 }
5717
5718
5719 ts.Config.SetKeepAlivesEnabled(false)
5720
5721 waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
5722 if conns := tr.IdleConnStrsForTesting(); len(conns) > 0 {
5723 if d > 0 {
5724 t.Logf("idle conns %v after SetKeepAlivesEnabled called = %q; waiting for empty", d, conns)
5725 }
5726 return false
5727 }
5728 return true
5729 })
5730
5731
5732
5733
5734 }
5735
5736 func TestServerShutdown(t *testing.T) { run(t, testServerShutdown) }
5737 func testServerShutdown(t *testing.T, mode testMode) {
5738 var cst *clientServerTest
5739
5740 var once sync.Once
5741 statesRes := make(chan map[ConnState]int, 1)
5742 shutdownRes := make(chan error, 1)
5743 gotOnShutdown := make(chan struct{})
5744 handler := HandlerFunc(func(w ResponseWriter, r *Request) {
5745 first := false
5746 once.Do(func() {
5747 statesRes <- cst.ts.Config.ExportAllConnsByState()
5748 go func() {
5749 shutdownRes <- cst.ts.Config.Shutdown(context.Background())
5750 }()
5751 first = true
5752 })
5753
5754 if first {
5755
5756
5757
5758 <-gotOnShutdown
5759
5760
5761 for !t.Failed() {
5762 res, err := cst.c.Get(cst.ts.URL)
5763 if err != nil {
5764 break
5765 }
5766 out, _ := io.ReadAll(res.Body)
5767 res.Body.Close()
5768 if mode == http2Mode {
5769 t.Logf("%v: unexpected success (%q). Listener should be closed before OnShutdown is called.", cst.ts.URL, out)
5770 t.Logf("Retrying to work around https://go.dev/issue/59038.")
5771 continue
5772 }
5773 t.Errorf("%v: unexpected success (%q). Listener should be closed before OnShutdown is called.", cst.ts.URL, out)
5774 }
5775 }
5776
5777 io.WriteString(w, r.RemoteAddr)
5778 })
5779
5780 cst = newClientServerTest(t, mode, handler, func(srv *httptest.Server) {
5781 srv.Config.RegisterOnShutdown(func() { close(gotOnShutdown) })
5782 })
5783
5784 out := get(t, cst.c, cst.ts.URL)
5785 t.Logf("%v: %q", cst.ts.URL, out)
5786
5787 if err := <-shutdownRes; err != nil {
5788 t.Fatalf("Shutdown: %v", err)
5789 }
5790 <-gotOnShutdown
5791
5792 if states := <-statesRes; states[StateActive] != 1 {
5793 t.Errorf("connection in wrong state, %v", states)
5794 }
5795 }
5796
5797 func TestServerShutdownStateNew(t *testing.T) { runSynctest(t, testServerShutdownStateNew) }
5798 func testServerShutdownStateNew(t testing.TB, mode testMode) {
5799 if testing.Short() {
5800 t.Skip("test takes 5-6 seconds; skipping in short mode")
5801 }
5802
5803 listener := fakeNetListen()
5804 defer listener.Close()
5805
5806 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5807
5808 }), func(ts *httptest.Server) {
5809 ts.Listener.Close()
5810 ts.Listener = listener
5811
5812 ts.Config.ErrorLog = log.New(io.Discard, "", 0)
5813 }).ts
5814
5815
5816 c := listener.connect()
5817 defer c.Close()
5818 synctest.Wait()
5819
5820 shutdownRes := runAsync(func() (struct{}, error) {
5821 return struct{}{}, ts.Config.Shutdown(context.Background())
5822 })
5823
5824
5825
5826
5827 const expectTimeout = 5 * time.Second
5828
5829
5830 time.Sleep(expectTimeout - 1)
5831 synctest.Wait()
5832 if shutdownRes.done() {
5833 t.Fatal("shutdown too soon")
5834 }
5835 if c.IsClosedByPeer() {
5836 t.Fatal("connection was closed by server too soon")
5837 }
5838
5839
5840
5841
5842
5843 time.Sleep(2 * time.Second)
5844 synctest.Wait()
5845 if _, err := shutdownRes.result(); err != nil {
5846 t.Fatalf("Shutdown() = %v, want complete", err)
5847 }
5848 if !c.IsClosedByPeer() {
5849 t.Fatalf("connection was not closed by server after shutdown")
5850 }
5851 }
5852
5853
5854 func TestServerCloseDeadlock(t *testing.T) {
5855 var s Server
5856 s.Close()
5857 s.Close()
5858 }
5859
5860
5861
5862 func TestServerKeepAlivesEnabled(t *testing.T) { run(t, testServerKeepAlivesEnabled, testNotParallel) }
5863 func testServerKeepAlivesEnabled(t *testing.T, mode testMode) {
5864 if mode == http2Mode {
5865 restore := ExportSetH2GoawayTimeout(10 * time.Millisecond)
5866 defer restore()
5867 }
5868
5869 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}))
5870 defer cst.close()
5871 srv := cst.ts.Config
5872 srv.SetKeepAlivesEnabled(false)
5873 for try := 0; try < 2; try++ {
5874 waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
5875 if !srv.ExportAllConnsIdle() {
5876 if d > 0 {
5877 t.Logf("test server still has active conns after %v", d)
5878 }
5879 return false
5880 }
5881 return true
5882 })
5883 conns := 0
5884 var info httptrace.GotConnInfo
5885 ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
5886 GotConn: func(v httptrace.GotConnInfo) {
5887 conns++
5888 info = v
5889 },
5890 })
5891 req, err := NewRequestWithContext(ctx, "GET", cst.ts.URL, nil)
5892 if err != nil {
5893 t.Fatal(err)
5894 }
5895 res, err := cst.c.Do(req)
5896 if err != nil {
5897 t.Fatal(err)
5898 }
5899 res.Body.Close()
5900 if conns != 1 {
5901 t.Fatalf("request %v: got %v conns, want 1", try, conns)
5902 }
5903 if info.Reused || info.WasIdle {
5904 t.Fatalf("request %v: Reused=%v (want false), WasIdle=%v (want false)", try, info.Reused, info.WasIdle)
5905 }
5906 }
5907 }
5908
5909
5910
5911
5912 func TestServerCancelsReadTimeoutWhenIdle(t *testing.T) { run(t, testServerCancelsReadTimeoutWhenIdle) }
5913 func testServerCancelsReadTimeoutWhenIdle(t *testing.T, mode testMode) {
5914 runTimeSensitiveTest(t, []time.Duration{
5915 10 * time.Millisecond,
5916 50 * time.Millisecond,
5917 250 * time.Millisecond,
5918 time.Second,
5919 2 * time.Second,
5920 }, func(t *testing.T, timeout time.Duration) error {
5921 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5922 select {
5923 case <-time.After(2 * timeout):
5924 fmt.Fprint(w, "ok")
5925 case <-r.Context().Done():
5926 fmt.Fprint(w, r.Context().Err())
5927 }
5928 }), func(ts *httptest.Server) {
5929 ts.Config.ReadTimeout = timeout
5930 t.Logf("Server.Config.ReadTimeout = %v", timeout)
5931 })
5932 defer cst.close()
5933 ts := cst.ts
5934
5935 var retries atomic.Int32
5936 cst.c.Transport.(*Transport).Proxy = func(*Request) (*url.URL, error) {
5937 if retries.Add(1) != 1 {
5938 return nil, errors.New("too many retries")
5939 }
5940 return nil, nil
5941 }
5942
5943 c := ts.Client()
5944
5945 res, err := c.Get(ts.URL)
5946 if err != nil {
5947 return fmt.Errorf("Get: %v", err)
5948 }
5949 slurp, err := io.ReadAll(res.Body)
5950 res.Body.Close()
5951 if err != nil {
5952 return fmt.Errorf("Body ReadAll: %v", err)
5953 }
5954 if string(slurp) != "ok" {
5955 return fmt.Errorf("got: %q, want ok", slurp)
5956 }
5957 return nil
5958 })
5959 }
5960
5961
5962
5963
5964 func TestServerCancelsReadHeaderTimeoutWhenIdle(t *testing.T) {
5965 run(t, testServerCancelsReadHeaderTimeoutWhenIdle, []testMode{http1Mode})
5966 }
5967 func testServerCancelsReadHeaderTimeoutWhenIdle(t *testing.T, mode testMode) {
5968 runTimeSensitiveTest(t, []time.Duration{
5969 10 * time.Millisecond,
5970 50 * time.Millisecond,
5971 250 * time.Millisecond,
5972 time.Second,
5973 2 * time.Second,
5974 }, func(t *testing.T, timeout time.Duration) error {
5975 cst := newClientServerTest(t, mode, serve(200), func(ts *httptest.Server) {
5976 ts.Config.ReadHeaderTimeout = timeout
5977 ts.Config.IdleTimeout = 0
5978 })
5979 defer cst.close()
5980 ts := cst.ts
5981
5982
5983
5984 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
5985 if err != nil {
5986 t.Fatalf("dial failed: %v", err)
5987 }
5988 br := bufio.NewReader(conn)
5989 defer conn.Close()
5990
5991 if _, err := conn.Write([]byte("GET / HTTP/1.1\r\nHost: e.com\r\n\r\n")); err != nil {
5992 return fmt.Errorf("writing first request failed: %v", err)
5993 }
5994
5995 if _, err := ReadResponse(br, nil); err != nil {
5996 return fmt.Errorf("first response (before timeout) failed: %v", err)
5997 }
5998
5999
6000
6001 time.Sleep(timeout * 3 / 2)
6002
6003 if _, err := conn.Write([]byte("GET / HTTP/1.1\r\nHost: e.com\r\n\r\n")); err != nil {
6004 return fmt.Errorf("writing second request failed: %v", err)
6005 }
6006
6007 if _, err := ReadResponse(br, nil); err != nil {
6008 return fmt.Errorf("second response (after timeout) failed: %v", err)
6009 }
6010
6011 return nil
6012 })
6013 }
6014
6015
6016
6017 func runTimeSensitiveTest(t *testing.T, durations []time.Duration, test func(t *testing.T, d time.Duration) error) {
6018 for i, d := range durations {
6019 err := test(t, d)
6020 if err == nil {
6021 return
6022 }
6023 if i == len(durations)-1 || t.Failed() {
6024 t.Fatalf("failed with duration %v: %v", d, err)
6025 }
6026 t.Logf("retrying after error with duration %v: %v", d, err)
6027 }
6028 }
6029
6030
6031
6032 func TestServerDuplicateBackgroundRead(t *testing.T) {
6033 run(t, testServerDuplicateBackgroundRead, []testMode{http1Mode})
6034 }
6035 func testServerDuplicateBackgroundRead(t *testing.T, mode testMode) {
6036 if runtime.GOOS == "netbsd" && runtime.GOARCH == "arm" {
6037 testenv.SkipFlaky(t, 24826)
6038 }
6039
6040 goroutines := 5
6041 requests := 2000
6042 if testing.Short() {
6043 goroutines = 3
6044 requests = 100
6045 }
6046
6047 hts := newClientServerTest(t, mode, HandlerFunc(NotFound)).ts
6048
6049 reqBytes := []byte("GET / HTTP/1.1\r\nHost: e.com\r\n\r\n")
6050
6051 var wg sync.WaitGroup
6052 for i := 0; i < goroutines; i++ {
6053 wg.Add(1)
6054 go func() {
6055 defer wg.Done()
6056 cn, err := net.Dial("tcp", hts.Listener.Addr().String())
6057 if err != nil {
6058 t.Error(err)
6059 return
6060 }
6061 defer cn.Close()
6062
6063 wg.Add(1)
6064 go func() {
6065 defer wg.Done()
6066 io.Copy(io.Discard, cn)
6067 }()
6068
6069 for j := 0; j < requests; j++ {
6070 if t.Failed() {
6071 return
6072 }
6073 _, err := cn.Write(reqBytes)
6074 if err != nil {
6075 t.Error(err)
6076 return
6077 }
6078 }
6079 }()
6080 }
6081 wg.Wait()
6082 }
6083
6084
6085
6086
6087
6088
6089 func TestServerHijackGetsBackgroundByte(t *testing.T) {
6090 run(t, testServerHijackGetsBackgroundByte, []testMode{http1Mode})
6091 }
6092 func testServerHijackGetsBackgroundByte(t *testing.T, mode testMode) {
6093 if runtime.GOOS == "plan9" {
6094 t.Skip("skipping test; see https://golang.org/issue/18657")
6095 }
6096 done := make(chan struct{})
6097 inHandler := make(chan bool, 1)
6098 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6099 defer close(done)
6100
6101
6102 inHandler <- true
6103
6104 conn, buf, err := w.(Hijacker).Hijack()
6105 if err != nil {
6106 t.Error(err)
6107 return
6108 }
6109 defer conn.Close()
6110
6111 peek, err := buf.Reader.Peek(3)
6112 if string(peek) != "foo" || err != nil {
6113 t.Errorf("Peek = %q, %v; want foo, nil", peek, err)
6114 }
6115
6116 select {
6117 case <-r.Context().Done():
6118 t.Error("context unexpectedly canceled")
6119 default:
6120 }
6121 })).ts
6122
6123 cn, err := net.Dial("tcp", ts.Listener.Addr().String())
6124 if err != nil {
6125 t.Fatal(err)
6126 }
6127 defer cn.Close()
6128 if _, err := cn.Write([]byte("GET / HTTP/1.1\r\nHost: e.com\r\n\r\n")); err != nil {
6129 t.Fatal(err)
6130 }
6131 <-inHandler
6132 if _, err := cn.Write([]byte("foo")); err != nil {
6133 t.Fatal(err)
6134 }
6135
6136 if err := cn.(*net.TCPConn).CloseWrite(); err != nil {
6137 t.Fatal(err)
6138 }
6139 <-done
6140 }
6141
6142
6143
6144
6145 func TestServerHijackGetsBackgroundByte_big(t *testing.T) {
6146 run(t, testServerHijackGetsBackgroundByte_big, []testMode{http1Mode})
6147 }
6148 func testServerHijackGetsBackgroundByte_big(t *testing.T, mode testMode) {
6149 if runtime.GOOS == "plan9" {
6150 t.Skip("skipping test; see https://golang.org/issue/18657")
6151 }
6152 done := make(chan struct{})
6153 const size = 8 << 10
6154 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6155 defer close(done)
6156
6157 conn, buf, err := w.(Hijacker).Hijack()
6158 if err != nil {
6159 t.Error(err)
6160 return
6161 }
6162 defer conn.Close()
6163 slurp, err := io.ReadAll(buf.Reader)
6164 if err != nil {
6165 t.Errorf("Copy: %v", err)
6166 }
6167 allX := true
6168 for _, v := range slurp {
6169 if v != 'x' {
6170 allX = false
6171 }
6172 }
6173 if len(slurp) != size {
6174 t.Errorf("read %d; want %d", len(slurp), size)
6175 } else if !allX {
6176 t.Errorf("read %q; want %d 'x'", slurp, size)
6177 }
6178 })).ts
6179
6180 cn, err := net.Dial("tcp", ts.Listener.Addr().String())
6181 if err != nil {
6182 t.Fatal(err)
6183 }
6184 defer cn.Close()
6185 if _, err := fmt.Fprintf(cn, "GET / HTTP/1.1\r\nHost: e.com\r\n\r\n%s",
6186 strings.Repeat("x", size)); err != nil {
6187 t.Fatal(err)
6188 }
6189 if err := cn.(*net.TCPConn).CloseWrite(); err != nil {
6190 t.Fatal(err)
6191 }
6192
6193 <-done
6194 }
6195
6196
6197 func TestServerValidatesMethod(t *testing.T) {
6198 tests := []struct {
6199 method string
6200 want int
6201 }{
6202 {"GET", 200},
6203 {"GE(T", 400},
6204 }
6205 for _, tt := range tests {
6206 conn := newTestConn()
6207 io.WriteString(&conn.readBuf, tt.method+" / HTTP/1.1\r\nHost: foo.example\r\n\r\n")
6208
6209 ln := &oneConnListener{conn}
6210 go Serve(ln, serve(200))
6211 <-conn.closec
6212 res, err := ReadResponse(bufio.NewReader(&conn.writeBuf), nil)
6213 if err != nil {
6214 t.Errorf("For %s, ReadResponse: %v", tt.method, res)
6215 continue
6216 }
6217 if res.StatusCode != tt.want {
6218 t.Errorf("For %s, Status = %d; want %d", tt.method, res.StatusCode, tt.want)
6219 }
6220 }
6221 }
6222
6223
6224 type eofListenerNotComparable []int
6225
6226 func (eofListenerNotComparable) Accept() (net.Conn, error) { return nil, io.EOF }
6227 func (eofListenerNotComparable) Addr() net.Addr { return nil }
6228 func (eofListenerNotComparable) Close() error { return nil }
6229
6230
6231 func TestServerListenNotComparableListener(t *testing.T) {
6232 var s Server
6233 s.Serve(make(eofListenerNotComparable, 1))
6234 }
6235
6236
6237 type countCloseListener struct {
6238 net.Listener
6239 closes int32
6240 }
6241
6242 func (p *countCloseListener) Close() error {
6243 var err error
6244 if n := atomic.AddInt32(&p.closes, 1); n == 1 && p.Listener != nil {
6245 err = p.Listener.Close()
6246 }
6247 return err
6248 }
6249
6250
6251 func TestServerCloseListenerOnce(t *testing.T) {
6252 setParallel(t)
6253 defer afterTest(t)
6254
6255 ln := newLocalListener(t)
6256 defer ln.Close()
6257
6258 cl := &countCloseListener{Listener: ln}
6259 server := &Server{}
6260 sdone := make(chan bool, 1)
6261
6262 go func() {
6263 server.Serve(cl)
6264 sdone <- true
6265 }()
6266 time.Sleep(10 * time.Millisecond)
6267 server.Shutdown(context.Background())
6268 ln.Close()
6269 <-sdone
6270
6271 nclose := atomic.LoadInt32(&cl.closes)
6272 if nclose != 1 {
6273 t.Errorf("Close calls = %v; want 1", nclose)
6274 }
6275 }
6276
6277
6278 func TestServerShutdownThenServe(t *testing.T) {
6279 var srv Server
6280 cl := &countCloseListener{Listener: nil}
6281 srv.Shutdown(context.Background())
6282 got := srv.Serve(cl)
6283 if got != ErrServerClosed {
6284 t.Errorf("Serve err = %v; want ErrServerClosed", got)
6285 }
6286 nclose := atomic.LoadInt32(&cl.closes)
6287 if nclose != 1 {
6288 t.Errorf("Close calls = %v; want 1", nclose)
6289 }
6290 }
6291
6292
6293 func TestStripPortFromHost(t *testing.T) {
6294 mux := NewServeMux()
6295
6296 mux.HandleFunc("example.com/", func(w ResponseWriter, r *Request) {
6297 fmt.Fprintf(w, "OK")
6298 })
6299 mux.HandleFunc("example.com:9000/", func(w ResponseWriter, r *Request) {
6300 fmt.Fprintf(w, "uh-oh!")
6301 })
6302
6303 req := httptest.NewRequest("GET", "http://example.com:9000/", nil)
6304 rw := httptest.NewRecorder()
6305
6306 mux.ServeHTTP(rw, req)
6307
6308 response := rw.Body.String()
6309 if response != "OK" {
6310 t.Errorf("Response gotten was %q", response)
6311 }
6312 }
6313
6314 func TestServerContexts(t *testing.T) { run(t, testServerContexts) }
6315 func testServerContexts(t *testing.T, mode testMode) {
6316 type baseKey struct{}
6317 type connKey struct{}
6318 ch := make(chan context.Context, 1)
6319 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, r *Request) {
6320 ch <- r.Context()
6321 }), func(ts *httptest.Server) {
6322 ts.Config.BaseContext = func(ln net.Listener) context.Context {
6323 if strings.Contains(reflect.TypeOf(ln).String(), "onceClose") {
6324 t.Errorf("unexpected onceClose listener type %T", ln)
6325 }
6326 return context.WithValue(context.Background(), baseKey{}, "base")
6327 }
6328 ts.Config.ConnContext = func(ctx context.Context, c net.Conn) context.Context {
6329 if got, want := ctx.Value(baseKey{}), "base"; got != want {
6330 t.Errorf("in ConnContext, base context key = %#v; want %q", got, want)
6331 }
6332 return context.WithValue(ctx, connKey{}, "conn")
6333 }
6334 }).ts
6335 res, err := ts.Client().Get(ts.URL)
6336 if err != nil {
6337 t.Fatal(err)
6338 }
6339 res.Body.Close()
6340 ctx := <-ch
6341 if got, want := ctx.Value(baseKey{}), "base"; got != want {
6342 t.Errorf("base context key = %#v; want %q", got, want)
6343 }
6344 if got, want := ctx.Value(connKey{}), "conn"; got != want {
6345 t.Errorf("conn context key = %#v; want %q", got, want)
6346 }
6347 }
6348
6349
6350 func TestConnContextNotModifyingAllContexts(t *testing.T) {
6351 run(t, testConnContextNotModifyingAllContexts)
6352 }
6353 func testConnContextNotModifyingAllContexts(t *testing.T, mode testMode) {
6354 type connKey struct{}
6355 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, r *Request) {
6356 rw.Header().Set("Connection", "close")
6357 }), func(ts *httptest.Server) {
6358 ts.Config.ConnContext = func(ctx context.Context, c net.Conn) context.Context {
6359 if got := ctx.Value(connKey{}); got != nil {
6360 t.Errorf("in ConnContext, unexpected context key = %#v", got)
6361 }
6362 return context.WithValue(ctx, connKey{}, "conn")
6363 }
6364 }).ts
6365
6366 var res *Response
6367 var err error
6368
6369 res, err = ts.Client().Get(ts.URL)
6370 if err != nil {
6371 t.Fatal(err)
6372 }
6373 res.Body.Close()
6374
6375 res, err = ts.Client().Get(ts.URL)
6376 if err != nil {
6377 t.Fatal(err)
6378 }
6379 res.Body.Close()
6380 }
6381
6382
6383
6384 func TestUnsupportedTransferEncodingsReturn501(t *testing.T) {
6385 run(t, testUnsupportedTransferEncodingsReturn501, []testMode{http1Mode})
6386 }
6387 func testUnsupportedTransferEncodingsReturn501(t *testing.T, mode testMode) {
6388 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6389 w.Write([]byte("Hello, World!"))
6390 })).ts
6391
6392 serverURL, err := url.Parse(cst.URL)
6393 if err != nil {
6394 t.Fatalf("Failed to parse server URL: %v", err)
6395 }
6396
6397 unsupportedTEs := []string{
6398 "fugazi",
6399 "foo-bar",
6400 "unknown",
6401 `" chunked"`,
6402 }
6403
6404 for _, badTE := range unsupportedTEs {
6405 http1ReqBody := fmt.Sprintf(""+
6406 "POST / HTTP/1.1\r\nConnection: close\r\n"+
6407 "Host: localhost\r\nTransfer-Encoding: %s\r\n\r\n", badTE)
6408
6409 gotBody, err := fetchWireResponse(serverURL.Host, []byte(http1ReqBody))
6410 if err != nil {
6411 t.Errorf("%q. unexpected error: %v", badTE, err)
6412 continue
6413 }
6414
6415 wantBody := fmt.Sprintf("" +
6416 "HTTP/1.1 501 Not Implemented\r\nContent-Type: text/plain; charset=utf-8\r\n" +
6417 "Connection: close\r\n\r\nUnsupported transfer encoding")
6418
6419 if string(gotBody) != wantBody {
6420 t.Errorf("%q. body\ngot\n%q\nwant\n%q", badTE, gotBody, wantBody)
6421 }
6422 }
6423 }
6424
6425
6426 func TestContentEncodingNoSniffing(t *testing.T) { run(t, testContentEncodingNoSniffing) }
6427 func testContentEncodingNoSniffing(t *testing.T, mode testMode) {
6428 type setting struct {
6429 name string
6430 body []byte
6431
6432
6433
6434
6435 contentEncoding any
6436 wantContentType string
6437 }
6438
6439 settings := []*setting{
6440 {
6441 name: "gzip content-encoding, gzipped",
6442 contentEncoding: "application/gzip",
6443 wantContentType: "",
6444 body: func() []byte {
6445 buf := new(bytes.Buffer)
6446 gzw := gzip.NewWriter(buf)
6447 gzw.Write([]byte("doctype html><p>Hello</p>"))
6448 gzw.Close()
6449 return buf.Bytes()
6450 }(),
6451 },
6452 {
6453 name: "zlib content-encoding, zlibbed",
6454 contentEncoding: "application/zlib",
6455 wantContentType: "",
6456 body: func() []byte {
6457 buf := new(bytes.Buffer)
6458 zw := zlib.NewWriter(buf)
6459 zw.Write([]byte("doctype html><p>Hello</p>"))
6460 zw.Close()
6461 return buf.Bytes()
6462 }(),
6463 },
6464 {
6465 name: "no content-encoding",
6466 wantContentType: "application/x-gzip",
6467 body: func() []byte {
6468 buf := new(bytes.Buffer)
6469 gzw := gzip.NewWriter(buf)
6470 gzw.Write([]byte("doctype html><p>Hello</p>"))
6471 gzw.Close()
6472 return buf.Bytes()
6473 }(),
6474 },
6475 {
6476 name: "phony content-encoding",
6477 contentEncoding: "foo/bar",
6478 body: []byte("doctype html><p>Hello</p>"),
6479 },
6480 {
6481 name: "empty but set content-encoding",
6482 contentEncoding: "",
6483 wantContentType: "audio/mpeg",
6484 body: []byte("ID3"),
6485 },
6486 }
6487
6488 for _, tt := range settings {
6489 t.Run(tt.name, func(t *testing.T) {
6490 cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, r *Request) {
6491 if tt.contentEncoding != nil {
6492 rw.Header().Set("Content-Encoding", tt.contentEncoding.(string))
6493 }
6494 rw.Write(tt.body)
6495 }))
6496
6497 res, err := cst.c.Get(cst.ts.URL)
6498 if err != nil {
6499 t.Fatalf("Failed to fetch URL: %v", err)
6500 }
6501 defer res.Body.Close()
6502
6503 if g, w := res.Header.Get("Content-Encoding"), tt.contentEncoding; g != w {
6504 if w != nil {
6505 t.Errorf("Content-Encoding mismatch\n\tgot: %q\n\twant: %q", g, w)
6506 } else if g != "" {
6507 t.Errorf("Unexpected Content-Encoding %q", g)
6508 }
6509 }
6510
6511 if g, w := res.Header.Get("Content-Type"), tt.wantContentType; g != w {
6512 t.Errorf("Content-Type mismatch\n\tgot: %q\n\twant: %q", g, w)
6513 }
6514 })
6515 }
6516 }
6517
6518
6519
6520 func TestTimeoutHandlerSuperfluousLogs(t *testing.T) {
6521 run(t, testTimeoutHandlerSuperfluousLogs, []testMode{http1Mode})
6522 }
6523 func testTimeoutHandlerSuperfluousLogs(t *testing.T, mode testMode) {
6524 if testing.Short() {
6525 t.Skip("skipping in short mode")
6526 }
6527
6528 pc, curFile, _, _ := runtime.Caller(0)
6529 curFileBaseName := filepath.Base(curFile)
6530 testFuncName := runtime.FuncForPC(pc).Name()
6531
6532 timeoutMsg := "timed out here!"
6533
6534 tests := []struct {
6535 name string
6536 mustTimeout bool
6537 wantResp string
6538 }{
6539 {
6540 name: "return before timeout",
6541 wantResp: "HTTP/1.1 404 Not Found\r\nContent-Length: 0\r\n\r\n",
6542 },
6543 {
6544 name: "return after timeout",
6545 mustTimeout: true,
6546 wantResp: fmt.Sprintf("HTTP/1.1 503 Service Unavailable\r\nContent-Length: %d\r\n\r\n%s",
6547 len(timeoutMsg), timeoutMsg),
6548 },
6549 }
6550
6551 for _, tt := range tests {
6552 tt := tt
6553 t.Run(tt.name, func(t *testing.T) {
6554 exitHandler := make(chan bool, 1)
6555 defer close(exitHandler)
6556 lastLine := make(chan int, 1)
6557
6558 sh := HandlerFunc(func(w ResponseWriter, r *Request) {
6559 w.WriteHeader(404)
6560 w.WriteHeader(404)
6561 w.WriteHeader(404)
6562 w.WriteHeader(404)
6563 _, _, line, _ := runtime.Caller(0)
6564 lastLine <- line
6565 <-exitHandler
6566 })
6567
6568 if !tt.mustTimeout {
6569 exitHandler <- true
6570 }
6571
6572 logBuf := new(strings.Builder)
6573 srvLog := log.New(logBuf, "", 0)
6574
6575 dur := 20 * time.Millisecond
6576 if !tt.mustTimeout {
6577
6578 dur = 10 * time.Second
6579 }
6580 th := TimeoutHandler(sh, dur, timeoutMsg)
6581 cst := newClientServerTest(t, mode, th, optWithServerLog(srvLog))
6582 defer cst.close()
6583
6584 res, err := cst.c.Get(cst.ts.URL)
6585 if err != nil {
6586 t.Fatalf("Unexpected error: %v", err)
6587 }
6588
6589
6590
6591 res.Header.Del("Date")
6592 res.Header.Del("Content-Type")
6593
6594
6595 blob, _ := httputil.DumpResponse(res, true)
6596 if g, w := string(blob), tt.wantResp; g != w {
6597 t.Errorf("Response mismatch\nGot\n%q\n\nWant\n%q", g, w)
6598 }
6599
6600
6601
6602 logEntries := strings.Split(strings.TrimSpace(logBuf.String()), "\n")
6603 if g, w := len(logEntries), 3; g != w {
6604 blob, _ := json.MarshalIndent(logEntries, "", " ")
6605 t.Fatalf("Server logs count mismatch\ngot %d, want %d\n\nGot\n%s\n", g, w, blob)
6606 }
6607
6608 lastSpuriousLine := <-lastLine
6609 firstSpuriousLine := lastSpuriousLine - 3
6610
6611
6612 for i, logEntry := range logEntries {
6613 wantLine := firstSpuriousLine + i
6614 pat := fmt.Sprintf("^http: superfluous response.WriteHeader call from %s.func\\d+.\\d+ \\(%s:%d\\)$",
6615 testFuncName, curFileBaseName, wantLine)
6616 re := regexp.MustCompile(pat)
6617 if !re.MatchString(logEntry) {
6618 t.Errorf("Log entry mismatch\n\t%s\ndoes not match\n\t%s", logEntry, pat)
6619 }
6620 }
6621 })
6622 }
6623 }
6624
6625
6626
6627
6628 func fetchWireResponse(host string, http1ReqBody []byte) ([]byte, error) {
6629 conn, err := net.Dial("tcp", host)
6630 if err != nil {
6631 return nil, err
6632 }
6633 defer conn.Close()
6634
6635 if _, err := conn.Write(http1ReqBody); err != nil {
6636 return nil, err
6637 }
6638 return io.ReadAll(conn)
6639 }
6640
6641 func BenchmarkResponseStatusLine(b *testing.B) {
6642 b.ReportAllocs()
6643 b.RunParallel(func(pb *testing.PB) {
6644 bw := bufio.NewWriter(io.Discard)
6645 var buf3 [3]byte
6646 for pb.Next() {
6647 Export_writeStatusLine(bw, true, 200, buf3[:])
6648 }
6649 })
6650 }
6651
6652 func TestDisableKeepAliveUpgrade(t *testing.T) {
6653 run(t, testDisableKeepAliveUpgrade, []testMode{http1Mode})
6654 }
6655 func testDisableKeepAliveUpgrade(t *testing.T, mode testMode) {
6656 if testing.Short() {
6657 t.Skip("skipping in short mode")
6658 }
6659
6660 s := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6661 w.Header().Set("Connection", "Upgrade")
6662 w.Header().Set("Upgrade", "someProto")
6663 w.WriteHeader(StatusSwitchingProtocols)
6664 c, buf, err := w.(Hijacker).Hijack()
6665 if err != nil {
6666 return
6667 }
6668 defer c.Close()
6669
6670
6671
6672 io.Copy(c, buf)
6673 }), func(ts *httptest.Server) {
6674 ts.Config.SetKeepAlivesEnabled(false)
6675 }).ts
6676
6677 cl := s.Client()
6678 cl.Transport.(*Transport).DisableKeepAlives = true
6679
6680 resp, err := cl.Get(s.URL)
6681 if err != nil {
6682 t.Fatalf("failed to perform request: %v", err)
6683 }
6684 defer resp.Body.Close()
6685
6686 if resp.StatusCode != StatusSwitchingProtocols {
6687 t.Fatalf("unexpected status code: %v", resp.StatusCode)
6688 }
6689
6690 rwc, ok := resp.Body.(io.ReadWriteCloser)
6691 if !ok {
6692 t.Fatalf("Response.Body is not an io.ReadWriteCloser: %T", resp.Body)
6693 }
6694
6695 _, err = rwc.Write([]byte("hello"))
6696 if err != nil {
6697 t.Fatalf("failed to write to body: %v", err)
6698 }
6699
6700 b := make([]byte, 5)
6701 _, err = io.ReadFull(rwc, b)
6702 if err != nil {
6703 t.Fatalf("failed to read from body: %v", err)
6704 }
6705
6706 if string(b) != "hello" {
6707 t.Fatalf("unexpected value read from body:\ngot: %q\nwant: %q", b, "hello")
6708 }
6709 }
6710
6711 type tlogWriter struct{ t *testing.T }
6712
6713 func (w tlogWriter) Write(p []byte) (int, error) {
6714 w.t.Log(string(p))
6715 return len(p), nil
6716 }
6717
6718 func TestWriteHeaderSwitchingProtocols(t *testing.T) {
6719 run(t, testWriteHeaderSwitchingProtocols, []testMode{http1Mode})
6720 }
6721 func testWriteHeaderSwitchingProtocols(t *testing.T, mode testMode) {
6722 const wantBody = "want"
6723 const wantUpgrade = "someProto"
6724 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6725 w.Header().Set("Connection", "Upgrade")
6726 w.Header().Set("Upgrade", wantUpgrade)
6727 w.WriteHeader(StatusSwitchingProtocols)
6728 NewResponseController(w).Flush()
6729
6730
6731 w.WriteHeader(200)
6732 if _, err := w.Write([]byte("x")); err == nil {
6733 t.Errorf("Write to body after 101 Switching Protocols unexpectedly succeeded")
6734 }
6735
6736 c, _, err := NewResponseController(w).Hijack()
6737 if err != nil {
6738 t.Errorf("Hijack: %v", err)
6739 return
6740 }
6741 defer c.Close()
6742 if _, err := c.Write([]byte(wantBody)); err != nil {
6743 t.Errorf("Write to hijacked body: %v", err)
6744 }
6745 }), func(ts *httptest.Server) {
6746
6747 ts.Config.ErrorLog = log.New(tlogWriter{t}, "log: ", 0)
6748 }).ts
6749
6750 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
6751 if err != nil {
6752 t.Fatalf("net.Dial: %v", err)
6753 }
6754 _, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: foo\r\n\r\n"))
6755 if err != nil {
6756 t.Fatalf("conn.Write: %v", err)
6757 }
6758 defer conn.Close()
6759
6760 r := bufio.NewReader(conn)
6761 res, err := ReadResponse(r, &Request{Method: "GET"})
6762 if err != nil {
6763 t.Fatal("ReadResponse error:", err)
6764 }
6765 if res.StatusCode != StatusSwitchingProtocols {
6766 t.Errorf("Response StatusCode=%v, want 101", res.StatusCode)
6767 }
6768 if got := res.Header.Get("Upgrade"); got != wantUpgrade {
6769 t.Errorf("Response Upgrade header = %q, want %q", got, wantUpgrade)
6770 }
6771 body, err := io.ReadAll(r)
6772 if err != nil {
6773 t.Error(err)
6774 }
6775 if string(body) != wantBody {
6776 t.Errorf("Response body = %q, want %q", string(body), wantBody)
6777 }
6778 }
6779
6780 func TestMuxRedirectRelative(t *testing.T) {
6781 setParallel(t)
6782 req, err := ReadRequest(bufio.NewReader(strings.NewReader("GET http://example.com HTTP/1.1\r\nHost: test\r\n\r\n")))
6783 if err != nil {
6784 t.Errorf("%s", err)
6785 }
6786 mux := NewServeMux()
6787 resp := httptest.NewRecorder()
6788 mux.ServeHTTP(resp, req)
6789 if got, want := resp.Header().Get("Location"), "/"; got != want {
6790 t.Errorf("Location header expected %q; got %q", want, got)
6791 }
6792 if got, want := resp.Code, StatusMovedPermanently; got != want {
6793 t.Errorf("Expected response code %d; got %d", want, got)
6794 }
6795 }
6796
6797
6798 func TestQuerySemicolon(t *testing.T) {
6799 t.Cleanup(func() { afterTest(t) })
6800
6801 tests := []struct {
6802 query string
6803 xNoSemicolons string
6804 xWithSemicolons string
6805 expectParseFormErr bool
6806 }{
6807 {"?a=1;x=bad&x=good", "good", "bad", true},
6808 {"?a=1;b=bad&x=good", "good", "good", true},
6809 {"?a=1%3Bx=bad&x=good%3B", "good;", "good;", false},
6810 {"?a=1;x=good;x=bad", "", "good", true},
6811 }
6812
6813 run(t, func(t *testing.T, mode testMode) {
6814 for _, tt := range tests {
6815 t.Run(tt.query+"/allow=false", func(t *testing.T) {
6816 allowSemicolons := false
6817 testQuerySemicolon(t, mode, tt.query, tt.xNoSemicolons, allowSemicolons, tt.expectParseFormErr)
6818 })
6819 t.Run(tt.query+"/allow=true", func(t *testing.T) {
6820 allowSemicolons, expectParseFormErr := true, false
6821 testQuerySemicolon(t, mode, tt.query, tt.xWithSemicolons, allowSemicolons, expectParseFormErr)
6822 })
6823 }
6824 })
6825 }
6826
6827 func testQuerySemicolon(t *testing.T, mode testMode, query string, wantX string, allowSemicolons, expectParseFormErr bool) {
6828 writeBackX := func(w ResponseWriter, r *Request) {
6829 x := r.URL.Query().Get("x")
6830 if expectParseFormErr {
6831 if err := r.ParseForm(); err == nil || !strings.Contains(err.Error(), "semicolon") {
6832 t.Errorf("expected error mentioning semicolons from ParseForm, got %v", err)
6833 }
6834 } else {
6835 if err := r.ParseForm(); err != nil {
6836 t.Errorf("expected no error from ParseForm, got %v", err)
6837 }
6838 }
6839 if got := r.FormValue("x"); x != got {
6840 t.Errorf("got %q from FormValue, want %q", got, x)
6841 }
6842 fmt.Fprintf(w, "%s", x)
6843 }
6844
6845 h := Handler(HandlerFunc(writeBackX))
6846 if allowSemicolons {
6847 h = AllowQuerySemicolons(h)
6848 }
6849
6850 logBuf := &strings.Builder{}
6851 ts := newClientServerTest(t, mode, h, func(ts *httptest.Server) {
6852 ts.Config.ErrorLog = log.New(logBuf, "", 0)
6853 }).ts
6854
6855 req, _ := NewRequest("GET", ts.URL+query, nil)
6856 res, err := ts.Client().Do(req)
6857 if err != nil {
6858 t.Fatal(err)
6859 }
6860 slurp, _ := io.ReadAll(res.Body)
6861 res.Body.Close()
6862 if got, want := res.StatusCode, 200; got != want {
6863 t.Errorf("Status = %d; want = %d", got, want)
6864 }
6865 if got, want := string(slurp), wantX; got != want {
6866 t.Errorf("Body = %q; want = %q", got, want)
6867 }
6868 }
6869
6870 func TestMaxBytesHandler(t *testing.T) {
6871
6872 defer afterTest(t)
6873
6874 for _, maxSize := range []int64{100, 1_000, 1_000_000} {
6875 for _, requestSize := range []int64{100, 1_000, 1_000_000} {
6876 t.Run(fmt.Sprintf("max size %d request size %d", maxSize, requestSize),
6877 func(t *testing.T) {
6878 run(t, func(t *testing.T, mode testMode) {
6879 testMaxBytesHandler(t, mode, maxSize, requestSize)
6880 }, testNotParallel)
6881 })
6882 }
6883 }
6884 }
6885
6886 func testMaxBytesHandler(t *testing.T, mode testMode, maxSize, requestSize int64) {
6887 runTimeSensitiveTest(t, []time.Duration{
6888 1 * time.Millisecond,
6889 5 * time.Millisecond,
6890 10 * time.Millisecond,
6891 50 * time.Millisecond,
6892 100 * time.Millisecond,
6893 500 * time.Millisecond,
6894 time.Second,
6895 5 * time.Second,
6896 }, func(t *testing.T, timeout time.Duration) error {
6897 SetRSTAvoidanceDelay(t, timeout)
6898 t.Logf("set RST avoidance delay to %v", timeout)
6899
6900 var (
6901 handlerN int64
6902 handlerErr error
6903 )
6904 echo := HandlerFunc(func(w ResponseWriter, r *Request) {
6905 var buf bytes.Buffer
6906 handlerN, handlerErr = io.Copy(&buf, r.Body)
6907 io.Copy(w, &buf)
6908 })
6909
6910 cst := newClientServerTest(t, mode, MaxBytesHandler(echo, maxSize))
6911
6912
6913 defer cst.close()
6914 ts := cst.ts
6915 c := ts.Client()
6916
6917 body := strings.Repeat("a", int(requestSize))
6918 var wg sync.WaitGroup
6919 defer wg.Wait()
6920 getBody := func() (io.ReadCloser, error) {
6921 wg.Add(1)
6922 body := &wgReadCloser{
6923 Reader: strings.NewReader(body),
6924 wg: &wg,
6925 }
6926 return body, nil
6927 }
6928 reqBody, _ := getBody()
6929 req, err := NewRequest("POST", ts.URL, reqBody)
6930 if err != nil {
6931 reqBody.Close()
6932 t.Fatal(err)
6933 }
6934 req.ContentLength = int64(len(body))
6935 req.GetBody = getBody
6936 req.Header.Set("Content-Type", "text/plain")
6937
6938 var buf strings.Builder
6939 res, err := c.Do(req)
6940 if err != nil {
6941 return fmt.Errorf("unexpected connection error: %v", err)
6942 } else {
6943 _, err = io.Copy(&buf, res.Body)
6944 res.Body.Close()
6945 if err != nil {
6946 return fmt.Errorf("unexpected read error: %v", err)
6947 }
6948 }
6949
6950
6951
6952
6953 if handlerN > maxSize {
6954 t.Errorf("expected max request body %d; got %d", maxSize, handlerN)
6955 }
6956 if requestSize > maxSize && handlerErr == nil {
6957 t.Error("expected error on handler side; got nil")
6958 }
6959 if requestSize <= maxSize {
6960 if handlerErr != nil {
6961 t.Errorf("%d expected nil error on handler side; got %v", requestSize, handlerErr)
6962 }
6963 if handlerN != requestSize {
6964 t.Errorf("expected request of size %d; got %d", requestSize, handlerN)
6965 }
6966 }
6967 if buf.Len() != int(handlerN) {
6968 t.Errorf("expected echo of size %d; got %d", handlerN, buf.Len())
6969 }
6970
6971 return nil
6972 })
6973 }
6974
6975 func TestEarlyHints(t *testing.T) {
6976 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
6977 h := w.Header()
6978 h.Add("Link", "</style.css>; rel=preload; as=style")
6979 h.Add("Link", "</script.js>; rel=preload; as=script")
6980 w.WriteHeader(StatusEarlyHints)
6981
6982 h.Add("Link", "</foo.js>; rel=preload; as=script")
6983 w.WriteHeader(StatusEarlyHints)
6984
6985 w.Write([]byte("stuff"))
6986 }))
6987
6988 got := ht.rawResponse("GET / HTTP/1.1\nHost: golang.org")
6989 expected := "HTTP/1.1 103 Early Hints\r\nLink: </style.css>; rel=preload; as=style\r\nLink: </script.js>; rel=preload; as=script\r\n\r\nHTTP/1.1 103 Early Hints\r\nLink: </style.css>; rel=preload; as=style\r\nLink: </script.js>; rel=preload; as=script\r\nLink: </foo.js>; rel=preload; as=script\r\n\r\nHTTP/1.1 200 OK\r\nLink: </style.css>; rel=preload; as=style\r\nLink: </script.js>; rel=preload; as=script\r\nLink: </foo.js>; rel=preload; as=script\r\nDate: "
6990 if !strings.Contains(got, expected) {
6991 t.Errorf("unexpected response; got %q; should start by %q", got, expected)
6992 }
6993 }
6994 func TestProcessing(t *testing.T) {
6995 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
6996 w.WriteHeader(StatusProcessing)
6997 w.Write([]byte("stuff"))
6998 }))
6999
7000 got := ht.rawResponse("GET / HTTP/1.1\nHost: golang.org")
7001 expected := "HTTP/1.1 102 Processing\r\n\r\nHTTP/1.1 200 OK\r\nDate: "
7002 if !strings.Contains(got, expected) {
7003 t.Errorf("unexpected response; got %q; should start by %q", got, expected)
7004 }
7005 }
7006
7007 func TestParseFormCleanup(t *testing.T) { run(t, testParseFormCleanup) }
7008 func testParseFormCleanup(t *testing.T, mode testMode) {
7009 if mode == http2Mode {
7010 t.Skip("https://go.dev/issue/20253")
7011 }
7012
7013 const maxMemory = 1024
7014 const key = "file"
7015
7016 if runtime.GOOS == "windows" {
7017
7018 t.Skip("https://go.dev/issue/25965")
7019 }
7020
7021 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7022 r.ParseMultipartForm(maxMemory)
7023 f, _, err := r.FormFile(key)
7024 if err != nil {
7025 t.Errorf("r.FormFile(%q) = %v", key, err)
7026 return
7027 }
7028 of, ok := f.(*os.File)
7029 if !ok {
7030 t.Errorf("r.FormFile(%q) returned type %T, want *os.File", key, f)
7031 return
7032 }
7033 w.Write([]byte(of.Name()))
7034 }))
7035
7036 fBuf := new(bytes.Buffer)
7037 mw := multipart.NewWriter(fBuf)
7038 mf, err := mw.CreateFormFile(key, "myfile.txt")
7039 if err != nil {
7040 t.Fatal(err)
7041 }
7042 if _, err := mf.Write(bytes.Repeat([]byte("A"), maxMemory*2)); err != nil {
7043 t.Fatal(err)
7044 }
7045 if err := mw.Close(); err != nil {
7046 t.Fatal(err)
7047 }
7048 req, err := NewRequest("POST", cst.ts.URL, fBuf)
7049 if err != nil {
7050 t.Fatal(err)
7051 }
7052 req.Header.Set("Content-Type", mw.FormDataContentType())
7053 res, err := cst.c.Do(req)
7054 if err != nil {
7055 t.Fatal(err)
7056 }
7057 defer res.Body.Close()
7058 fname, err := io.ReadAll(res.Body)
7059 if err != nil {
7060 t.Fatal(err)
7061 }
7062 cst.close()
7063 if _, err := os.Stat(string(fname)); !errors.Is(err, os.ErrNotExist) {
7064 t.Errorf("file %q exists after HTTP handler returned", string(fname))
7065 }
7066 }
7067
7068 func TestHeadBody(t *testing.T) {
7069 const identityMode = false
7070 const chunkedMode = true
7071 run(t, func(t *testing.T, mode testMode) {
7072 t.Run("identity", func(t *testing.T) { testHeadBody(t, mode, identityMode, "HEAD") })
7073 t.Run("chunked", func(t *testing.T) { testHeadBody(t, mode, chunkedMode, "HEAD") })
7074 })
7075 }
7076
7077 func TestGetBody(t *testing.T) {
7078 const identityMode = false
7079 const chunkedMode = true
7080 run(t, func(t *testing.T, mode testMode) {
7081 t.Run("identity", func(t *testing.T) { testHeadBody(t, mode, identityMode, "GET") })
7082 t.Run("chunked", func(t *testing.T) { testHeadBody(t, mode, chunkedMode, "GET") })
7083 })
7084 }
7085
7086 func testHeadBody(t *testing.T, mode testMode, chunked bool, method string) {
7087 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7088 b, err := io.ReadAll(r.Body)
7089 if err != nil {
7090 t.Errorf("server reading body: %v", err)
7091 return
7092 }
7093 w.Header().Set("X-Request-Body", string(b))
7094 w.Header().Set("Content-Length", "0")
7095 }))
7096 defer cst.close()
7097 for _, reqBody := range []string{
7098 "",
7099 "",
7100 "request_body",
7101 "",
7102 } {
7103 var bodyReader io.Reader
7104 if reqBody != "" {
7105 bodyReader = strings.NewReader(reqBody)
7106 if chunked {
7107 bodyReader = bufio.NewReader(bodyReader)
7108 }
7109 }
7110 req, err := NewRequest(method, cst.ts.URL, bodyReader)
7111 if err != nil {
7112 t.Fatal(err)
7113 }
7114 res, err := cst.c.Do(req)
7115 if err != nil {
7116 t.Fatal(err)
7117 }
7118 res.Body.Close()
7119 if got, want := res.StatusCode, 200; got != want {
7120 t.Errorf("%v request with %d-byte body: StatusCode = %v, want %v", method, len(reqBody), got, want)
7121 }
7122 if got, want := res.Header.Get("X-Request-Body"), reqBody; got != want {
7123 t.Errorf("%v request with %d-byte body: handler read body %q, want %q", method, len(reqBody), got, want)
7124 }
7125 }
7126 }
7127
7128
7129
7130 func TestDisableContentLength(t *testing.T) { run(t, testDisableContentLength) }
7131 func testDisableContentLength(t *testing.T, mode testMode) {
7132 noCL := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7133 w.Header()["Content-Length"] = nil
7134 fmt.Fprintf(w, "OK")
7135 }))
7136
7137 res, err := noCL.c.Get(noCL.ts.URL)
7138 if err != nil {
7139 t.Fatal(err)
7140 }
7141 if got, haveCL := res.Header["Content-Length"]; haveCL {
7142 t.Errorf("Unexpected Content-Length: %q", got)
7143 }
7144 if err := res.Body.Close(); err != nil {
7145 t.Fatal(err)
7146 }
7147
7148 withCL := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7149 fmt.Fprintf(w, "OK")
7150 }))
7151
7152 res, err = withCL.c.Get(withCL.ts.URL)
7153 if err != nil {
7154 t.Fatal(err)
7155 }
7156 if got := res.Header.Get("Content-Length"); got != "2" {
7157 t.Errorf("Content-Length: %q; want 2", got)
7158 }
7159 if err := res.Body.Close(); err != nil {
7160 t.Fatal(err)
7161 }
7162 }
7163
7164 func TestErrorContentLength(t *testing.T) { run(t, testErrorContentLength) }
7165 func testErrorContentLength(t *testing.T, mode testMode) {
7166 const errorBody = "an error occurred"
7167 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7168 w.Header().Set("Content-Length", "1000")
7169 Error(w, errorBody, 400)
7170 }))
7171 res, err := cst.c.Get(cst.ts.URL)
7172 if err != nil {
7173 t.Fatalf("Get(%q) = %v", cst.ts.URL, err)
7174 }
7175 defer res.Body.Close()
7176 body, err := io.ReadAll(res.Body)
7177 if err != nil {
7178 t.Fatalf("io.ReadAll(res.Body) = %v", err)
7179 }
7180 if string(body) != errorBody+"\n" {
7181 t.Fatalf("read body: %q, want %q", string(body), errorBody)
7182 }
7183 }
7184
7185 func TestError(t *testing.T) {
7186 w := httptest.NewRecorder()
7187 w.Header().Set("Content-Length", "1")
7188 w.Header().Set("X-Content-Type-Options", "scratch and sniff")
7189 w.Header().Set("Other", "foo")
7190 Error(w, "oops", 432)
7191
7192 h := w.Header()
7193 for _, hdr := range []string{"Content-Length"} {
7194 if v, ok := h[hdr]; ok {
7195 t.Errorf("%s: %q, want not present", hdr, v)
7196 }
7197 }
7198 if v := h.Get("Content-Type"); v != "text/plain; charset=utf-8" {
7199 t.Errorf("Content-Type: %q, want %q", v, "text/plain; charset=utf-8")
7200 }
7201 if v := h.Get("X-Content-Type-Options"); v != "nosniff" {
7202 t.Errorf("X-Content-Type-Options: %q, want %q", v, "nosniff")
7203 }
7204 }
7205
7206 func TestServerReadAfterWriteHeader100Continue(t *testing.T) {
7207 run(t, testServerReadAfterWriteHeader100Continue)
7208 }
7209 func testServerReadAfterWriteHeader100Continue(t *testing.T, mode testMode) {
7210 t.Skip("https://go.dev/issue/67555")
7211 body := []byte("body")
7212 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7213 w.WriteHeader(200)
7214 NewResponseController(w).Flush()
7215 io.ReadAll(r.Body)
7216 w.Write(body)
7217 }), func(tr *Transport) {
7218 tr.ExpectContinueTimeout = 24 * time.Hour
7219 })
7220
7221 req, _ := NewRequest("GET", cst.ts.URL, strings.NewReader("body"))
7222 req.Header.Set("Expect", "100-continue")
7223 res, err := cst.c.Do(req)
7224 if err != nil {
7225 t.Fatalf("Get(%q) = %v", cst.ts.URL, err)
7226 }
7227 defer res.Body.Close()
7228 got, err := io.ReadAll(res.Body)
7229 if err != nil {
7230 t.Fatalf("io.ReadAll(res.Body) = %v", err)
7231 }
7232 if !bytes.Equal(got, body) {
7233 t.Fatalf("response body = %q, want %q", got, body)
7234 }
7235 }
7236
7237 func TestServerReadAfterHandlerDone100Continue(t *testing.T) {
7238 run(t, testServerReadAfterHandlerDone100Continue)
7239 }
7240 func testServerReadAfterHandlerDone100Continue(t *testing.T, mode testMode) {
7241 t.Skip("https://go.dev/issue/67555")
7242 readyc := make(chan struct{})
7243 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7244 go func() {
7245 <-readyc
7246 io.ReadAll(r.Body)
7247 <-readyc
7248 }()
7249 }), func(tr *Transport) {
7250 tr.ExpectContinueTimeout = 24 * time.Hour
7251 })
7252
7253 req, _ := NewRequest("GET", cst.ts.URL, strings.NewReader("body"))
7254 req.Header.Set("Expect", "100-continue")
7255 res, err := cst.c.Do(req)
7256 if err != nil {
7257 t.Fatalf("Get(%q) = %v", cst.ts.URL, err)
7258 }
7259 res.Body.Close()
7260 readyc <- struct{}{}
7261 readyc <- struct{}{}
7262 }
7263
7264 func TestServerReadAfterHandlerAbort100Continue(t *testing.T) {
7265 run(t, testServerReadAfterHandlerAbort100Continue)
7266 }
7267 func testServerReadAfterHandlerAbort100Continue(t *testing.T, mode testMode) {
7268 t.Skip("https://go.dev/issue/67555")
7269 readyc := make(chan struct{})
7270 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7271 go func() {
7272 <-readyc
7273 io.ReadAll(r.Body)
7274 <-readyc
7275 }()
7276 panic(ErrAbortHandler)
7277 }), func(tr *Transport) {
7278 tr.ExpectContinueTimeout = 24 * time.Hour
7279 })
7280
7281 req, _ := NewRequest("GET", cst.ts.URL, strings.NewReader("body"))
7282 req.Header.Set("Expect", "100-continue")
7283 res, err := cst.c.Do(req)
7284 if err == nil {
7285 res.Body.Close()
7286 }
7287 readyc <- struct{}{}
7288 readyc <- struct{}{}
7289 }
7290
7291 func TestInvalidChunkedBodies(t *testing.T) {
7292 for _, test := range []struct {
7293 name string
7294 b string
7295 }{{
7296 name: "bare LF in chunk size",
7297 b: "1\na\r\n0\r\n\r\n",
7298 }, {
7299 name: "bare LF at body end",
7300 b: "1\r\na\r\n0\r\n\n",
7301 }} {
7302 t.Run(test.name, func(t *testing.T) {
7303 reqc := make(chan error)
7304 ts := newClientServerTest(t, http1Mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7305 got, err := io.ReadAll(r.Body)
7306 if err == nil {
7307 t.Logf("read body: %q", got)
7308 }
7309 reqc <- err
7310 })).ts
7311
7312 serverURL, err := url.Parse(ts.URL)
7313 if err != nil {
7314 t.Fatal(err)
7315 }
7316
7317 conn, err := net.Dial("tcp", serverURL.Host)
7318 if err != nil {
7319 t.Fatal(err)
7320 }
7321
7322 if _, err := conn.Write([]byte(
7323 "POST / HTTP/1.1\r\n" +
7324 "Host: localhost\r\n" +
7325 "Transfer-Encoding: chunked\r\n" +
7326 "Connection: close\r\n" +
7327 "\r\n" +
7328 test.b)); err != nil {
7329 t.Fatal(err)
7330 }
7331 conn.(*net.TCPConn).CloseWrite()
7332
7333 if err := <-reqc; err == nil {
7334 t.Errorf("server handler: io.ReadAll(r.Body) succeeded, want error")
7335 }
7336 })
7337 }
7338 }
7339
7340
7341 func TestServerTLSNextProtos(t *testing.T) {
7342 run(t, testServerTLSNextProtos, []testMode{https1Mode, http2Mode})
7343 }
7344 func testServerTLSNextProtos(t *testing.T, mode testMode) {
7345 CondSkipHTTP2(t)
7346
7347 cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
7348 if err != nil {
7349 t.Fatal(err)
7350 }
7351 leafCert, err := x509.ParseCertificate(cert.Certificate[0])
7352 if err != nil {
7353 t.Fatal(err)
7354 }
7355 certpool := x509.NewCertPool()
7356 certpool.AddCert(leafCert)
7357
7358 protos := new(Protocols)
7359 switch mode {
7360 case https1Mode:
7361 protos.SetHTTP1(true)
7362 case http2Mode:
7363 protos.SetHTTP2(true)
7364 }
7365
7366 wantNextProtos := []string{"http/1.1", "h2", "other"}
7367 nextProtos := slices.Clone(wantNextProtos)
7368
7369
7370 srv := &Server{
7371 TLSConfig: &tls.Config{
7372 Certificates: []tls.Certificate{cert},
7373 NextProtos: nextProtos,
7374 },
7375 Handler: HandlerFunc(func(w ResponseWriter, req *Request) {}),
7376 Protocols: protos,
7377 }
7378 tr := &Transport{
7379 TLSClientConfig: &tls.Config{
7380 RootCAs: certpool,
7381 NextProtos: nextProtos,
7382 },
7383 Protocols: protos,
7384 }
7385
7386 listener := newLocalListener(t)
7387 srvc := make(chan error, 1)
7388 go func() {
7389 srvc <- srv.ServeTLS(listener, "", "")
7390 }()
7391 t.Cleanup(func() {
7392 srv.Close()
7393 <-srvc
7394 })
7395
7396 client := &Client{Transport: tr}
7397 resp, err := client.Get("https://" + listener.Addr().String())
7398 if err != nil {
7399 t.Fatal(err)
7400 }
7401 resp.Body.Close()
7402
7403 if !slices.Equal(nextProtos, wantNextProtos) {
7404 t.Fatalf("after running test: original NextProtos slice = %v, want %v", nextProtos, wantNextProtos)
7405 }
7406 }
7407
View as plain text