Source file
src/net/http/serve_test.go
1
2
3
4
5
6
7 package http_test
8
9 import (
10 "bufio"
11 "bytes"
12 "compress/gzip"
13 "compress/zlib"
14 "context"
15 "crypto/tls"
16 "crypto/x509"
17 "encoding/json"
18 "errors"
19 "fmt"
20 "internal/synctest"
21 "internal/testenv"
22 "io"
23 "log"
24 "math/rand"
25 "mime/multipart"
26 "net"
27 . "net/http"
28 "net/http/httptest"
29 "net/http/httptrace"
30 "net/http/httputil"
31 "net/http/internal"
32 "net/http/internal/testcert"
33 "net/url"
34 "os"
35 "path/filepath"
36 "reflect"
37 "regexp"
38 "runtime"
39 "slices"
40 "strconv"
41 "strings"
42 "sync"
43 "sync/atomic"
44 "syscall"
45 "testing"
46 "time"
47 )
48
49 type dummyAddr string
50 type oneConnListener struct {
51 conn net.Conn
52 }
53
54 func (l *oneConnListener) Accept() (c net.Conn, err error) {
55 c = l.conn
56 if c == nil {
57 err = io.EOF
58 return
59 }
60 err = nil
61 l.conn = nil
62 return
63 }
64
65 func (l *oneConnListener) Close() error {
66 return nil
67 }
68
69 func (l *oneConnListener) Addr() net.Addr {
70 return dummyAddr("test-address")
71 }
72
73 func (a dummyAddr) Network() string {
74 return string(a)
75 }
76
77 func (a dummyAddr) String() string {
78 return string(a)
79 }
80
81 type noopConn struct{}
82
83 func (noopConn) LocalAddr() net.Addr { return dummyAddr("local-addr") }
84 func (noopConn) RemoteAddr() net.Addr { return dummyAddr("remote-addr") }
85 func (noopConn) SetDeadline(t time.Time) error { return nil }
86 func (noopConn) SetReadDeadline(t time.Time) error { return nil }
87 func (noopConn) SetWriteDeadline(t time.Time) error { return nil }
88
89 type rwTestConn struct {
90 io.Reader
91 io.Writer
92 noopConn
93
94 closeFunc func() error
95 closec chan bool
96 }
97
98 func (c *rwTestConn) Close() error {
99 if c.closeFunc != nil {
100 return c.closeFunc()
101 }
102 select {
103 case c.closec <- true:
104 default:
105 }
106 return nil
107 }
108
109 type testConn struct {
110 readMu sync.Mutex
111 readBuf bytes.Buffer
112 writeBuf bytes.Buffer
113 closec chan bool
114 noopConn
115 }
116
117 func newTestConn() *testConn {
118 return &testConn{closec: make(chan bool, 1)}
119 }
120
121 func (c *testConn) Read(b []byte) (int, error) {
122 c.readMu.Lock()
123 defer c.readMu.Unlock()
124 return c.readBuf.Read(b)
125 }
126
127 func (c *testConn) Write(b []byte) (int, error) {
128 return c.writeBuf.Write(b)
129 }
130
131 func (c *testConn) Close() error {
132 select {
133 case c.closec <- true:
134 default:
135 }
136 return nil
137 }
138
139
140
141 func reqBytes(req string) []byte {
142 return []byte(strings.ReplaceAll(strings.TrimSpace(req), "\n", "\r\n") + "\r\n\r\n")
143 }
144
145 type handlerTest struct {
146 logbuf bytes.Buffer
147 handler Handler
148 }
149
150 func newHandlerTest(h Handler) handlerTest {
151 return handlerTest{handler: h}
152 }
153
154 func (ht *handlerTest) rawResponse(req string) string {
155 reqb := reqBytes(req)
156 var output strings.Builder
157 conn := &rwTestConn{
158 Reader: bytes.NewReader(reqb),
159 Writer: &output,
160 closec: make(chan bool, 1),
161 }
162 ln := &oneConnListener{conn: conn}
163 srv := &Server{
164 ErrorLog: log.New(&ht.logbuf, "", 0),
165 Handler: ht.handler,
166 }
167 go srv.Serve(ln)
168 <-conn.closec
169 return output.String()
170 }
171
172 func TestConsumingBodyOnNextConn(t *testing.T) {
173 t.Parallel()
174 defer afterTest(t)
175 conn := new(testConn)
176 for i := 0; i < 2; i++ {
177 conn.readBuf.Write([]byte(
178 "POST / HTTP/1.1\r\n" +
179 "Host: test\r\n" +
180 "Content-Length: 11\r\n" +
181 "\r\n" +
182 "foo=1&bar=1"))
183 }
184
185 reqNum := 0
186 ch := make(chan *Request)
187 servech := make(chan error)
188 listener := &oneConnListener{conn}
189 handler := func(res ResponseWriter, req *Request) {
190 reqNum++
191 ch <- req
192 }
193
194 go func() {
195 servech <- Serve(listener, HandlerFunc(handler))
196 }()
197
198 var req *Request
199 req = <-ch
200 if req == nil {
201 t.Fatal("Got nil first request.")
202 }
203 if req.Method != "POST" {
204 t.Errorf("For request #1's method, got %q; expected %q",
205 req.Method, "POST")
206 }
207
208 req = <-ch
209 if req == nil {
210 t.Fatal("Got nil first request.")
211 }
212 if req.Method != "POST" {
213 t.Errorf("For request #2's method, got %q; expected %q",
214 req.Method, "POST")
215 }
216
217 if serveerr := <-servech; serveerr != io.EOF {
218 t.Errorf("Serve returned %q; expected EOF", serveerr)
219 }
220 }
221
222 type stringHandler string
223
224 func (s stringHandler) ServeHTTP(w ResponseWriter, r *Request) {
225 w.Header().Set("Result", string(s))
226 }
227
228 var handlers = []struct {
229 pattern string
230 msg string
231 }{
232 {"/", "Default"},
233 {"/someDir/", "someDir"},
234 {"/#/", "hash"},
235 {"someHost.com/someDir/", "someHost.com/someDir"},
236 }
237
238 var vtests = []struct {
239 url string
240 expected string
241 }{
242 {"http://localhost/someDir/apage", "someDir"},
243 {"http://localhost/%23/apage", "hash"},
244 {"http://localhost/otherDir/apage", "Default"},
245 {"http://someHost.com/someDir/apage", "someHost.com/someDir"},
246 {"http://otherHost.com/someDir/apage", "someDir"},
247 {"http://otherHost.com/aDir/apage", "Default"},
248
249 {"http://localhost/someDir", "/someDir/"},
250 {"http://localhost/%23", "/%23/"},
251 {"http://someHost.com/someDir", "/someDir/"},
252 }
253
254 func TestHostHandlers(t *testing.T) { run(t, testHostHandlers, []testMode{http1Mode}) }
255 func testHostHandlers(t *testing.T, mode testMode) {
256 mux := NewServeMux()
257 for _, h := range handlers {
258 mux.Handle(h.pattern, stringHandler(h.msg))
259 }
260 ts := newClientServerTest(t, mode, mux).ts
261
262 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
263 if err != nil {
264 t.Fatal(err)
265 }
266 defer conn.Close()
267 cc := httputil.NewClientConn(conn, nil)
268 for _, vt := range vtests {
269 var r *Response
270 var req Request
271 if req.URL, err = url.Parse(vt.url); err != nil {
272 t.Errorf("cannot parse url: %v", err)
273 continue
274 }
275 if err := cc.Write(&req); err != nil {
276 t.Errorf("writing request: %v", err)
277 continue
278 }
279 r, err := cc.Read(&req)
280 if err != nil {
281 t.Errorf("reading response: %v", err)
282 continue
283 }
284 switch r.StatusCode {
285 case StatusOK:
286 s := r.Header.Get("Result")
287 if s != vt.expected {
288 t.Errorf("Get(%q) = %q, want %q", vt.url, s, vt.expected)
289 }
290 case StatusMovedPermanently:
291 s := r.Header.Get("Location")
292 if s != vt.expected {
293 t.Errorf("Get(%q) = %q, want %q", vt.url, s, vt.expected)
294 }
295 default:
296 t.Errorf("Get(%q) unhandled status code %d", vt.url, r.StatusCode)
297 }
298 }
299 }
300
301 var serveMuxRegister = []struct {
302 pattern string
303 h Handler
304 }{
305 {"/dir/", serve(200)},
306 {"/search", serve(201)},
307 {"codesearch.google.com/search", serve(202)},
308 {"codesearch.google.com/", serve(203)},
309 {"example.com/", HandlerFunc(checkQueryStringHandler)},
310 }
311
312
313 func serve(code int) HandlerFunc {
314 return func(w ResponseWriter, r *Request) {
315 w.WriteHeader(code)
316 }
317 }
318
319
320
321
322 func checkQueryStringHandler(w ResponseWriter, r *Request) {
323 u := *r.URL
324 u.Scheme = "http"
325 u.Host = r.Host
326 u.RawQuery = ""
327 if "http://"+r.URL.RawQuery == u.String() {
328 w.WriteHeader(200)
329 } else {
330 w.WriteHeader(500)
331 }
332 }
333
334 var serveMuxTests = []struct {
335 method string
336 host string
337 path string
338 code int
339 pattern string
340 }{
341 {"GET", "google.com", "/", 404, ""},
342 {"GET", "google.com", "/dir", 301, "/dir/"},
343 {"GET", "google.com", "/dir/", 200, "/dir/"},
344 {"GET", "google.com", "/dir/file", 200, "/dir/"},
345 {"GET", "google.com", "/search", 201, "/search"},
346 {"GET", "google.com", "/search/", 404, ""},
347 {"GET", "google.com", "/search/foo", 404, ""},
348 {"GET", "codesearch.google.com", "/search", 202, "codesearch.google.com/search"},
349 {"GET", "codesearch.google.com", "/search/", 203, "codesearch.google.com/"},
350 {"GET", "codesearch.google.com", "/search/foo", 203, "codesearch.google.com/"},
351 {"GET", "codesearch.google.com", "/", 203, "codesearch.google.com/"},
352 {"GET", "codesearch.google.com:443", "/", 203, "codesearch.google.com/"},
353 {"GET", "images.google.com", "/search", 201, "/search"},
354 {"GET", "images.google.com", "/search/", 404, ""},
355 {"GET", "images.google.com", "/search/foo", 404, ""},
356 {"GET", "google.com", "/../search", 301, "/search"},
357 {"GET", "google.com", "/dir/..", 301, ""},
358 {"GET", "google.com", "/dir/..", 301, ""},
359 {"GET", "google.com", "/dir/./file", 301, "/dir/"},
360
361
362
363 {"CONNECT", "google.com", "/dir", 301, "/dir/"},
364 {"CONNECT", "google.com", "/../search", 404, ""},
365 {"CONNECT", "google.com", "/dir/..", 200, "/dir/"},
366 {"CONNECT", "google.com", "/dir/..", 200, "/dir/"},
367 {"CONNECT", "google.com", "/dir/./file", 200, "/dir/"},
368 }
369
370 func TestServeMuxHandler(t *testing.T) {
371 setParallel(t)
372 mux := NewServeMux()
373 for _, e := range serveMuxRegister {
374 mux.Handle(e.pattern, e.h)
375 }
376
377 for _, tt := range serveMuxTests {
378 r := &Request{
379 Method: tt.method,
380 Host: tt.host,
381 URL: &url.URL{
382 Path: tt.path,
383 },
384 }
385 h, pattern := mux.Handler(r)
386 rr := httptest.NewRecorder()
387 h.ServeHTTP(rr, r)
388 if pattern != tt.pattern || rr.Code != tt.code {
389 t.Errorf("%s %s %s = %d, %q, want %d, %q", tt.method, tt.host, tt.path, rr.Code, pattern, tt.code, tt.pattern)
390 }
391 }
392 }
393
394
395 func TestServeMuxHandleFuncWithNilHandler(t *testing.T) {
396 setParallel(t)
397 defer func() {
398 if err := recover(); err == nil {
399 t.Error("expected call to mux.HandleFunc to panic")
400 }
401 }()
402 mux := NewServeMux()
403 mux.HandleFunc("/", nil)
404 }
405
406 var serveMuxTests2 = []struct {
407 method string
408 host string
409 url string
410 code int
411 redirOk bool
412 }{
413 {"GET", "google.com", "/", 404, false},
414 {"GET", "example.com", "/test/?example.com/test/", 200, false},
415 {"GET", "example.com", "test/?example.com/test/", 200, true},
416 }
417
418
419
420 func TestServeMuxHandlerRedirects(t *testing.T) {
421 setParallel(t)
422 mux := NewServeMux()
423 for _, e := range serveMuxRegister {
424 mux.Handle(e.pattern, e.h)
425 }
426
427 for _, tt := range serveMuxTests2 {
428 tries := 1
429 turl := tt.url
430 for {
431 u, e := url.Parse(turl)
432 if e != nil {
433 t.Fatal(e)
434 }
435 r := &Request{
436 Method: tt.method,
437 Host: tt.host,
438 URL: u,
439 }
440 h, _ := mux.Handler(r)
441 rr := httptest.NewRecorder()
442 h.ServeHTTP(rr, r)
443 if rr.Code != 301 {
444 if rr.Code != tt.code {
445 t.Errorf("%s %s %s = %d, want %d", tt.method, tt.host, tt.url, rr.Code, tt.code)
446 }
447 break
448 }
449 if !tt.redirOk {
450 t.Errorf("%s %s %s, unexpected redirect", tt.method, tt.host, tt.url)
451 break
452 }
453 turl = rr.HeaderMap.Get("Location")
454 tries--
455 }
456 if tries < 0 {
457 t.Errorf("%s %s %s, too many redirects", tt.method, tt.host, tt.url)
458 }
459 }
460 }
461
462
463 func TestMuxRedirectLeadingSlashes(t *testing.T) {
464 setParallel(t)
465 paths := []string{"//foo.txt", "///foo.txt", "/../../foo.txt"}
466 for _, path := range paths {
467 req, err := ReadRequest(bufio.NewReader(strings.NewReader("GET " + path + " HTTP/1.1\r\nHost: test\r\n\r\n")))
468 if err != nil {
469 t.Errorf("%s", err)
470 }
471 mux := NewServeMux()
472 resp := httptest.NewRecorder()
473
474 mux.ServeHTTP(resp, req)
475
476 if loc, expected := resp.Header().Get("Location"), "/foo.txt"; loc != expected {
477 t.Errorf("Expected Location header set to %q; got %q", expected, loc)
478 return
479 }
480
481 if code, expected := resp.Code, StatusMovedPermanently; code != expected {
482 t.Errorf("Expected response code of StatusMovedPermanently; got %d", code)
483 return
484 }
485 }
486 }
487
488
489
490
491
492 func TestServeWithSlashRedirectKeepsQueryString(t *testing.T) {
493 run(t, testServeWithSlashRedirectKeepsQueryString, []testMode{http1Mode})
494 }
495 func testServeWithSlashRedirectKeepsQueryString(t *testing.T, mode testMode) {
496 writeBackQuery := func(w ResponseWriter, r *Request) {
497 fmt.Fprintf(w, "%s", r.URL.RawQuery)
498 }
499
500 mux := NewServeMux()
501 mux.HandleFunc("/testOne", writeBackQuery)
502 mux.HandleFunc("/testTwo/", writeBackQuery)
503 mux.HandleFunc("/testThree", writeBackQuery)
504 mux.HandleFunc("/testThree/", func(w ResponseWriter, r *Request) {
505 fmt.Fprintf(w, "%s:bar", r.URL.RawQuery)
506 })
507
508 ts := newClientServerTest(t, mode, mux).ts
509
510 tests := [...]struct {
511 path string
512 method string
513 want string
514 statusOk bool
515 }{
516 0: {"/testOne?this=that", "GET", "this=that", true},
517 1: {"/testTwo?foo=bar", "GET", "foo=bar", true},
518 2: {"/testTwo?a=1&b=2&a=3", "GET", "a=1&b=2&a=3", true},
519 3: {"/testTwo?", "GET", "", true},
520 4: {"/testThree?foo", "GET", "foo", true},
521 5: {"/testThree/?foo", "GET", "foo:bar", true},
522 6: {"/testThree?foo", "CONNECT", "foo", true},
523 7: {"/testThree/?foo", "CONNECT", "foo:bar", true},
524
525
526 8: {"/testOne/foo/..?foo", "GET", "foo", true},
527 9: {"/testOne/foo/..?foo", "CONNECT", "404 page not found\n", false},
528 }
529
530 for i, tt := range tests {
531 req, _ := NewRequest(tt.method, ts.URL+tt.path, nil)
532 res, err := ts.Client().Do(req)
533 if err != nil {
534 continue
535 }
536 slurp, _ := io.ReadAll(res.Body)
537 res.Body.Close()
538 if !tt.statusOk {
539 if got, want := res.StatusCode, 404; got != want {
540 t.Errorf("#%d: Status = %d; want = %d", i, got, want)
541 }
542 }
543 if got, want := string(slurp), tt.want; got != want {
544 t.Errorf("#%d: Body = %q; want = %q", i, got, want)
545 }
546 }
547 }
548
549 func TestServeWithSlashRedirectForHostPatterns(t *testing.T) {
550 setParallel(t)
551
552 mux := NewServeMux()
553 mux.Handle("example.com/pkg/foo/", stringHandler("example.com/pkg/foo/"))
554 mux.Handle("example.com/pkg/bar", stringHandler("example.com/pkg/bar"))
555 mux.Handle("example.com/pkg/bar/", stringHandler("example.com/pkg/bar/"))
556 mux.Handle("example.com:3000/pkg/connect/", stringHandler("example.com:3000/pkg/connect/"))
557 mux.Handle("example.com:9000/", stringHandler("example.com:9000/"))
558 mux.Handle("/pkg/baz/", stringHandler("/pkg/baz/"))
559
560 tests := []struct {
561 method string
562 url string
563 code int
564 loc string
565 want string
566 }{
567 {"GET", "http://example.com/", 404, "", ""},
568 {"GET", "http://example.com/pkg/foo", 301, "/pkg/foo/", ""},
569 {"GET", "http://example.com/pkg/bar", 200, "", "example.com/pkg/bar"},
570 {"GET", "http://example.com/pkg/bar/", 200, "", "example.com/pkg/bar/"},
571 {"GET", "http://example.com/pkg/baz", 301, "/pkg/baz/", ""},
572 {"GET", "http://example.com:3000/pkg/foo", 301, "/pkg/foo/", ""},
573 {"CONNECT", "http://example.com/", 404, "", ""},
574 {"CONNECT", "http://example.com:3000/", 404, "", ""},
575 {"CONNECT", "http://example.com:9000/", 200, "", "example.com:9000/"},
576 {"CONNECT", "http://example.com/pkg/foo", 301, "/pkg/foo/", ""},
577 {"CONNECT", "http://example.com:3000/pkg/foo", 404, "", ""},
578 {"CONNECT", "http://example.com:3000/pkg/baz", 301, "/pkg/baz/", ""},
579 {"CONNECT", "http://example.com:3000/pkg/connect", 301, "/pkg/connect/", ""},
580 }
581
582 for i, tt := range tests {
583 req, _ := NewRequest(tt.method, tt.url, nil)
584 w := httptest.NewRecorder()
585 mux.ServeHTTP(w, req)
586
587 if got, want := w.Code, tt.code; got != want {
588 t.Errorf("#%d: Status = %d; want = %d", i, got, want)
589 }
590
591 if tt.code == 301 {
592 if got, want := w.HeaderMap.Get("Location"), tt.loc; got != want {
593 t.Errorf("#%d: Location = %q; want = %q", i, got, want)
594 }
595 } else {
596 if got, want := w.HeaderMap.Get("Result"), tt.want; got != want {
597 t.Errorf("#%d: Result = %q; want = %q", i, got, want)
598 }
599 }
600 }
601 }
602
603
604
605
606 func TestMuxNoSlashRedirectWithTrailingSlash(t *testing.T) {
607 mux := NewServeMux()
608 mux.HandleFunc("/{x}/", func(w ResponseWriter, r *Request) {
609 fmt.Fprintln(w, "ok")
610 })
611 w := httptest.NewRecorder()
612 req, _ := NewRequest("GET", "/", nil)
613 mux.ServeHTTP(w, req)
614 if g, w := w.Code, 404; g != w {
615 t.Errorf("got %d, want %d", g, w)
616 }
617 }
618
619
620
621
622 func TestMuxNoSlash405WithTrailingSlash(t *testing.T) {
623 mux := NewServeMux()
624 mux.HandleFunc("GET /{x}/", func(w ResponseWriter, r *Request) {
625 fmt.Fprintln(w, "ok")
626 })
627 w := httptest.NewRecorder()
628 req, _ := NewRequest("GET", "/", nil)
629 mux.ServeHTTP(w, req)
630 if g, w := w.Code, 404; g != w {
631 t.Errorf("got %d, want %d", g, w)
632 }
633 }
634
635 func TestShouldRedirectConcurrency(t *testing.T) { run(t, testShouldRedirectConcurrency) }
636 func testShouldRedirectConcurrency(t *testing.T, mode testMode) {
637 mux := NewServeMux()
638 newClientServerTest(t, mode, mux)
639 mux.HandleFunc("/", func(w ResponseWriter, r *Request) {})
640 }
641
642 func BenchmarkServeMux(b *testing.B) { benchmarkServeMux(b, true) }
643 func BenchmarkServeMux_SkipServe(b *testing.B) { benchmarkServeMux(b, false) }
644 func benchmarkServeMux(b *testing.B, runHandler bool) {
645 type test struct {
646 path string
647 code int
648 req *Request
649 }
650
651
652 var tests []test
653 endpoints := []string{"search", "dir", "file", "change", "count", "s"}
654 for _, e := range endpoints {
655 for i := 200; i < 230; i++ {
656 p := fmt.Sprintf("/%s/%d/", e, i)
657 tests = append(tests, test{
658 path: p,
659 code: i,
660 req: &Request{Method: "GET", Host: "localhost", URL: &url.URL{Path: p}},
661 })
662 }
663 }
664 mux := NewServeMux()
665 for _, tt := range tests {
666 mux.Handle(tt.path, serve(tt.code))
667 }
668
669 rw := httptest.NewRecorder()
670 b.ReportAllocs()
671 b.ResetTimer()
672 for i := 0; i < b.N; i++ {
673 for _, tt := range tests {
674 *rw = httptest.ResponseRecorder{}
675 h, pattern := mux.Handler(tt.req)
676 if runHandler {
677 h.ServeHTTP(rw, tt.req)
678 if pattern != tt.path || rw.Code != tt.code {
679 b.Fatalf("got %d, %q, want %d, %q", rw.Code, pattern, tt.code, tt.path)
680 }
681 }
682 }
683 }
684 }
685
686 func TestServerTimeouts(t *testing.T) { run(t, testServerTimeouts, []testMode{http1Mode}) }
687 func testServerTimeouts(t *testing.T, mode testMode) {
688 runTimeSensitiveTest(t, []time.Duration{
689 10 * time.Millisecond,
690 50 * time.Millisecond,
691 100 * time.Millisecond,
692 500 * time.Millisecond,
693 1 * time.Second,
694 }, func(t *testing.T, timeout time.Duration) error {
695 return testServerTimeoutsWithTimeout(t, timeout, mode)
696 })
697 }
698
699 func testServerTimeoutsWithTimeout(t *testing.T, timeout time.Duration, mode testMode) error {
700 var reqNum atomic.Int32
701 cst := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
702 fmt.Fprintf(res, "req=%d", reqNum.Add(1))
703 }), func(ts *httptest.Server) {
704 ts.Config.ReadTimeout = timeout
705 ts.Config.WriteTimeout = timeout
706 })
707 defer cst.close()
708 ts := cst.ts
709
710
711 c := ts.Client()
712 r, err := c.Get(ts.URL)
713 if err != nil {
714 return fmt.Errorf("http Get #1: %v", err)
715 }
716 got, err := io.ReadAll(r.Body)
717 expected := "req=1"
718 if string(got) != expected || err != nil {
719 return fmt.Errorf("Unexpected response for request #1; got %q ,%v; expected %q, nil",
720 string(got), err, expected)
721 }
722
723
724 t1 := time.Now()
725 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
726 if err != nil {
727 return fmt.Errorf("Dial: %v", err)
728 }
729 buf := make([]byte, 1)
730 n, err := conn.Read(buf)
731 conn.Close()
732 latency := time.Since(t1)
733 if n != 0 || err != io.EOF {
734 return fmt.Errorf("Read = %v, %v, wanted %v, %v", n, err, 0, io.EOF)
735 }
736 minLatency := timeout / 5 * 4
737 if latency < minLatency {
738 return fmt.Errorf("got EOF after %s, want >= %s", latency, minLatency)
739 }
740
741
742
743
744 r, err = c.Get(ts.URL)
745 if err != nil {
746 return fmt.Errorf("http Get #2: %v", err)
747 }
748 got, err = io.ReadAll(r.Body)
749 r.Body.Close()
750 expected = "req=2"
751 if string(got) != expected || err != nil {
752 return fmt.Errorf("Get #2 got %q, %v, want %q, nil", string(got), err, expected)
753 }
754
755 if !testing.Short() {
756 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
757 if err != nil {
758 return fmt.Errorf("long Dial: %v", err)
759 }
760 defer conn.Close()
761 go io.Copy(io.Discard, conn)
762 for i := 0; i < 5; i++ {
763 _, err := conn.Write([]byte("GET / HTTP/1.1\r\nHost: foo\r\n\r\n"))
764 if err != nil {
765 return fmt.Errorf("on write %d: %v", i, err)
766 }
767 time.Sleep(timeout / 2)
768 }
769 }
770 return nil
771 }
772
773 func TestServerReadTimeout(t *testing.T) { run(t, testServerReadTimeout) }
774 func testServerReadTimeout(t *testing.T, mode testMode) {
775 respBody := "response body"
776 for timeout := 5 * time.Millisecond; ; timeout *= 2 {
777 cst := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
778 _, err := io.Copy(io.Discard, req.Body)
779 if !errors.Is(err, os.ErrDeadlineExceeded) {
780 t.Errorf("server timed out reading request body: got err %v; want os.ErrDeadlineExceeded", err)
781 }
782 res.Write([]byte(respBody))
783 }), func(ts *httptest.Server) {
784 ts.Config.ReadHeaderTimeout = -1
785 ts.Config.ReadTimeout = timeout
786 t.Logf("Server.Config.ReadTimeout = %v", timeout)
787 })
788
789 var retries atomic.Int32
790 cst.c.Transport.(*Transport).Proxy = func(*Request) (*url.URL, error) {
791 if retries.Add(1) != 1 {
792 return nil, errors.New("too many retries")
793 }
794 return nil, nil
795 }
796
797 pr, pw := io.Pipe()
798 res, err := cst.c.Post(cst.ts.URL, "text/apocryphal", pr)
799 if err != nil {
800 t.Logf("Get error, retrying: %v", err)
801 cst.close()
802 continue
803 }
804 defer res.Body.Close()
805 got, err := io.ReadAll(res.Body)
806 if string(got) != respBody || err != nil {
807 t.Errorf("client read response body: %q, %v; want %q, nil", string(got), err, respBody)
808 }
809 pw.Close()
810 break
811 }
812 }
813
814 func TestServerNoReadTimeout(t *testing.T) { run(t, testServerNoReadTimeout) }
815 func testServerNoReadTimeout(t *testing.T, mode testMode) {
816 reqBody := "Hello, Gophers!"
817 resBody := "Hi, Gophers!"
818 for _, timeout := range []time.Duration{0, -1} {
819 cst := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
820 ctl := NewResponseController(res)
821 ctl.EnableFullDuplex()
822 res.WriteHeader(StatusOK)
823
824
825 if err := ctl.Flush(); err != nil {
826 t.Errorf("server flush response: %v", err)
827 return
828 }
829 got, err := io.ReadAll(req.Body)
830 if string(got) != reqBody || err != nil {
831 t.Errorf("server read request body: %v; got %q, want %q", err, got, reqBody)
832 }
833 res.Write([]byte(resBody))
834 }), func(ts *httptest.Server) {
835 ts.Config.ReadTimeout = timeout
836 t.Logf("Server.Config.ReadTimeout = %d", timeout)
837 })
838
839 pr, pw := io.Pipe()
840 res, err := cst.c.Post(cst.ts.URL, "text/plain", pr)
841 if err != nil {
842 t.Fatal(err)
843 }
844 defer res.Body.Close()
845
846
847 time.Sleep(10 * time.Millisecond)
848 pw.Write([]byte(reqBody))
849 pw.Close()
850
851 got, err := io.ReadAll(res.Body)
852 if string(got) != resBody || err != nil {
853 t.Errorf("client read response body: %v; got %v, want %q", err, got, resBody)
854 }
855 }
856 }
857
858 func TestServerWriteTimeout(t *testing.T) { run(t, testServerWriteTimeout) }
859 func testServerWriteTimeout(t *testing.T, mode testMode) {
860 for timeout := 5 * time.Millisecond; ; timeout *= 2 {
861 errc := make(chan error, 2)
862 cst := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
863 errc <- nil
864 _, err := io.Copy(res, neverEnding('a'))
865 errc <- err
866 }), func(ts *httptest.Server) {
867 ts.Config.WriteTimeout = timeout
868 t.Logf("Server.Config.WriteTimeout = %v", timeout)
869 })
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888 var retries atomic.Int32
889 cst.c.Transport.(*Transport).Proxy = func(*Request) (*url.URL, error) {
890 if retries.Add(1) != 1 {
891 return nil, errors.New("too many retries")
892 }
893 return nil, nil
894 }
895
896 res, err := cst.c.Get(cst.ts.URL)
897 if err != nil {
898
899 t.Logf("Get error, retrying: %v", err)
900 cst.close()
901 continue
902 }
903 defer res.Body.Close()
904 _, err = io.Copy(io.Discard, res.Body)
905 if err == nil {
906 t.Errorf("client reading from truncated request body: got nil error, want non-nil")
907 }
908 select {
909 case <-errc:
910 err = <-errc
911 if !errors.Is(err, os.ErrDeadlineExceeded) {
912 t.Errorf("server timed out writing request body: got err %v; want os.ErrDeadlineExceeded", err)
913 }
914 return
915 default:
916
917 t.Logf("handler didn't run, retrying")
918 cst.close()
919 }
920 }
921 }
922
923 func TestServerNoWriteTimeout(t *testing.T) { run(t, testServerNoWriteTimeout) }
924 func testServerNoWriteTimeout(t *testing.T, mode testMode) {
925 for _, timeout := range []time.Duration{0, -1} {
926 cst := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
927 _, err := io.Copy(res, neverEnding('a'))
928 t.Logf("server write response: %v", err)
929 }), func(ts *httptest.Server) {
930 ts.Config.WriteTimeout = timeout
931 t.Logf("Server.Config.WriteTimeout = %d", timeout)
932 })
933
934 res, err := cst.c.Get(cst.ts.URL)
935 if err != nil {
936 t.Fatal(err)
937 }
938 defer res.Body.Close()
939 n, err := io.CopyN(io.Discard, res.Body, 1<<20)
940 if n != 1<<20 || err != nil {
941 t.Errorf("client read response body: %d, %v", n, err)
942 }
943
944
945 res.Body.Close()
946 cst.ts.Config.Shutdown(context.Background())
947 }
948 }
949
950
951 func TestWriteDeadlineExtendedOnNewRequest(t *testing.T) {
952 run(t, testWriteDeadlineExtendedOnNewRequest)
953 }
954 func testWriteDeadlineExtendedOnNewRequest(t *testing.T, mode testMode) {
955 if testing.Short() {
956 t.Skip("skipping in short mode")
957 }
958 ts := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {}),
959 func(ts *httptest.Server) {
960 ts.Config.WriteTimeout = 250 * time.Millisecond
961 },
962 ).ts
963
964 c := ts.Client()
965
966 for i := 1; i <= 3; i++ {
967 req, err := NewRequest("GET", ts.URL, nil)
968 if err != nil {
969 t.Fatal(err)
970 }
971
972 r, err := c.Do(req)
973 if err != nil {
974 t.Fatalf("http2 Get #%d: %v", i, err)
975 }
976 r.Body.Close()
977 time.Sleep(ts.Config.WriteTimeout / 2)
978 }
979 }
980
981
982
983 func tryTimeouts(t *testing.T, testFunc func(timeout time.Duration) error) {
984 tries := []time.Duration{250 * time.Millisecond, 500 * time.Millisecond, 1 * time.Second}
985 for i, timeout := range tries {
986 err := testFunc(timeout)
987 if err == nil {
988 return
989 }
990 t.Logf("failed at %v: %v", timeout, err)
991 if i != len(tries)-1 {
992 t.Logf("retrying at %v ...", tries[i+1])
993 }
994 }
995 t.Fatal("all attempts failed")
996 }
997
998
999 func TestWriteDeadlineEnforcedPerStream(t *testing.T) {
1000 if testing.Short() {
1001 t.Skip("skipping in short mode")
1002 }
1003 setParallel(t)
1004 run(t, func(t *testing.T, mode testMode) {
1005 tryTimeouts(t, func(timeout time.Duration) error {
1006 return testWriteDeadlineEnforcedPerStream(t, mode, timeout)
1007 })
1008 })
1009 }
1010
1011 func testWriteDeadlineEnforcedPerStream(t *testing.T, mode testMode, timeout time.Duration) error {
1012 firstRequest := make(chan bool, 1)
1013 cst := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
1014 select {
1015 case firstRequest <- true:
1016
1017 default:
1018
1019 time.Sleep(timeout)
1020 }
1021 }), func(ts *httptest.Server) {
1022 ts.Config.WriteTimeout = timeout / 2
1023 })
1024 defer cst.close()
1025 ts := cst.ts
1026
1027 c := ts.Client()
1028
1029 req, err := NewRequest("GET", ts.URL, nil)
1030 if err != nil {
1031 return fmt.Errorf("NewRequest: %v", err)
1032 }
1033 r, err := c.Do(req)
1034 if err != nil {
1035 return fmt.Errorf("Get #1: %v", err)
1036 }
1037 r.Body.Close()
1038
1039 req, err = NewRequest("GET", ts.URL, nil)
1040 if err != nil {
1041 return fmt.Errorf("NewRequest: %v", err)
1042 }
1043 r, err = c.Do(req)
1044 if err == nil {
1045 r.Body.Close()
1046 return fmt.Errorf("Get #2 expected error, got nil")
1047 }
1048 if mode == http2Mode {
1049 expected := "stream ID 3; INTERNAL_ERROR"
1050 if !strings.Contains(err.Error(), expected) {
1051 return fmt.Errorf("http2 Get #2: expected error to contain %q, got %q", expected, err)
1052 }
1053 }
1054 return nil
1055 }
1056
1057
1058 func TestNoWriteDeadline(t *testing.T) {
1059 if testing.Short() {
1060 t.Skip("skipping in short mode")
1061 }
1062 setParallel(t)
1063 defer afterTest(t)
1064 run(t, func(t *testing.T, mode testMode) {
1065 tryTimeouts(t, func(timeout time.Duration) error {
1066 return testNoWriteDeadline(t, mode, timeout)
1067 })
1068 })
1069 }
1070
1071 func testNoWriteDeadline(t *testing.T, mode testMode, timeout time.Duration) error {
1072 firstRequest := make(chan bool, 1)
1073 cst := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
1074 select {
1075 case firstRequest <- true:
1076
1077 default:
1078
1079 time.Sleep(timeout)
1080 }
1081 }))
1082 defer cst.close()
1083 ts := cst.ts
1084
1085 c := ts.Client()
1086
1087 for i := 0; i < 2; i++ {
1088 req, err := NewRequest("GET", ts.URL, nil)
1089 if err != nil {
1090 return fmt.Errorf("NewRequest: %v", err)
1091 }
1092 r, err := c.Do(req)
1093 if err != nil {
1094 return fmt.Errorf("Get #%d: %v", i, err)
1095 }
1096 r.Body.Close()
1097 }
1098 return nil
1099 }
1100
1101
1102
1103
1104 func TestOnlyWriteTimeout(t *testing.T) { run(t, testOnlyWriteTimeout, []testMode{http1Mode}) }
1105 func testOnlyWriteTimeout(t *testing.T, mode testMode) {
1106 var (
1107 mu sync.RWMutex
1108 conn net.Conn
1109 )
1110 var afterTimeoutErrc = make(chan error, 1)
1111 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, req *Request) {
1112 buf := make([]byte, 512<<10)
1113 _, err := w.Write(buf)
1114 if err != nil {
1115 t.Errorf("handler Write error: %v", err)
1116 return
1117 }
1118 mu.RLock()
1119 defer mu.RUnlock()
1120 if conn == nil {
1121 t.Error("no established connection found")
1122 return
1123 }
1124 conn.SetWriteDeadline(time.Now().Add(-30 * time.Second))
1125 _, err = w.Write(buf)
1126 afterTimeoutErrc <- err
1127 }), func(ts *httptest.Server) {
1128 ts.Listener = trackLastConnListener{ts.Listener, &mu, &conn}
1129 }).ts
1130
1131 c := ts.Client()
1132
1133 err := func() error {
1134 res, err := c.Get(ts.URL)
1135 if err != nil {
1136 return err
1137 }
1138 _, err = io.Copy(io.Discard, res.Body)
1139 res.Body.Close()
1140 return err
1141 }()
1142 if err == nil {
1143 t.Errorf("expected an error copying body from Get request")
1144 }
1145
1146 if err := <-afterTimeoutErrc; err == nil {
1147 t.Error("expected write error after timeout")
1148 }
1149 }
1150
1151
1152 type trackLastConnListener struct {
1153 net.Listener
1154
1155 mu *sync.RWMutex
1156 last *net.Conn
1157 }
1158
1159 func (l trackLastConnListener) Accept() (c net.Conn, err error) {
1160 c, err = l.Listener.Accept()
1161 if err == nil {
1162 l.mu.Lock()
1163 *l.last = c
1164 l.mu.Unlock()
1165 }
1166 return
1167 }
1168
1169
1170 func TestIdentityResponse(t *testing.T) { run(t, testIdentityResponse) }
1171 func testIdentityResponse(t *testing.T, mode testMode) {
1172 if mode == http2Mode {
1173 t.Skip("https://go.dev/issue/56019")
1174 }
1175
1176 handler := HandlerFunc(func(rw ResponseWriter, req *Request) {
1177 rw.Header().Set("Content-Length", "3")
1178 rw.Header().Set("Transfer-Encoding", req.FormValue("te"))
1179 switch {
1180 case req.FormValue("overwrite") == "1":
1181 _, err := rw.Write([]byte("foo TOO LONG"))
1182 if err != ErrContentLength {
1183 t.Errorf("expected ErrContentLength; got %v", err)
1184 }
1185 case req.FormValue("underwrite") == "1":
1186 rw.Header().Set("Content-Length", "500")
1187 rw.Write([]byte("too short"))
1188 default:
1189 rw.Write([]byte("foo"))
1190 }
1191 })
1192
1193 ts := newClientServerTest(t, mode, handler).ts
1194 c := ts.Client()
1195
1196
1197
1198
1199
1200 for _, te := range []string{"", "identity"} {
1201 url := ts.URL + "/?te=" + te
1202 res, err := c.Get(url)
1203 if err != nil {
1204 t.Fatalf("error with Get of %s: %v", url, err)
1205 }
1206 if cl, expected := res.ContentLength, int64(3); cl != expected {
1207 t.Errorf("for %s expected res.ContentLength of %d; got %d", url, expected, cl)
1208 }
1209 if cl, expected := res.Header.Get("Content-Length"), "3"; cl != expected {
1210 t.Errorf("for %s expected Content-Length header of %q; got %q", url, expected, cl)
1211 }
1212 if tl, expected := len(res.TransferEncoding), 0; tl != expected {
1213 t.Errorf("for %s expected len(res.TransferEncoding) of %d; got %d (%v)",
1214 url, expected, tl, res.TransferEncoding)
1215 }
1216 res.Body.Close()
1217 }
1218
1219
1220 url := ts.URL + "/?overwrite=1"
1221 res, err := c.Get(url)
1222 if err != nil {
1223 t.Fatalf("error with Get of %s: %v", url, err)
1224 }
1225 res.Body.Close()
1226
1227 if mode != http1Mode {
1228 return
1229 }
1230
1231
1232
1233 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
1234 if err != nil {
1235 t.Fatalf("error dialing: %v", err)
1236 }
1237 _, err = conn.Write([]byte("GET /?underwrite=1 HTTP/1.1\r\nHost: foo\r\n\r\n"))
1238 if err != nil {
1239 t.Fatalf("error writing: %v", err)
1240 }
1241
1242
1243 got, _ := io.ReadAll(conn)
1244 expectedSuffix := "\r\n\r\ntoo short"
1245 if !strings.HasSuffix(string(got), expectedSuffix) {
1246 t.Errorf("Expected output to end with %q; got response body %q",
1247 expectedSuffix, string(got))
1248 }
1249 }
1250
1251 func testTCPConnectionCloses(t *testing.T, req string, h Handler) {
1252 setParallel(t)
1253 s := newClientServerTest(t, http1Mode, h).ts
1254
1255 conn, err := net.Dial("tcp", s.Listener.Addr().String())
1256 if err != nil {
1257 t.Fatal("dial error:", err)
1258 }
1259 defer conn.Close()
1260
1261 _, err = fmt.Fprint(conn, req)
1262 if err != nil {
1263 t.Fatal("print error:", err)
1264 }
1265
1266 r := bufio.NewReader(conn)
1267 res, err := ReadResponse(r, &Request{Method: "GET"})
1268 if err != nil {
1269 t.Fatal("ReadResponse error:", err)
1270 }
1271
1272 _, err = io.ReadAll(r)
1273 if err != nil {
1274 t.Fatal("read error:", err)
1275 }
1276
1277 if !res.Close {
1278 t.Errorf("Response.Close = false; want true")
1279 }
1280 }
1281
1282 func testTCPConnectionStaysOpen(t *testing.T, req string, handler Handler) {
1283 setParallel(t)
1284 ts := newClientServerTest(t, http1Mode, handler).ts
1285 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
1286 if err != nil {
1287 t.Fatal(err)
1288 }
1289 defer conn.Close()
1290 br := bufio.NewReader(conn)
1291 for i := 0; i < 2; i++ {
1292 if _, err := io.WriteString(conn, req); err != nil {
1293 t.Fatal(err)
1294 }
1295 res, err := ReadResponse(br, nil)
1296 if err != nil {
1297 t.Fatalf("res %d: %v", i+1, err)
1298 }
1299 if _, err := io.Copy(io.Discard, res.Body); err != nil {
1300 t.Fatalf("res %d body copy: %v", i+1, err)
1301 }
1302 res.Body.Close()
1303 }
1304 }
1305
1306
1307 func TestServeHTTP10Close(t *testing.T) {
1308 testTCPConnectionCloses(t, "GET / HTTP/1.0\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
1309 ServeFile(w, r, "testdata/file")
1310 }))
1311 }
1312
1313
1314 func TestClientCanClose(t *testing.T) {
1315 testTCPConnectionCloses(t, "GET / HTTP/1.1\r\nHost: foo\r\nConnection: close\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
1316
1317 }))
1318 }
1319
1320
1321
1322 func TestHandlersCanSetConnectionClose11(t *testing.T) {
1323 testTCPConnectionCloses(t, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
1324 w.Header().Set("Connection", "close")
1325 }))
1326 }
1327
1328 func TestHandlersCanSetConnectionClose10(t *testing.T) {
1329 testTCPConnectionCloses(t, "GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
1330 w.Header().Set("Connection", "close")
1331 }))
1332 }
1333
1334 func TestHTTP2UpgradeClosesConnection(t *testing.T) {
1335 testTCPConnectionCloses(t, "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
1336
1337
1338 }))
1339 }
1340
1341 func send204(w ResponseWriter, r *Request) { w.WriteHeader(204) }
1342 func send304(w ResponseWriter, r *Request) { w.WriteHeader(304) }
1343
1344
1345 func TestHTTP10KeepAlive204Response(t *testing.T) {
1346 testTCPConnectionStaysOpen(t, "GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n", HandlerFunc(send204))
1347 }
1348
1349 func TestHTTP11KeepAlive204Response(t *testing.T) {
1350 testTCPConnectionStaysOpen(t, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n", HandlerFunc(send204))
1351 }
1352
1353 func TestHTTP10KeepAlive304Response(t *testing.T) {
1354 testTCPConnectionStaysOpen(t,
1355 "GET / HTTP/1.0\r\nConnection: keep-alive\r\nIf-Modified-Since: Mon, 02 Jan 2006 15:04:05 GMT\r\n\r\n",
1356 HandlerFunc(send304))
1357 }
1358
1359
1360 func TestKeepAliveFinalChunkWithEOF(t *testing.T) { run(t, testKeepAliveFinalChunkWithEOF) }
1361 func testKeepAliveFinalChunkWithEOF(t *testing.T, mode testMode) {
1362 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1363 w.(Flusher).Flush()
1364 w.Write([]byte("{\"Addr\": \"" + r.RemoteAddr + "\"}"))
1365 }))
1366 type data struct {
1367 Addr string
1368 }
1369 var addrs [2]data
1370 for i := range addrs {
1371 res, err := cst.c.Get(cst.ts.URL)
1372 if err != nil {
1373 t.Fatal(err)
1374 }
1375 if err := json.NewDecoder(res.Body).Decode(&addrs[i]); err != nil {
1376 t.Fatal(err)
1377 }
1378 if addrs[i].Addr == "" {
1379 t.Fatal("no address")
1380 }
1381 res.Body.Close()
1382 }
1383 if addrs[0] != addrs[1] {
1384 t.Fatalf("connection not reused")
1385 }
1386 }
1387
1388 func TestSetsRemoteAddr(t *testing.T) { run(t, testSetsRemoteAddr) }
1389 func testSetsRemoteAddr(t *testing.T, mode testMode) {
1390 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1391 fmt.Fprintf(w, "%s", r.RemoteAddr)
1392 }))
1393
1394 res, err := cst.c.Get(cst.ts.URL)
1395 if err != nil {
1396 t.Fatalf("Get error: %v", err)
1397 }
1398 body, err := io.ReadAll(res.Body)
1399 if err != nil {
1400 t.Fatalf("ReadAll error: %v", err)
1401 }
1402 ip := string(body)
1403 if !strings.HasPrefix(ip, "127.0.0.1:") && !strings.HasPrefix(ip, "[::1]:") {
1404 t.Fatalf("Expected local addr; got %q", ip)
1405 }
1406 }
1407
1408 type blockingRemoteAddrListener struct {
1409 net.Listener
1410 conns chan<- net.Conn
1411 }
1412
1413 func (l *blockingRemoteAddrListener) Accept() (net.Conn, error) {
1414 c, err := l.Listener.Accept()
1415 if err != nil {
1416 return nil, err
1417 }
1418 brac := &blockingRemoteAddrConn{
1419 Conn: c,
1420 addrs: make(chan net.Addr, 1),
1421 }
1422 l.conns <- brac
1423 return brac, nil
1424 }
1425
1426 type blockingRemoteAddrConn struct {
1427 net.Conn
1428 addrs chan net.Addr
1429 }
1430
1431 func (c *blockingRemoteAddrConn) RemoteAddr() net.Addr {
1432 return <-c.addrs
1433 }
1434
1435
1436 func TestServerAllowsBlockingRemoteAddr(t *testing.T) {
1437 run(t, testServerAllowsBlockingRemoteAddr, []testMode{http1Mode})
1438 }
1439 func testServerAllowsBlockingRemoteAddr(t *testing.T, mode testMode) {
1440 conns := make(chan net.Conn)
1441 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1442 fmt.Fprintf(w, "RA:%s", r.RemoteAddr)
1443 }), func(ts *httptest.Server) {
1444 ts.Listener = &blockingRemoteAddrListener{
1445 Listener: ts.Listener,
1446 conns: conns,
1447 }
1448 }).ts
1449
1450 c := ts.Client()
1451
1452 c.Transport.(*Transport).DisableKeepAlives = true
1453
1454 fetch := func(num int, response chan<- string) {
1455 resp, err := c.Get(ts.URL)
1456 if err != nil {
1457 t.Errorf("Request %d: %v", num, err)
1458 response <- ""
1459 return
1460 }
1461 defer resp.Body.Close()
1462 body, err := io.ReadAll(resp.Body)
1463 if err != nil {
1464 t.Errorf("Request %d: %v", num, err)
1465 response <- ""
1466 return
1467 }
1468 response <- string(body)
1469 }
1470
1471
1472 response1c := make(chan string, 1)
1473 go fetch(1, response1c)
1474
1475
1476 conn1 := <-conns
1477
1478
1479 response2c := make(chan string, 1)
1480 go fetch(2, response2c)
1481 conn2 := <-conns
1482
1483
1484 conn2.(*blockingRemoteAddrConn).addrs <- &net.TCPAddr{
1485 IP: net.ParseIP("12.12.12.12"), Port: 12}
1486
1487
1488 response2 := <-response2c
1489 if g, e := response2, "RA:12.12.12.12:12"; g != e {
1490 t.Fatalf("response 2 addr = %q; want %q", g, e)
1491 }
1492
1493
1494 conn1.(*blockingRemoteAddrConn).addrs <- &net.TCPAddr{
1495 IP: net.ParseIP("21.21.21.21"), Port: 21}
1496
1497
1498 response1 := <-response1c
1499 if g, e := response1, "RA:21.21.21.21:21"; g != e {
1500 t.Fatalf("response 1 addr = %q; want %q", g, e)
1501 }
1502 }
1503
1504
1505
1506 func TestHeadResponses(t *testing.T) { run(t, testHeadResponses) }
1507 func testHeadResponses(t *testing.T, mode testMode) {
1508 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1509 _, err := w.Write([]byte("<html>"))
1510 if err != nil {
1511 t.Errorf("ResponseWriter.Write: %v", err)
1512 }
1513
1514
1515 _, err = io.Copy(w, struct{ io.Reader }{strings.NewReader("789a")})
1516 if err != nil {
1517 t.Errorf("Copy(ResponseWriter, ...): %v", err)
1518 }
1519 }))
1520 res, err := cst.c.Head(cst.ts.URL)
1521 if err != nil {
1522 t.Error(err)
1523 }
1524 if len(res.TransferEncoding) > 0 {
1525 t.Errorf("expected no TransferEncoding; got %v", res.TransferEncoding)
1526 }
1527 if ct := res.Header.Get("Content-Type"); ct != "text/html; charset=utf-8" {
1528 t.Errorf("Content-Type: %q; want text/html; charset=utf-8", ct)
1529 }
1530 if v := res.ContentLength; v != 10 {
1531 t.Errorf("Content-Length: %d; want 10", v)
1532 }
1533 body, err := io.ReadAll(res.Body)
1534 if err != nil {
1535 t.Error(err)
1536 }
1537 if len(body) > 0 {
1538 t.Errorf("got unexpected body %q", string(body))
1539 }
1540 }
1541
1542
1543
1544 func TestHeadReaderFrom(t *testing.T) { run(t, testHeadReaderFrom, []testMode{http1Mode}) }
1545 func testHeadReaderFrom(t *testing.T, mode testMode) {
1546
1547 wantBody := strings.Repeat("a", 4096)
1548 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1549 w.(io.ReaderFrom).ReadFrom(strings.NewReader(wantBody))
1550 }))
1551 res, err := cst.c.Head(cst.ts.URL)
1552 if err != nil {
1553 t.Fatal(err)
1554 }
1555 res.Body.Close()
1556 res, err = cst.c.Get(cst.ts.URL)
1557 if err != nil {
1558 t.Fatal(err)
1559 }
1560 gotBody, err := io.ReadAll(res.Body)
1561 res.Body.Close()
1562 if err != nil {
1563 t.Fatal(err)
1564 }
1565 if string(gotBody) != wantBody {
1566 t.Errorf("got unexpected body len=%v, want %v", len(gotBody), len(wantBody))
1567 }
1568 }
1569
1570 func TestTLSHandshakeTimeout(t *testing.T) {
1571 run(t, testTLSHandshakeTimeout, []testMode{https1Mode, http2Mode})
1572 }
1573 func testTLSHandshakeTimeout(t *testing.T, mode testMode) {
1574 errLog := new(strings.Builder)
1575 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}),
1576 func(ts *httptest.Server) {
1577 ts.Config.ReadTimeout = 250 * time.Millisecond
1578 ts.Config.ErrorLog = log.New(errLog, "", 0)
1579 },
1580 )
1581 ts := cst.ts
1582
1583 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
1584 if err != nil {
1585 t.Fatalf("Dial: %v", err)
1586 }
1587 var buf [1]byte
1588 n, err := conn.Read(buf[:])
1589 if err == nil || n != 0 {
1590 t.Errorf("Read = %d, %v; want an error and no bytes", n, err)
1591 }
1592 conn.Close()
1593
1594 cst.close()
1595 if v := errLog.String(); !strings.Contains(v, "timeout") && !strings.Contains(v, "TLS handshake") {
1596 t.Errorf("expected a TLS handshake timeout error; got %q", v)
1597 }
1598 }
1599
1600 func TestTLSServer(t *testing.T) { run(t, testTLSServer, []testMode{https1Mode, http2Mode}) }
1601 func testTLSServer(t *testing.T, mode testMode) {
1602 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1603 if r.TLS != nil {
1604 w.Header().Set("X-TLS-Set", "true")
1605 if r.TLS.HandshakeComplete {
1606 w.Header().Set("X-TLS-HandshakeComplete", "true")
1607 }
1608 }
1609 }), func(ts *httptest.Server) {
1610 ts.Config.ErrorLog = log.New(io.Discard, "", 0)
1611 }).ts
1612
1613
1614
1615
1616
1617
1618 idleConn, err := net.Dial("tcp", ts.Listener.Addr().String())
1619 if err != nil {
1620 t.Fatalf("Dial: %v", err)
1621 }
1622 defer idleConn.Close()
1623
1624 if !strings.HasPrefix(ts.URL, "https://") {
1625 t.Errorf("expected test TLS server to start with https://, got %q", ts.URL)
1626 return
1627 }
1628 client := ts.Client()
1629 res, err := client.Get(ts.URL)
1630 if err != nil {
1631 t.Error(err)
1632 return
1633 }
1634 if res == nil {
1635 t.Errorf("got nil Response")
1636 return
1637 }
1638 defer res.Body.Close()
1639 if res.Header.Get("X-TLS-Set") != "true" {
1640 t.Errorf("expected X-TLS-Set response header")
1641 return
1642 }
1643 if res.Header.Get("X-TLS-HandshakeComplete") != "true" {
1644 t.Errorf("expected X-TLS-HandshakeComplete header")
1645 }
1646 }
1647
1648 func TestServeTLS(t *testing.T) {
1649 CondSkipHTTP2(t)
1650
1651 defer afterTest(t)
1652 defer SetTestHookServerServe(nil)
1653
1654 cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
1655 if err != nil {
1656 t.Fatal(err)
1657 }
1658 tlsConf := &tls.Config{
1659 Certificates: []tls.Certificate{cert},
1660 }
1661
1662 ln := newLocalListener(t)
1663 defer ln.Close()
1664 addr := ln.Addr().String()
1665
1666 serving := make(chan bool, 1)
1667 SetTestHookServerServe(func(s *Server, ln net.Listener) {
1668 serving <- true
1669 })
1670 handler := HandlerFunc(func(w ResponseWriter, r *Request) {})
1671 s := &Server{
1672 Addr: addr,
1673 TLSConfig: tlsConf,
1674 Handler: handler,
1675 }
1676 errc := make(chan error, 1)
1677 go func() { errc <- s.ServeTLS(ln, "", "") }()
1678 select {
1679 case err := <-errc:
1680 t.Fatalf("ServeTLS: %v", err)
1681 case <-serving:
1682 }
1683
1684 c, err := tls.Dial("tcp", ln.Addr().String(), &tls.Config{
1685 InsecureSkipVerify: true,
1686 NextProtos: []string{"h2", "http/1.1"},
1687 })
1688 if err != nil {
1689 t.Fatal(err)
1690 }
1691 defer c.Close()
1692 if got, want := c.ConnectionState().NegotiatedProtocol, "h2"; got != want {
1693 t.Errorf("NegotiatedProtocol = %q; want %q", got, want)
1694 }
1695 if got, want := c.ConnectionState().NegotiatedProtocolIsMutual, true; got != want {
1696 t.Errorf("NegotiatedProtocolIsMutual = %v; want %v", got, want)
1697 }
1698 }
1699
1700
1701 func TestTLSServerRejectHTTPRequests(t *testing.T) {
1702 run(t, testTLSServerRejectHTTPRequests, []testMode{https1Mode, http2Mode})
1703 }
1704 func testTLSServerRejectHTTPRequests(t *testing.T, mode testMode) {
1705 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1706 t.Error("unexpected HTTPS request")
1707 }), func(ts *httptest.Server) {
1708 var errBuf bytes.Buffer
1709 ts.Config.ErrorLog = log.New(&errBuf, "", 0)
1710 }).ts
1711 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
1712 if err != nil {
1713 t.Fatal(err)
1714 }
1715 defer conn.Close()
1716 io.WriteString(conn, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n")
1717 slurp, err := io.ReadAll(conn)
1718 if err != nil {
1719 t.Fatal(err)
1720 }
1721 const wantPrefix = "HTTP/1.0 400 Bad Request\r\n"
1722 if !strings.HasPrefix(string(slurp), wantPrefix) {
1723 t.Errorf("response = %q; wanted prefix %q", slurp, wantPrefix)
1724 }
1725 }
1726
1727
1728 func TestAutomaticHTTP2_Serve_NoTLSConfig(t *testing.T) {
1729 testAutomaticHTTP2_Serve(t, nil, true)
1730 }
1731
1732 func TestAutomaticHTTP2_Serve_NonH2TLSConfig(t *testing.T) {
1733 testAutomaticHTTP2_Serve(t, &tls.Config{}, false)
1734 }
1735
1736 func TestAutomaticHTTP2_Serve_H2TLSConfig(t *testing.T) {
1737 testAutomaticHTTP2_Serve(t, &tls.Config{NextProtos: []string{"h2"}}, true)
1738 }
1739
1740 func testAutomaticHTTP2_Serve(t *testing.T, tlsConf *tls.Config, wantH2 bool) {
1741 setParallel(t)
1742 defer afterTest(t)
1743 ln := newLocalListener(t)
1744 ln.Close()
1745 var s Server
1746 s.TLSConfig = tlsConf
1747 if err := s.Serve(ln); err == nil {
1748 t.Fatal("expected an error")
1749 }
1750 gotH2 := s.TLSNextProto["h2"] != nil
1751 if gotH2 != wantH2 {
1752 t.Errorf("http2 configured = %v; want %v", gotH2, wantH2)
1753 }
1754 }
1755
1756 func TestAutomaticHTTP2_Serve_WithTLSConfig(t *testing.T) {
1757 setParallel(t)
1758 defer afterTest(t)
1759 ln := newLocalListener(t)
1760 ln.Close()
1761 var s Server
1762
1763
1764 s.TLSConfig = &tls.Config{
1765 NextProtos: []string{"h2"},
1766 }
1767 if err := s.Serve(ln); err == nil {
1768 t.Fatal("expected an error")
1769 }
1770 on := s.TLSNextProto["h2"] != nil
1771 if !on {
1772 t.Errorf("http2 wasn't automatically enabled")
1773 }
1774 }
1775
1776 func TestAutomaticHTTP2_ListenAndServe(t *testing.T) {
1777 cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
1778 if err != nil {
1779 t.Fatal(err)
1780 }
1781 testAutomaticHTTP2_ListenAndServe(t, &tls.Config{
1782 Certificates: []tls.Certificate{cert},
1783 })
1784 }
1785
1786 func TestAutomaticHTTP2_ListenAndServe_GetCertificate(t *testing.T) {
1787 cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
1788 if err != nil {
1789 t.Fatal(err)
1790 }
1791 testAutomaticHTTP2_ListenAndServe(t, &tls.Config{
1792 GetCertificate: func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
1793 return &cert, nil
1794 },
1795 })
1796 }
1797
1798 func TestAutomaticHTTP2_ListenAndServe_GetConfigForClient(t *testing.T) {
1799 cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
1800 if err != nil {
1801 t.Fatal(err)
1802 }
1803 conf := &tls.Config{
1804
1805
1806 NextProtos: []string{"h2"},
1807 Certificates: []tls.Certificate{cert},
1808 }
1809 testAutomaticHTTP2_ListenAndServe(t, &tls.Config{
1810 GetConfigForClient: func(clientHello *tls.ClientHelloInfo) (*tls.Config, error) {
1811 return conf, nil
1812 },
1813 })
1814 }
1815
1816 func testAutomaticHTTP2_ListenAndServe(t *testing.T, tlsConf *tls.Config) {
1817 CondSkipHTTP2(t)
1818
1819 defer afterTest(t)
1820 defer SetTestHookServerServe(nil)
1821 var ok bool
1822 var s *Server
1823 const maxTries = 5
1824 var ln net.Listener
1825 Try:
1826 for try := 0; try < maxTries; try++ {
1827 ln = newLocalListener(t)
1828 addr := ln.Addr().String()
1829 ln.Close()
1830 t.Logf("Got %v", addr)
1831 lnc := make(chan net.Listener, 1)
1832 SetTestHookServerServe(func(s *Server, ln net.Listener) {
1833 lnc <- ln
1834 })
1835 s = &Server{
1836 Addr: addr,
1837 TLSConfig: tlsConf,
1838 }
1839 errc := make(chan error, 1)
1840 go func() { errc <- s.ListenAndServeTLS("", "") }()
1841 select {
1842 case err := <-errc:
1843 t.Logf("On try #%v: %v", try+1, err)
1844 continue
1845 case ln = <-lnc:
1846 ok = true
1847 t.Logf("Listening on %v", ln.Addr().String())
1848 break Try
1849 }
1850 }
1851 if !ok {
1852 t.Fatalf("Failed to start up after %d tries", maxTries)
1853 }
1854 defer ln.Close()
1855 c, err := tls.Dial("tcp", ln.Addr().String(), &tls.Config{
1856 InsecureSkipVerify: true,
1857 NextProtos: []string{"h2", "http/1.1"},
1858 })
1859 if err != nil {
1860 t.Fatal(err)
1861 }
1862 defer c.Close()
1863 if got, want := c.ConnectionState().NegotiatedProtocol, "h2"; got != want {
1864 t.Errorf("NegotiatedProtocol = %q; want %q", got, want)
1865 }
1866 if got, want := c.ConnectionState().NegotiatedProtocolIsMutual, true; got != want {
1867 t.Errorf("NegotiatedProtocolIsMutual = %v; want %v", got, want)
1868 }
1869 }
1870
1871 type serverExpectTest struct {
1872 contentLength int
1873 chunked bool
1874 expectation string
1875 readBody bool
1876 expectedResponse string
1877 }
1878
1879 func expectTest(contentLength int, expectation string, readBody bool, expectedResponse string) serverExpectTest {
1880 return serverExpectTest{
1881 contentLength: contentLength,
1882 expectation: expectation,
1883 readBody: readBody,
1884 expectedResponse: expectedResponse,
1885 }
1886 }
1887
1888 var serverExpectTests = []serverExpectTest{
1889
1890 expectTest(100, "100-continue", true, "100 Continue"),
1891 expectTest(100, "100-cOntInUE", true, "100 Continue"),
1892
1893
1894 expectTest(100, "", true, "200 OK"),
1895
1896
1897
1898 expectTest(100, "100-continue", false, "401 Unauthorized"),
1899
1900 expectTest(100, "", false, "401 Unauthorized"),
1901
1902
1903 expectTest(0, "a-pony", false, "417 Expectation Failed"),
1904
1905
1906 expectTest(0, "100-continue", true, "200 OK"),
1907
1908 expectTest(0, "100-continue", false, "401 Unauthorized"),
1909
1910 {
1911 expectation: "100-continue",
1912 readBody: true,
1913 chunked: true,
1914 expectedResponse: "100 Continue",
1915 },
1916 }
1917
1918
1919
1920 func TestServerExpect(t *testing.T) { run(t, testServerExpect, []testMode{http1Mode}) }
1921 func testServerExpect(t *testing.T, mode testMode) {
1922 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1923
1924
1925
1926 if strings.Contains(r.URL.RawQuery, "readbody=true") {
1927 io.ReadAll(r.Body)
1928 w.Write([]byte("Hi"))
1929 } else {
1930 w.WriteHeader(StatusUnauthorized)
1931 }
1932 })).ts
1933
1934 runTest := func(test serverExpectTest) {
1935 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
1936 if err != nil {
1937 t.Fatalf("Dial: %v", err)
1938 }
1939 defer conn.Close()
1940
1941
1942
1943 writeBody := test.contentLength != 0 && strings.ToLower(test.expectation) != "100-continue"
1944
1945 wg := sync.WaitGroup{}
1946 wg.Add(1)
1947 defer wg.Wait()
1948
1949 go func() {
1950 defer wg.Done()
1951
1952 contentLen := fmt.Sprintf("Content-Length: %d", test.contentLength)
1953 if test.chunked {
1954 contentLen = "Transfer-Encoding: chunked"
1955 }
1956 _, err := fmt.Fprintf(conn, "POST /?readbody=%v HTTP/1.1\r\n"+
1957 "Connection: close\r\n"+
1958 "%s\r\n"+
1959 "Expect: %s\r\nHost: foo\r\n\r\n",
1960 test.readBody, contentLen, test.expectation)
1961 if err != nil {
1962 t.Errorf("On test %#v, error writing request headers: %v", test, err)
1963 return
1964 }
1965 if writeBody {
1966 var targ io.WriteCloser = struct {
1967 io.Writer
1968 io.Closer
1969 }{
1970 conn,
1971 io.NopCloser(nil),
1972 }
1973 if test.chunked {
1974 targ = httputil.NewChunkedWriter(conn)
1975 }
1976 body := strings.Repeat("A", test.contentLength)
1977 _, err = fmt.Fprint(targ, body)
1978 if err == nil {
1979 err = targ.Close()
1980 }
1981 if err != nil {
1982 if !test.readBody {
1983
1984
1985 t.Logf("On test %#v, acceptable error writing request body: %v", test, err)
1986 return
1987 }
1988 t.Errorf("On test %#v, error writing request body: %v", test, err)
1989 }
1990 }
1991 }()
1992 bufr := bufio.NewReader(conn)
1993 line, err := bufr.ReadString('\n')
1994 if err != nil {
1995 if writeBody && !test.readBody {
1996
1997
1998
1999
2000
2001 t.Logf("On test %#v, acceptable error from ReadString: %v", test, err)
2002 return
2003 }
2004 t.Fatalf("On test %#v, ReadString: %v", test, err)
2005 }
2006 if !strings.Contains(line, test.expectedResponse) {
2007 t.Errorf("On test %#v, got first line = %q; want %q", test, line, test.expectedResponse)
2008 }
2009 }
2010
2011 for _, test := range serverExpectTests {
2012 runTest(test)
2013 }
2014 }
2015
2016
2017
2018 func TestServerUnreadRequestBodyLittle(t *testing.T) {
2019 setParallel(t)
2020 defer afterTest(t)
2021 conn := new(testConn)
2022 body := strings.Repeat("x", 100<<10)
2023 conn.readBuf.Write([]byte(fmt.Sprintf(
2024 "POST / HTTP/1.1\r\n"+
2025 "Host: test\r\n"+
2026 "Content-Length: %d\r\n"+
2027 "\r\n", len(body))))
2028 conn.readBuf.Write([]byte(body))
2029
2030 done := make(chan bool)
2031
2032 readBufLen := func() int {
2033 conn.readMu.Lock()
2034 defer conn.readMu.Unlock()
2035 return conn.readBuf.Len()
2036 }
2037
2038 ls := &oneConnListener{conn}
2039 go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
2040 defer close(done)
2041 if bufLen := readBufLen(); bufLen < len(body)/2 {
2042 t.Errorf("on request, read buffer length is %d; expected about 100 KB", bufLen)
2043 }
2044 rw.WriteHeader(200)
2045 rw.(Flusher).Flush()
2046 if g, e := readBufLen(), 0; g != e {
2047 t.Errorf("after WriteHeader, read buffer length is %d; want %d", g, e)
2048 }
2049 if c := rw.Header().Get("Connection"); c != "" {
2050 t.Errorf(`Connection header = %q; want ""`, c)
2051 }
2052 }))
2053 <-done
2054 }
2055
2056
2057
2058
2059 func TestServerUnreadRequestBodyLarge(t *testing.T) {
2060 setParallel(t)
2061 if testing.Short() && testenv.Builder() == "" {
2062 t.Log("skipping in short mode")
2063 }
2064 conn := new(testConn)
2065 body := strings.Repeat("x", 1<<20)
2066 conn.readBuf.Write([]byte(fmt.Sprintf(
2067 "POST / HTTP/1.1\r\n"+
2068 "Host: test\r\n"+
2069 "Content-Length: %d\r\n"+
2070 "\r\n", len(body))))
2071 conn.readBuf.Write([]byte(body))
2072 conn.closec = make(chan bool, 1)
2073
2074 ls := &oneConnListener{conn}
2075 go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
2076 if conn.readBuf.Len() < len(body)/2 {
2077 t.Errorf("on request, read buffer length is %d; expected about 1MB", conn.readBuf.Len())
2078 }
2079 rw.WriteHeader(200)
2080 rw.(Flusher).Flush()
2081 if conn.readBuf.Len() < len(body)/2 {
2082 t.Errorf("post-WriteHeader, read buffer length is %d; expected about 1MB", conn.readBuf.Len())
2083 }
2084 }))
2085 <-conn.closec
2086
2087 if res := conn.writeBuf.String(); !strings.Contains(res, "Connection: close") {
2088 t.Errorf("Expected a Connection: close header; got response: %s", res)
2089 }
2090 }
2091
2092 type handlerBodyCloseTest struct {
2093 bodySize int
2094 bodyChunked bool
2095 reqConnClose bool
2096
2097 wantEOFSearch bool
2098 wantNextReq bool
2099 }
2100
2101 func (t handlerBodyCloseTest) connectionHeader() string {
2102 if t.reqConnClose {
2103 return "Connection: close\r\n"
2104 }
2105 return ""
2106 }
2107
2108 var handlerBodyCloseTests = [...]handlerBodyCloseTest{
2109
2110
2111 0: {
2112 bodySize: 20 << 10,
2113 bodyChunked: false,
2114 reqConnClose: false,
2115 wantEOFSearch: true,
2116 wantNextReq: true,
2117 },
2118
2119
2120
2121 1: {
2122 bodySize: 20 << 10,
2123 bodyChunked: true,
2124 reqConnClose: false,
2125 wantEOFSearch: true,
2126 wantNextReq: true,
2127 },
2128
2129
2130
2131
2132 2: {
2133 bodySize: 20 << 10,
2134 bodyChunked: false,
2135 reqConnClose: true,
2136 wantEOFSearch: false,
2137 wantNextReq: false,
2138 },
2139
2140
2141
2142
2143
2144
2145 3: {
2146 bodySize: 20 << 10,
2147 bodyChunked: true,
2148 reqConnClose: true,
2149 wantEOFSearch: true,
2150 wantNextReq: false,
2151 },
2152
2153
2154 4: {
2155 bodySize: 1 << 20,
2156 bodyChunked: false,
2157 reqConnClose: false,
2158 wantEOFSearch: false,
2159 wantNextReq: false,
2160 },
2161
2162
2163 5: {
2164 bodySize: 1 << 20,
2165 bodyChunked: true,
2166 reqConnClose: false,
2167 wantEOFSearch: true,
2168 wantNextReq: false,
2169 },
2170
2171
2172
2173
2174 6: {
2175 bodySize: 1 << 20,
2176 bodyChunked: true,
2177 reqConnClose: true,
2178 wantEOFSearch: true,
2179 wantNextReq: false,
2180 },
2181
2182
2183
2184 7: {
2185 bodySize: 1 << 20,
2186 bodyChunked: false,
2187 reqConnClose: true,
2188 wantEOFSearch: false,
2189 wantNextReq: false,
2190 },
2191 }
2192
2193 func TestHandlerBodyClose(t *testing.T) {
2194 setParallel(t)
2195 if testing.Short() && testenv.Builder() == "" {
2196 t.Skip("skipping in -short mode")
2197 }
2198 for i, tt := range handlerBodyCloseTests {
2199 testHandlerBodyClose(t, i, tt)
2200 }
2201 }
2202
2203 func testHandlerBodyClose(t *testing.T, i int, tt handlerBodyCloseTest) {
2204 conn := new(testConn)
2205 body := strings.Repeat("x", tt.bodySize)
2206 if tt.bodyChunked {
2207 conn.readBuf.WriteString("POST / HTTP/1.1\r\n" +
2208 "Host: test\r\n" +
2209 tt.connectionHeader() +
2210 "Transfer-Encoding: chunked\r\n" +
2211 "\r\n")
2212 cw := internal.NewChunkedWriter(&conn.readBuf)
2213 io.WriteString(cw, body)
2214 cw.Close()
2215 conn.readBuf.WriteString("\r\n")
2216 } else {
2217 conn.readBuf.Write([]byte(fmt.Sprintf(
2218 "POST / HTTP/1.1\r\n"+
2219 "Host: test\r\n"+
2220 tt.connectionHeader()+
2221 "Content-Length: %d\r\n"+
2222 "\r\n", len(body))))
2223 conn.readBuf.Write([]byte(body))
2224 }
2225 if !tt.reqConnClose {
2226 conn.readBuf.WriteString("GET / HTTP/1.1\r\nHost: test\r\n\r\n")
2227 }
2228 conn.closec = make(chan bool, 1)
2229
2230 readBufLen := func() int {
2231 conn.readMu.Lock()
2232 defer conn.readMu.Unlock()
2233 return conn.readBuf.Len()
2234 }
2235
2236 ls := &oneConnListener{conn}
2237 var numReqs int
2238 var size0, size1 int
2239 go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
2240 numReqs++
2241 if numReqs == 1 {
2242 size0 = readBufLen()
2243 req.Body.Close()
2244 size1 = readBufLen()
2245 }
2246 }))
2247 <-conn.closec
2248 if numReqs < 1 || numReqs > 2 {
2249 t.Fatalf("%d. bug in test. unexpected number of requests = %d", i, numReqs)
2250 }
2251 didSearch := size0 != size1
2252 if didSearch != tt.wantEOFSearch {
2253 t.Errorf("%d. did EOF search = %v; want %v (size went from %d to %d)", i, didSearch, !didSearch, size0, size1)
2254 }
2255 if tt.wantNextReq && numReqs != 2 {
2256 t.Errorf("%d. numReq = %d; want 2", i, numReqs)
2257 }
2258 }
2259
2260
2261
2262 type testHandlerBodyConsumer struct {
2263 name string
2264 f func(io.ReadCloser)
2265 }
2266
2267 var testHandlerBodyConsumers = []testHandlerBodyConsumer{
2268 {"nil", func(io.ReadCloser) {}},
2269 {"close", func(r io.ReadCloser) { r.Close() }},
2270 {"discard", func(r io.ReadCloser) { io.Copy(io.Discard, r) }},
2271 }
2272
2273 func TestRequestBodyReadErrorClosesConnection(t *testing.T) {
2274 setParallel(t)
2275 defer afterTest(t)
2276 for _, handler := range testHandlerBodyConsumers {
2277 conn := new(testConn)
2278 conn.readBuf.WriteString("POST /public HTTP/1.1\r\n" +
2279 "Host: test\r\n" +
2280 "Transfer-Encoding: chunked\r\n" +
2281 "\r\n" +
2282 "hax\r\n" +
2283 "GET /secret HTTP/1.1\r\n" +
2284 "Host: test\r\n" +
2285 "\r\n")
2286
2287 conn.closec = make(chan bool, 1)
2288 ls := &oneConnListener{conn}
2289 var numReqs int
2290 go Serve(ls, HandlerFunc(func(_ ResponseWriter, req *Request) {
2291 numReqs++
2292 if strings.Contains(req.URL.Path, "secret") {
2293 t.Error("Request for /secret encountered, should not have happened.")
2294 }
2295 handler.f(req.Body)
2296 }))
2297 <-conn.closec
2298 if numReqs != 1 {
2299 t.Errorf("Handler %v: got %d reqs; want 1", handler.name, numReqs)
2300 }
2301 }
2302 }
2303
2304 func TestInvalidTrailerClosesConnection(t *testing.T) {
2305 setParallel(t)
2306 defer afterTest(t)
2307 for _, handler := range testHandlerBodyConsumers {
2308 conn := new(testConn)
2309 conn.readBuf.WriteString("POST /public HTTP/1.1\r\n" +
2310 "Host: test\r\n" +
2311 "Trailer: hack\r\n" +
2312 "Transfer-Encoding: chunked\r\n" +
2313 "\r\n" +
2314 "3\r\n" +
2315 "hax\r\n" +
2316 "0\r\n" +
2317 "I'm not a valid trailer\r\n" +
2318 "GET /secret HTTP/1.1\r\n" +
2319 "Host: test\r\n" +
2320 "\r\n")
2321
2322 conn.closec = make(chan bool, 1)
2323 ln := &oneConnListener{conn}
2324 var numReqs int
2325 go Serve(ln, HandlerFunc(func(_ ResponseWriter, req *Request) {
2326 numReqs++
2327 if strings.Contains(req.URL.Path, "secret") {
2328 t.Errorf("Handler %s, Request for /secret encountered, should not have happened.", handler.name)
2329 }
2330 handler.f(req.Body)
2331 }))
2332 <-conn.closec
2333 if numReqs != 1 {
2334 t.Errorf("Handler %s: got %d reqs; want 1", handler.name, numReqs)
2335 }
2336 }
2337 }
2338
2339
2340
2341
2342 type slowTestConn struct {
2343
2344 script []any
2345 closec chan bool
2346
2347 mu sync.Mutex
2348 rd, wd time.Time
2349 noopConn
2350 }
2351
2352 func (c *slowTestConn) SetDeadline(t time.Time) error {
2353 c.SetReadDeadline(t)
2354 c.SetWriteDeadline(t)
2355 return nil
2356 }
2357
2358 func (c *slowTestConn) SetReadDeadline(t time.Time) error {
2359 c.mu.Lock()
2360 defer c.mu.Unlock()
2361 c.rd = t
2362 return nil
2363 }
2364
2365 func (c *slowTestConn) SetWriteDeadline(t time.Time) error {
2366 c.mu.Lock()
2367 defer c.mu.Unlock()
2368 c.wd = t
2369 return nil
2370 }
2371
2372 func (c *slowTestConn) Read(b []byte) (n int, err error) {
2373 c.mu.Lock()
2374 defer c.mu.Unlock()
2375 restart:
2376 if !c.rd.IsZero() && time.Now().After(c.rd) {
2377 return 0, syscall.ETIMEDOUT
2378 }
2379 if len(c.script) == 0 {
2380 return 0, io.EOF
2381 }
2382
2383 switch cue := c.script[0].(type) {
2384 case time.Duration:
2385 if !c.rd.IsZero() {
2386
2387
2388 if remaining := time.Until(c.rd); remaining < cue {
2389 c.script[0] = cue - remaining
2390 time.Sleep(remaining)
2391 return 0, syscall.ETIMEDOUT
2392 }
2393 }
2394 c.script = c.script[1:]
2395 time.Sleep(cue)
2396 goto restart
2397
2398 case string:
2399 n = copy(b, cue)
2400
2401 if len(cue) > n {
2402 c.script[0] = cue[n:]
2403 } else {
2404 c.script = c.script[1:]
2405 }
2406
2407 default:
2408 panic("unknown cue in slowTestConn script")
2409 }
2410
2411 return
2412 }
2413
2414 func (c *slowTestConn) Close() error {
2415 select {
2416 case c.closec <- true:
2417 default:
2418 }
2419 return nil
2420 }
2421
2422 func (c *slowTestConn) Write(b []byte) (int, error) {
2423 if !c.wd.IsZero() && time.Now().After(c.wd) {
2424 return 0, syscall.ETIMEDOUT
2425 }
2426 return len(b), nil
2427 }
2428
2429 func TestRequestBodyTimeoutClosesConnection(t *testing.T) {
2430 if testing.Short() {
2431 t.Skip("skipping in -short mode")
2432 }
2433 defer afterTest(t)
2434 for _, handler := range testHandlerBodyConsumers {
2435 conn := &slowTestConn{
2436 script: []any{
2437 "POST /public HTTP/1.1\r\n" +
2438 "Host: test\r\n" +
2439 "Content-Length: 10000\r\n" +
2440 "\r\n",
2441 "foo bar baz",
2442 600 * time.Millisecond,
2443 "GET /secret HTTP/1.1\r\n" +
2444 "Host: test\r\n" +
2445 "\r\n",
2446 },
2447 closec: make(chan bool, 1),
2448 }
2449 ls := &oneConnListener{conn}
2450
2451 var numReqs int
2452 s := Server{
2453 Handler: HandlerFunc(func(_ ResponseWriter, req *Request) {
2454 numReqs++
2455 if strings.Contains(req.URL.Path, "secret") {
2456 t.Error("Request for /secret encountered, should not have happened.")
2457 }
2458 handler.f(req.Body)
2459 }),
2460 ReadTimeout: 400 * time.Millisecond,
2461 }
2462 go s.Serve(ls)
2463 <-conn.closec
2464
2465 if numReqs != 1 {
2466 t.Errorf("Handler %v: got %d reqs; want 1", handler.name, numReqs)
2467 }
2468 }
2469 }
2470
2471
2472 type cancelableTimeoutContext struct {
2473 context.Context
2474 }
2475
2476 func (c cancelableTimeoutContext) Err() error {
2477 if c.Context.Err() != nil {
2478 return context.DeadlineExceeded
2479 }
2480 return nil
2481 }
2482
2483 func TestTimeoutHandler(t *testing.T) { run(t, testTimeoutHandler) }
2484 func testTimeoutHandler(t *testing.T, mode testMode) {
2485 sendHi := make(chan bool, 1)
2486 writeErrors := make(chan error, 1)
2487 sayHi := HandlerFunc(func(w ResponseWriter, r *Request) {
2488 <-sendHi
2489 _, werr := w.Write([]byte("hi"))
2490 writeErrors <- werr
2491 })
2492 ctx, cancel := context.WithCancel(context.Background())
2493 h := NewTestTimeoutHandler(sayHi, cancelableTimeoutContext{ctx})
2494 cst := newClientServerTest(t, mode, h)
2495
2496
2497 sendHi <- true
2498 res, err := cst.c.Get(cst.ts.URL)
2499 if err != nil {
2500 t.Error(err)
2501 }
2502 if g, e := res.StatusCode, StatusOK; g != e {
2503 t.Errorf("got res.StatusCode %d; expected %d", g, e)
2504 }
2505 body, _ := io.ReadAll(res.Body)
2506 if g, e := string(body), "hi"; g != e {
2507 t.Errorf("got body %q; expected %q", g, e)
2508 }
2509 if g := <-writeErrors; g != nil {
2510 t.Errorf("got unexpected Write error on first request: %v", g)
2511 }
2512
2513
2514 cancel()
2515
2516 res, err = cst.c.Get(cst.ts.URL)
2517 if err != nil {
2518 t.Error(err)
2519 }
2520 if g, e := res.StatusCode, StatusServiceUnavailable; g != e {
2521 t.Errorf("got res.StatusCode %d; expected %d", g, e)
2522 }
2523 body, _ = io.ReadAll(res.Body)
2524 if !strings.Contains(string(body), "<title>Timeout</title>") {
2525 t.Errorf("expected timeout body; got %q", string(body))
2526 }
2527 if g, w := res.Header.Get("Content-Type"), "text/html; charset=utf-8"; g != w {
2528 t.Errorf("response content-type = %q; want %q", g, w)
2529 }
2530
2531
2532
2533 sendHi <- true
2534 if g, e := <-writeErrors, ErrHandlerTimeout; g != e {
2535 t.Errorf("expected Write error of %v; got %v", e, g)
2536 }
2537 }
2538
2539
2540 func TestTimeoutHandlerRace(t *testing.T) { run(t, testTimeoutHandlerRace) }
2541 func testTimeoutHandlerRace(t *testing.T, mode testMode) {
2542 delayHi := HandlerFunc(func(w ResponseWriter, r *Request) {
2543 ms, _ := strconv.Atoi(r.URL.Path[1:])
2544 if ms == 0 {
2545 ms = 1
2546 }
2547 for i := 0; i < ms; i++ {
2548 w.Write([]byte("hi"))
2549 time.Sleep(time.Millisecond)
2550 }
2551 })
2552
2553 ts := newClientServerTest(t, mode, TimeoutHandler(delayHi, 20*time.Millisecond, "")).ts
2554
2555 c := ts.Client()
2556
2557 var wg sync.WaitGroup
2558 gate := make(chan bool, 10)
2559 n := 50
2560 if testing.Short() {
2561 n = 10
2562 gate = make(chan bool, 3)
2563 }
2564 for i := 0; i < n; i++ {
2565 gate <- true
2566 wg.Add(1)
2567 go func() {
2568 defer wg.Done()
2569 defer func() { <-gate }()
2570 res, err := c.Get(fmt.Sprintf("%s/%d", ts.URL, rand.Intn(50)))
2571 if err == nil {
2572 io.Copy(io.Discard, res.Body)
2573 res.Body.Close()
2574 }
2575 }()
2576 }
2577 wg.Wait()
2578 }
2579
2580
2581
2582 func TestTimeoutHandlerRaceHeader(t *testing.T) { run(t, testTimeoutHandlerRaceHeader) }
2583 func testTimeoutHandlerRaceHeader(t *testing.T, mode testMode) {
2584 delay204 := HandlerFunc(func(w ResponseWriter, r *Request) {
2585 w.WriteHeader(204)
2586 })
2587
2588 ts := newClientServerTest(t, mode, TimeoutHandler(delay204, time.Nanosecond, "")).ts
2589
2590 var wg sync.WaitGroup
2591 gate := make(chan bool, 50)
2592 n := 500
2593 if testing.Short() {
2594 n = 10
2595 }
2596
2597 c := ts.Client()
2598 for i := 0; i < n; i++ {
2599 gate <- true
2600 wg.Add(1)
2601 go func() {
2602 defer wg.Done()
2603 defer func() { <-gate }()
2604 res, err := c.Get(ts.URL)
2605 if err != nil {
2606
2607
2608 t.Log(err)
2609 return
2610 }
2611 defer res.Body.Close()
2612 io.Copy(io.Discard, res.Body)
2613 }()
2614 }
2615 wg.Wait()
2616 }
2617
2618
2619 func TestTimeoutHandlerRaceHeaderTimeout(t *testing.T) { run(t, testTimeoutHandlerRaceHeaderTimeout) }
2620 func testTimeoutHandlerRaceHeaderTimeout(t *testing.T, mode testMode) {
2621 sendHi := make(chan bool, 1)
2622 writeErrors := make(chan error, 1)
2623 sayHi := HandlerFunc(func(w ResponseWriter, r *Request) {
2624 w.Header().Set("Content-Type", "text/plain")
2625 <-sendHi
2626 _, werr := w.Write([]byte("hi"))
2627 writeErrors <- werr
2628 })
2629 ctx, cancel := context.WithCancel(context.Background())
2630 h := NewTestTimeoutHandler(sayHi, cancelableTimeoutContext{ctx})
2631 cst := newClientServerTest(t, mode, h)
2632
2633
2634 sendHi <- true
2635 res, err := cst.c.Get(cst.ts.URL)
2636 if err != nil {
2637 t.Error(err)
2638 }
2639 if g, e := res.StatusCode, StatusOK; g != e {
2640 t.Errorf("got res.StatusCode %d; expected %d", g, e)
2641 }
2642 body, _ := io.ReadAll(res.Body)
2643 if g, e := string(body), "hi"; g != e {
2644 t.Errorf("got body %q; expected %q", g, e)
2645 }
2646 if g := <-writeErrors; g != nil {
2647 t.Errorf("got unexpected Write error on first request: %v", g)
2648 }
2649
2650
2651 cancel()
2652
2653 res, err = cst.c.Get(cst.ts.URL)
2654 if err != nil {
2655 t.Error(err)
2656 }
2657 if g, e := res.StatusCode, StatusServiceUnavailable; g != e {
2658 t.Errorf("got res.StatusCode %d; expected %d", g, e)
2659 }
2660 body, _ = io.ReadAll(res.Body)
2661 if !strings.Contains(string(body), "<title>Timeout</title>") {
2662 t.Errorf("expected timeout body; got %q", string(body))
2663 }
2664
2665
2666
2667 sendHi <- true
2668 if g, e := <-writeErrors, ErrHandlerTimeout; g != e {
2669 t.Errorf("expected Write error of %v; got %v", e, g)
2670 }
2671 }
2672
2673
2674 func TestTimeoutHandlerStartTimerWhenServing(t *testing.T) {
2675 run(t, testTimeoutHandlerStartTimerWhenServing)
2676 }
2677 func testTimeoutHandlerStartTimerWhenServing(t *testing.T, mode testMode) {
2678 if testing.Short() {
2679 t.Skip("skipping sleeping test in -short mode")
2680 }
2681 var handler HandlerFunc = func(w ResponseWriter, _ *Request) {
2682 w.WriteHeader(StatusNoContent)
2683 }
2684 timeout := 300 * time.Millisecond
2685 ts := newClientServerTest(t, mode, TimeoutHandler(handler, timeout, "")).ts
2686 defer ts.Close()
2687
2688 c := ts.Client()
2689
2690
2691
2692
2693 time.Sleep(2 * timeout)
2694 res, err := c.Get(ts.URL)
2695 if err != nil {
2696 t.Fatal(err)
2697 }
2698 defer res.Body.Close()
2699 if res.StatusCode != StatusNoContent {
2700 t.Errorf("got res.StatusCode %d, want %v", res.StatusCode, StatusNoContent)
2701 }
2702 }
2703
2704 func TestTimeoutHandlerContextCanceled(t *testing.T) { run(t, testTimeoutHandlerContextCanceled) }
2705 func testTimeoutHandlerContextCanceled(t *testing.T, mode testMode) {
2706 writeErrors := make(chan error, 1)
2707 sayHi := HandlerFunc(func(w ResponseWriter, r *Request) {
2708 w.Header().Set("Content-Type", "text/plain")
2709 var err error
2710
2711
2712
2713 for i := 0; i < 100; i++ {
2714 _, err = w.Write([]byte("a"))
2715 if err != nil {
2716 break
2717 }
2718 time.Sleep(1 * time.Millisecond)
2719 }
2720 writeErrors <- err
2721 })
2722 ctx, cancel := context.WithCancel(context.Background())
2723 cancel()
2724 h := NewTestTimeoutHandler(sayHi, ctx)
2725 cst := newClientServerTest(t, mode, h)
2726 defer cst.close()
2727
2728 res, err := cst.c.Get(cst.ts.URL)
2729 if err != nil {
2730 t.Error(err)
2731 }
2732 if g, e := res.StatusCode, StatusServiceUnavailable; g != e {
2733 t.Errorf("got res.StatusCode %d; expected %d", g, e)
2734 }
2735 body, _ := io.ReadAll(res.Body)
2736 if g, e := string(body), ""; g != e {
2737 t.Errorf("got body %q; expected %q", g, e)
2738 }
2739 if g, e := <-writeErrors, context.Canceled; g != e {
2740 t.Errorf("got unexpected Write in handler: %v, want %g", g, e)
2741 }
2742 }
2743
2744
2745 func TestTimeoutHandlerEmptyResponse(t *testing.T) { run(t, testTimeoutHandlerEmptyResponse) }
2746 func testTimeoutHandlerEmptyResponse(t *testing.T, mode testMode) {
2747 var handler HandlerFunc = func(w ResponseWriter, _ *Request) {
2748
2749 }
2750 timeout := 300 * time.Millisecond
2751 ts := newClientServerTest(t, mode, TimeoutHandler(handler, timeout, "")).ts
2752
2753 c := ts.Client()
2754
2755 res, err := c.Get(ts.URL)
2756 if err != nil {
2757 t.Fatal(err)
2758 }
2759 defer res.Body.Close()
2760 if res.StatusCode != StatusOK {
2761 t.Errorf("got res.StatusCode %d, want %v", res.StatusCode, StatusOK)
2762 }
2763 }
2764
2765
2766 func TestTimeoutHandlerPanicRecovery(t *testing.T) {
2767 wrapper := func(h Handler) Handler {
2768 return TimeoutHandler(h, time.Second, "")
2769 }
2770 run(t, func(t *testing.T, mode testMode) {
2771 testHandlerPanic(t, false, mode, wrapper, "intentional death for testing")
2772 }, testNotParallel)
2773 }
2774
2775 func TestRedirectBadPath(t *testing.T) {
2776
2777
2778 rr := httptest.NewRecorder()
2779 req := &Request{
2780 Method: "GET",
2781 URL: &url.URL{
2782 Scheme: "http",
2783 Path: "not-empty-but-no-leading-slash",
2784 },
2785 }
2786 Redirect(rr, req, "", 304)
2787 if rr.Code != 304 {
2788 t.Errorf("Code = %d; want 304", rr.Code)
2789 }
2790 }
2791
2792
2793 func TestRedirect(t *testing.T) {
2794 req, _ := NewRequest("GET", "http://example.com/qux/", nil)
2795
2796 var tests = []struct {
2797 in string
2798 want string
2799 }{
2800
2801 {"http://foobar.com/baz", "http://foobar.com/baz"},
2802
2803 {"https://foobar.com/baz", "https://foobar.com/baz"},
2804
2805 {"test://foobar.com/baz", "test://foobar.com/baz"},
2806
2807 {"//foobar.com/baz", "//foobar.com/baz"},
2808
2809 {"/foobar.com/baz", "/foobar.com/baz"},
2810
2811 {"foobar.com/baz", "/qux/foobar.com/baz"},
2812
2813 {"../quux/foobar.com/baz", "/quux/foobar.com/baz"},
2814
2815 {"///foobar.com/baz", "/foobar.com/baz"},
2816
2817
2818 {"/foo?next=http://bar.com/", "/foo?next=http://bar.com/"},
2819 {"http://localhost:8080/_ah/login?continue=http://localhost:8080/",
2820 "http://localhost:8080/_ah/login?continue=http://localhost:8080/"},
2821
2822 {"/фубар", "/%d1%84%d1%83%d0%b1%d0%b0%d1%80"},
2823 {"http://foo.com/фубар", "http://foo.com/%d1%84%d1%83%d0%b1%d0%b0%d1%80"},
2824 }
2825
2826 for _, tt := range tests {
2827 rec := httptest.NewRecorder()
2828 Redirect(rec, req, tt.in, 302)
2829 if got, want := rec.Code, 302; got != want {
2830 t.Errorf("Redirect(%q) generated status code %v; want %v", tt.in, got, want)
2831 }
2832 if got := rec.Header().Get("Location"); got != tt.want {
2833 t.Errorf("Redirect(%q) generated Location header %q; want %q", tt.in, got, tt.want)
2834 }
2835 }
2836 }
2837
2838
2839
2840 func TestRedirectContentTypeAndBody(t *testing.T) {
2841 type ctHeader struct {
2842 Values []string
2843 }
2844
2845 var tests = []struct {
2846 method string
2847 ct *ctHeader
2848 wantCT string
2849 wantBody string
2850 }{
2851 {MethodGet, nil, "text/html; charset=utf-8", "<a href=\"/foo\">Found</a>.\n\n"},
2852 {MethodHead, nil, "text/html; charset=utf-8", ""},
2853 {MethodPost, nil, "", ""},
2854 {MethodDelete, nil, "", ""},
2855 {"foo", nil, "", ""},
2856 {MethodGet, &ctHeader{[]string{"application/test"}}, "application/test", ""},
2857 {MethodGet, &ctHeader{[]string{}}, "", ""},
2858 {MethodGet, &ctHeader{nil}, "", ""},
2859 }
2860 for _, tt := range tests {
2861 req := httptest.NewRequest(tt.method, "http://example.com/qux/", nil)
2862 rec := httptest.NewRecorder()
2863 if tt.ct != nil {
2864 rec.Header()["Content-Type"] = tt.ct.Values
2865 }
2866 Redirect(rec, req, "/foo", 302)
2867 if got, want := rec.Code, 302; got != want {
2868 t.Errorf("Redirect(%q, %#v) generated status code %v; want %v", tt.method, tt.ct, got, want)
2869 }
2870 if got, want := rec.Header().Get("Content-Type"), tt.wantCT; got != want {
2871 t.Errorf("Redirect(%q, %#v) generated Content-Type header %q; want %q", tt.method, tt.ct, got, want)
2872 }
2873 resp := rec.Result()
2874 body, err := io.ReadAll(resp.Body)
2875 if err != nil {
2876 t.Fatal(err)
2877 }
2878 if got, want := string(body), tt.wantBody; got != want {
2879 t.Errorf("Redirect(%q, %#v) generated Body %q; want %q", tt.method, tt.ct, got, want)
2880 }
2881 }
2882 }
2883
2884
2885
2886
2887
2888
2889
2890 func TestZeroLengthPostAndResponse(t *testing.T) { run(t, testZeroLengthPostAndResponse) }
2891
2892 func testZeroLengthPostAndResponse(t *testing.T, mode testMode) {
2893 cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, r *Request) {
2894 all, err := io.ReadAll(r.Body)
2895 if err != nil {
2896 t.Fatalf("handler ReadAll: %v", err)
2897 }
2898 if len(all) != 0 {
2899 t.Errorf("handler got %d bytes; expected 0", len(all))
2900 }
2901 rw.Header().Set("Content-Length", "0")
2902 }))
2903
2904 req, err := NewRequest("POST", cst.ts.URL, strings.NewReader(""))
2905 if err != nil {
2906 t.Fatal(err)
2907 }
2908 req.ContentLength = 0
2909
2910 var resp [5]*Response
2911 for i := range resp {
2912 resp[i], err = cst.c.Do(req)
2913 if err != nil {
2914 t.Fatalf("client post #%d: %v", i, err)
2915 }
2916 }
2917
2918 for i := range resp {
2919 all, err := io.ReadAll(resp[i].Body)
2920 if err != nil {
2921 t.Fatalf("req #%d: client ReadAll: %v", i, err)
2922 }
2923 if len(all) != 0 {
2924 t.Errorf("req #%d: client got %d bytes; expected 0", i, len(all))
2925 }
2926 }
2927 }
2928
2929 func TestHandlerPanicNil(t *testing.T) {
2930 run(t, func(t *testing.T, mode testMode) {
2931 testHandlerPanic(t, false, mode, nil, nil)
2932 }, testNotParallel)
2933 }
2934
2935 func TestHandlerPanic(t *testing.T) {
2936 run(t, func(t *testing.T, mode testMode) {
2937 testHandlerPanic(t, false, mode, nil, "intentional death for testing")
2938 }, testNotParallel)
2939 }
2940
2941 func TestHandlerPanicWithHijack(t *testing.T) {
2942
2943 run(t, func(t *testing.T, mode testMode) {
2944 testHandlerPanic(t, true, mode, nil, "intentional death for testing")
2945 }, []testMode{http1Mode})
2946 }
2947
2948 func testHandlerPanic(t *testing.T, withHijack bool, mode testMode, wrapper func(Handler) Handler, panicValue any) {
2949
2950
2951
2952
2953
2954
2955
2956
2957 pr, pw := io.Pipe()
2958 defer pw.Close()
2959
2960 var handler Handler = HandlerFunc(func(w ResponseWriter, r *Request) {
2961 if withHijack {
2962 rwc, _, err := w.(Hijacker).Hijack()
2963 if err != nil {
2964 t.Logf("unexpected error: %v", err)
2965 }
2966 defer rwc.Close()
2967 }
2968 panic(panicValue)
2969 })
2970 if wrapper != nil {
2971 handler = wrapper(handler)
2972 }
2973 cst := newClientServerTest(t, mode, handler, func(ts *httptest.Server) {
2974 ts.Config.ErrorLog = log.New(pw, "", 0)
2975 })
2976
2977
2978 done := make(chan bool, 1)
2979 go func() {
2980 buf := make([]byte, 4<<10)
2981 _, err := pr.Read(buf)
2982 pr.Close()
2983 if err != nil && err != io.EOF {
2984 t.Error(err)
2985 }
2986 done <- true
2987 }()
2988
2989 _, err := cst.c.Get(cst.ts.URL)
2990 if err == nil {
2991 t.Logf("expected an error")
2992 }
2993
2994 if panicValue == nil {
2995 return
2996 }
2997
2998 <-done
2999 }
3000
3001 type terrorWriter struct{ t *testing.T }
3002
3003 func (w terrorWriter) Write(p []byte) (int, error) {
3004 w.t.Errorf("%s", p)
3005 return len(p), nil
3006 }
3007
3008
3009
3010 func TestServerWriteHijackZeroBytes(t *testing.T) {
3011 run(t, testServerWriteHijackZeroBytes, []testMode{http1Mode})
3012 }
3013 func testServerWriteHijackZeroBytes(t *testing.T, mode testMode) {
3014 done := make(chan struct{})
3015 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3016 defer close(done)
3017 w.(Flusher).Flush()
3018 conn, _, err := w.(Hijacker).Hijack()
3019 if err != nil {
3020 t.Errorf("Hijack: %v", err)
3021 return
3022 }
3023 defer conn.Close()
3024 _, err = w.Write(nil)
3025 if err != ErrHijacked {
3026 t.Errorf("Write error = %v; want ErrHijacked", err)
3027 }
3028 }), func(ts *httptest.Server) {
3029 ts.Config.ErrorLog = log.New(terrorWriter{t}, "Unexpected write: ", 0)
3030 }).ts
3031
3032 c := ts.Client()
3033 res, err := c.Get(ts.URL)
3034 if err != nil {
3035 t.Fatal(err)
3036 }
3037 res.Body.Close()
3038 <-done
3039 }
3040
3041 func TestServerNoDate(t *testing.T) {
3042 run(t, func(t *testing.T, mode testMode) {
3043 testServerNoHeader(t, mode, "Date")
3044 })
3045 }
3046
3047 func TestServerContentType(t *testing.T) {
3048 run(t, func(t *testing.T, mode testMode) {
3049 testServerNoHeader(t, mode, "Content-Type")
3050 })
3051 }
3052
3053 func testServerNoHeader(t *testing.T, mode testMode, header string) {
3054 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3055 w.Header()[header] = nil
3056 io.WriteString(w, "<html>foo</html>")
3057 }))
3058 res, err := cst.c.Get(cst.ts.URL)
3059 if err != nil {
3060 t.Fatal(err)
3061 }
3062 res.Body.Close()
3063 if got, ok := res.Header[header]; ok {
3064 t.Fatalf("Expected no %s header; got %q", header, got)
3065 }
3066 }
3067
3068 func TestStripPrefix(t *testing.T) { run(t, testStripPrefix) }
3069 func testStripPrefix(t *testing.T, mode testMode) {
3070 h := HandlerFunc(func(w ResponseWriter, r *Request) {
3071 w.Header().Set("X-Path", r.URL.Path)
3072 w.Header().Set("X-RawPath", r.URL.RawPath)
3073 })
3074 ts := newClientServerTest(t, mode, StripPrefix("/foo/bar", h)).ts
3075
3076 c := ts.Client()
3077
3078 cases := []struct {
3079 reqPath string
3080 path string
3081 rawPath string
3082 }{
3083 {"/foo/bar/qux", "/qux", ""},
3084 {"/foo/bar%2Fqux", "/qux", "%2Fqux"},
3085 {"/foo%2Fbar/qux", "", ""},
3086 {"/bar", "", ""},
3087 }
3088 for _, tc := range cases {
3089 t.Run(tc.reqPath, func(t *testing.T) {
3090 res, err := c.Get(ts.URL + tc.reqPath)
3091 if err != nil {
3092 t.Fatal(err)
3093 }
3094 res.Body.Close()
3095 if tc.path == "" {
3096 if res.StatusCode != StatusNotFound {
3097 t.Errorf("got %q, want 404 Not Found", res.Status)
3098 }
3099 return
3100 }
3101 if res.StatusCode != StatusOK {
3102 t.Fatalf("got %q, want 200 OK", res.Status)
3103 }
3104 if g, w := res.Header.Get("X-Path"), tc.path; g != w {
3105 t.Errorf("got Path %q, want %q", g, w)
3106 }
3107 if g, w := res.Header.Get("X-RawPath"), tc.rawPath; g != w {
3108 t.Errorf("got RawPath %q, want %q", g, w)
3109 }
3110 })
3111 }
3112 }
3113
3114
3115 func TestStripPrefixNotModifyRequest(t *testing.T) {
3116 h := StripPrefix("/foo", NotFoundHandler())
3117 req := httptest.NewRequest("GET", "/foo/bar", nil)
3118 h.ServeHTTP(httptest.NewRecorder(), req)
3119 if req.URL.Path != "/foo/bar" {
3120 t.Errorf("StripPrefix should not modify the provided Request, but it did")
3121 }
3122 }
3123
3124 func TestRequestLimit(t *testing.T) { run(t, testRequestLimit) }
3125 func testRequestLimit(t *testing.T, mode testMode) {
3126 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3127 t.Fatalf("didn't expect to get request in Handler")
3128 }), optQuietLog)
3129 req, _ := NewRequest("GET", cst.ts.URL, nil)
3130 var bytesPerHeader = len("header12345: val12345\r\n")
3131 for i := 0; i < ((DefaultMaxHeaderBytes+4096)/bytesPerHeader)+1; i++ {
3132 req.Header.Set(fmt.Sprintf("header%05d", i), fmt.Sprintf("val%05d", i))
3133 }
3134 res, err := cst.c.Do(req)
3135 if res != nil {
3136 defer res.Body.Close()
3137 }
3138 if mode == http2Mode {
3139
3140
3141
3142
3143 if err == nil && res.StatusCode != 431 {
3144 t.Fatalf("expected 431 response status; got: %d %s", res.StatusCode, res.Status)
3145 }
3146 } else {
3147
3148
3149
3150
3151 if err != nil {
3152 t.Fatalf("Do: %v", err)
3153 }
3154 if res.StatusCode != 431 {
3155 t.Fatalf("expected 431 response status; got: %d %s", res.StatusCode, res.Status)
3156 }
3157 }
3158 }
3159
3160 type neverEnding byte
3161
3162 func (b neverEnding) Read(p []byte) (n int, err error) {
3163 for i := range p {
3164 p[i] = byte(b)
3165 }
3166 return len(p), nil
3167 }
3168
3169 type bodyLimitReader struct {
3170 mu sync.Mutex
3171 count int
3172 limit int
3173 closed chan struct{}
3174 }
3175
3176 func (r *bodyLimitReader) Read(p []byte) (int, error) {
3177 r.mu.Lock()
3178 defer r.mu.Unlock()
3179 select {
3180 case <-r.closed:
3181 return 0, errors.New("closed")
3182 default:
3183 }
3184 if r.count > r.limit {
3185 return 0, errors.New("at limit")
3186 }
3187 r.count += len(p)
3188 for i := range p {
3189 p[i] = 'a'
3190 }
3191 return len(p), nil
3192 }
3193
3194 func (r *bodyLimitReader) Close() error {
3195 r.mu.Lock()
3196 defer r.mu.Unlock()
3197 close(r.closed)
3198 return nil
3199 }
3200
3201 func TestRequestBodyLimit(t *testing.T) { run(t, testRequestBodyLimit) }
3202 func testRequestBodyLimit(t *testing.T, mode testMode) {
3203 const limit = 1 << 20
3204 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3205 r.Body = MaxBytesReader(w, r.Body, limit)
3206 n, err := io.Copy(io.Discard, r.Body)
3207 if err == nil {
3208 t.Errorf("expected error from io.Copy")
3209 }
3210 if n != limit {
3211 t.Errorf("io.Copy = %d, want %d", n, limit)
3212 }
3213 mbErr, ok := err.(*MaxBytesError)
3214 if !ok {
3215 t.Errorf("expected MaxBytesError, got %T", err)
3216 }
3217 if mbErr.Limit != limit {
3218 t.Errorf("MaxBytesError.Limit = %d, want %d", mbErr.Limit, limit)
3219 }
3220 }))
3221
3222 body := &bodyLimitReader{
3223 closed: make(chan struct{}),
3224 limit: limit * 200,
3225 }
3226 req, _ := NewRequest("POST", cst.ts.URL, body)
3227
3228
3229
3230
3231
3232
3233
3234
3235
3236
3237 resp, err := cst.c.Do(req)
3238 if err == nil {
3239 resp.Body.Close()
3240 }
3241
3242
3243 <-body.closed
3244
3245 if body.count > limit*100 {
3246 t.Errorf("handler restricted the request body to %d bytes, but client managed to write %d",
3247 limit, body.count)
3248 }
3249 }
3250
3251
3252
3253 func TestClientWriteShutdown(t *testing.T) { run(t, testClientWriteShutdown) }
3254 func testClientWriteShutdown(t *testing.T, mode testMode) {
3255 if runtime.GOOS == "plan9" {
3256 t.Skip("skipping test; see https://golang.org/issue/17906")
3257 }
3258 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {})).ts
3259 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3260 if err != nil {
3261 t.Fatalf("Dial: %v", err)
3262 }
3263 err = conn.(*net.TCPConn).CloseWrite()
3264 if err != nil {
3265 t.Fatalf("CloseWrite: %v", err)
3266 }
3267
3268 bs, err := io.ReadAll(conn)
3269 if err != nil {
3270 t.Errorf("ReadAll: %v", err)
3271 }
3272 got := string(bs)
3273 if got != "" {
3274 t.Errorf("read %q from server; want nothing", got)
3275 }
3276 }
3277
3278
3279
3280 func TestServerBufferedChunking(t *testing.T) {
3281 conn := new(testConn)
3282 conn.readBuf.Write([]byte("GET / HTTP/1.1\r\nHost: foo\r\n\r\n"))
3283 conn.closec = make(chan bool, 1)
3284 ls := &oneConnListener{conn}
3285 go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
3286 rw.(Flusher).Flush()
3287 rw.Write([]byte{'x'})
3288 rw.Write([]byte{'y'})
3289 rw.Write([]byte{'z'})
3290 }))
3291 <-conn.closec
3292 if !bytes.HasSuffix(conn.writeBuf.Bytes(), []byte("\r\n\r\n3\r\nxyz\r\n0\r\n\r\n")) {
3293 t.Errorf("response didn't end with a single 3 byte 'xyz' chunk; got:\n%q",
3294 conn.writeBuf.Bytes())
3295 }
3296 }
3297
3298
3299
3300
3301
3302 func TestServerGracefulClose(t *testing.T) {
3303
3304 run(t, testServerGracefulClose, []testMode{http1Mode}, testNotParallel)
3305 }
3306 func testServerGracefulClose(t *testing.T, mode testMode) {
3307 runTimeSensitiveTest(t, []time.Duration{
3308 1 * time.Millisecond,
3309 5 * time.Millisecond,
3310 10 * time.Millisecond,
3311 50 * time.Millisecond,
3312 100 * time.Millisecond,
3313 500 * time.Millisecond,
3314 time.Second,
3315 5 * time.Second,
3316 }, func(t *testing.T, timeout time.Duration) error {
3317 SetRSTAvoidanceDelay(t, timeout)
3318 t.Logf("set RST avoidance delay to %v", timeout)
3319
3320 const bodySize = 5 << 20
3321 req := []byte(fmt.Sprintf("POST / HTTP/1.1\r\nHost: foo.com\r\nContent-Length: %d\r\n\r\n", bodySize))
3322 for i := 0; i < bodySize; i++ {
3323 req = append(req, 'x')
3324 }
3325
3326 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3327 Error(w, "bye", StatusUnauthorized)
3328 }))
3329
3330
3331 defer cst.close()
3332 ts := cst.ts
3333
3334 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3335 if err != nil {
3336 return err
3337 }
3338 writeErr := make(chan error)
3339 go func() {
3340 _, err := conn.Write(req)
3341 writeErr <- err
3342 }()
3343 defer func() {
3344 conn.Close()
3345
3346
3347
3348 <-writeErr
3349 }()
3350
3351 br := bufio.NewReader(conn)
3352 lineNum := 0
3353 for {
3354 line, err := br.ReadString('\n')
3355 if err == io.EOF {
3356 break
3357 }
3358 if err != nil {
3359 return fmt.Errorf("ReadLine: %v", err)
3360 }
3361 lineNum++
3362 if lineNum == 1 && !strings.Contains(line, "401 Unauthorized") {
3363 t.Errorf("Response line = %q; want a 401", line)
3364 }
3365 }
3366 return nil
3367 })
3368 }
3369
3370 func TestCaseSensitiveMethod(t *testing.T) { run(t, testCaseSensitiveMethod) }
3371 func testCaseSensitiveMethod(t *testing.T, mode testMode) {
3372 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3373 if r.Method != "get" {
3374 t.Errorf(`Got method %q; want "get"`, r.Method)
3375 }
3376 }))
3377 defer cst.close()
3378 req, _ := NewRequest("get", cst.ts.URL, nil)
3379 res, err := cst.c.Do(req)
3380 if err != nil {
3381 t.Error(err)
3382 return
3383 }
3384
3385 res.Body.Close()
3386 }
3387
3388
3389
3390
3391
3392 func TestContentLengthZero(t *testing.T) {
3393 run(t, testContentLengthZero, []testMode{http1Mode})
3394 }
3395 func testContentLengthZero(t *testing.T, mode testMode) {
3396 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {})).ts
3397
3398 for _, version := range []string{"HTTP/1.0", "HTTP/1.1"} {
3399 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3400 if err != nil {
3401 t.Fatalf("error dialing: %v", err)
3402 }
3403 _, err = fmt.Fprintf(conn, "GET / %v\r\nConnection: keep-alive\r\nHost: foo\r\n\r\n", version)
3404 if err != nil {
3405 t.Fatalf("error writing: %v", err)
3406 }
3407 req, _ := NewRequest("GET", "/", nil)
3408 res, err := ReadResponse(bufio.NewReader(conn), req)
3409 if err != nil {
3410 t.Fatalf("error reading response: %v", err)
3411 }
3412 if te := res.TransferEncoding; len(te) > 0 {
3413 t.Errorf("For version %q, Transfer-Encoding = %q; want none", version, te)
3414 }
3415 if cl := res.ContentLength; cl != 0 {
3416 t.Errorf("For version %q, Content-Length = %v; want 0", version, cl)
3417 }
3418 conn.Close()
3419 }
3420 }
3421
3422 func TestCloseNotifier(t *testing.T) {
3423 run(t, testCloseNotifier, []testMode{http1Mode})
3424 }
3425 func testCloseNotifier(t *testing.T, mode testMode) {
3426 gotReq := make(chan bool, 1)
3427 sawClose := make(chan bool, 1)
3428 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
3429 gotReq <- true
3430 cc := rw.(CloseNotifier).CloseNotify()
3431 <-cc
3432 sawClose <- true
3433 })).ts
3434 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3435 if err != nil {
3436 t.Fatalf("error dialing: %v", err)
3437 }
3438 diec := make(chan bool)
3439 go func() {
3440 _, err = fmt.Fprintf(conn, "GET / HTTP/1.1\r\nConnection: keep-alive\r\nHost: foo\r\n\r\n")
3441 if err != nil {
3442 t.Error(err)
3443 return
3444 }
3445 <-diec
3446 conn.Close()
3447 }()
3448 For:
3449 for {
3450 select {
3451 case <-gotReq:
3452 diec <- true
3453 case <-sawClose:
3454 break For
3455 }
3456 }
3457 ts.Close()
3458 }
3459
3460
3461
3462
3463
3464 func TestCloseNotifierPipelined(t *testing.T) {
3465 run(t, testCloseNotifierPipelined, []testMode{http1Mode})
3466 }
3467 func testCloseNotifierPipelined(t *testing.T, mode testMode) {
3468 gotReq := make(chan bool, 2)
3469 sawClose := make(chan bool, 2)
3470 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
3471 gotReq <- true
3472 cc := rw.(CloseNotifier).CloseNotify()
3473 select {
3474 case <-cc:
3475 t.Error("unexpected CloseNotify")
3476 case <-time.After(100 * time.Millisecond):
3477 }
3478 sawClose <- true
3479 })).ts
3480 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3481 if err != nil {
3482 t.Fatalf("error dialing: %v", err)
3483 }
3484 diec := make(chan bool, 1)
3485 defer close(diec)
3486 go func() {
3487 const req = "GET / HTTP/1.1\r\nConnection: keep-alive\r\nHost: foo\r\n\r\n"
3488 _, err = io.WriteString(conn, req+req)
3489 if err != nil {
3490 t.Error(err)
3491 return
3492 }
3493 <-diec
3494 conn.Close()
3495 }()
3496 reqs := 0
3497 closes := 0
3498 for {
3499 select {
3500 case <-gotReq:
3501 reqs++
3502 if reqs > 2 {
3503 t.Fatal("too many requests")
3504 }
3505 case <-sawClose:
3506 closes++
3507 if closes > 1 {
3508 return
3509 }
3510 }
3511 }
3512 }
3513
3514 func TestCloseNotifierChanLeak(t *testing.T) {
3515 defer afterTest(t)
3516 req := reqBytes("GET / HTTP/1.0\nHost: golang.org")
3517 for i := 0; i < 20; i++ {
3518 var output bytes.Buffer
3519 conn := &rwTestConn{
3520 Reader: bytes.NewReader(req),
3521 Writer: &output,
3522 closec: make(chan bool, 1),
3523 }
3524 ln := &oneConnListener{conn: conn}
3525 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
3526
3527
3528
3529 _ = rw.(CloseNotifier).CloseNotify()
3530 })
3531 go Serve(ln, handler)
3532 <-conn.closec
3533 }
3534 }
3535
3536
3537
3538
3539
3540
3541
3542
3543
3544
3545 func TestHijackAfterCloseNotifier(t *testing.T) {
3546 run(t, testHijackAfterCloseNotifier, []testMode{http1Mode})
3547 }
3548 func testHijackAfterCloseNotifier(t *testing.T, mode testMode) {
3549 script := make(chan string, 2)
3550 script <- "closenotify"
3551 script <- "hijack"
3552 close(script)
3553 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3554 plan := <-script
3555 switch plan {
3556 default:
3557 panic("bogus plan; too many requests")
3558 case "closenotify":
3559 w.(CloseNotifier).CloseNotify()
3560 w.Header().Set("X-Addr", r.RemoteAddr)
3561 case "hijack":
3562 c, _, err := w.(Hijacker).Hijack()
3563 if err != nil {
3564 t.Errorf("Hijack in Handler: %v", err)
3565 return
3566 }
3567 if _, ok := c.(*net.TCPConn); !ok {
3568
3569
3570 t.Errorf("type of hijacked conn is %T; want *net.TCPConn", c)
3571 }
3572 fmt.Fprintf(c, "HTTP/1.0 200 OK\r\nX-Addr: %v\r\nContent-Length: 0\r\n\r\n", r.RemoteAddr)
3573 c.Close()
3574 return
3575 }
3576 })).ts
3577 res1, err := ts.Client().Get(ts.URL)
3578 if err != nil {
3579 log.Fatal(err)
3580 }
3581 res2, err := ts.Client().Get(ts.URL)
3582 if err != nil {
3583 log.Fatal(err)
3584 }
3585 addr1 := res1.Header.Get("X-Addr")
3586 addr2 := res2.Header.Get("X-Addr")
3587 if addr1 == "" || addr1 != addr2 {
3588 t.Errorf("addr1, addr2 = %q, %q; want same", addr1, addr2)
3589 }
3590 }
3591
3592 func TestHijackBeforeRequestBodyRead(t *testing.T) {
3593 run(t, testHijackBeforeRequestBodyRead, []testMode{http1Mode})
3594 }
3595 func testHijackBeforeRequestBodyRead(t *testing.T, mode testMode) {
3596 var requestBody = bytes.Repeat([]byte("a"), 1<<20)
3597 bodyOkay := make(chan bool, 1)
3598 gotCloseNotify := make(chan bool, 1)
3599 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3600 defer close(bodyOkay)
3601
3602 reqBody := r.Body
3603 r.Body = nil
3604
3605 gone := w.(CloseNotifier).CloseNotify()
3606 slurp, err := io.ReadAll(reqBody)
3607 if err != nil {
3608 t.Errorf("Body read: %v", err)
3609 return
3610 }
3611 if len(slurp) != len(requestBody) {
3612 t.Errorf("Backend read %d request body bytes; want %d", len(slurp), len(requestBody))
3613 return
3614 }
3615 if !bytes.Equal(slurp, requestBody) {
3616 t.Error("Backend read wrong request body.")
3617 return
3618 }
3619 bodyOkay <- true
3620 <-gone
3621 gotCloseNotify <- true
3622 })).ts
3623
3624 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3625 if err != nil {
3626 t.Fatal(err)
3627 }
3628 defer conn.Close()
3629
3630 fmt.Fprintf(conn, "POST / HTTP/1.1\r\nHost: foo\r\nContent-Length: %d\r\n\r\n%s",
3631 len(requestBody), requestBody)
3632 if !<-bodyOkay {
3633
3634 return
3635 }
3636 conn.Close()
3637 <-gotCloseNotify
3638 }
3639
3640 func TestOptions(t *testing.T) { run(t, testOptions, []testMode{http1Mode}) }
3641 func testOptions(t *testing.T, mode testMode) {
3642 uric := make(chan string, 2)
3643 mux := NewServeMux()
3644 mux.HandleFunc("/", func(w ResponseWriter, r *Request) {
3645 uric <- r.RequestURI
3646 })
3647 ts := newClientServerTest(t, mode, mux).ts
3648
3649 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3650 if err != nil {
3651 t.Fatal(err)
3652 }
3653 defer conn.Close()
3654
3655
3656 _, err = conn.Write([]byte("OPTIONS * HTTP/1.1\r\nHost: foo.com\r\n\r\n"))
3657 if err != nil {
3658 t.Fatal(err)
3659 }
3660 br := bufio.NewReader(conn)
3661 res, err := ReadResponse(br, &Request{Method: "OPTIONS"})
3662 if err != nil {
3663 t.Fatal(err)
3664 }
3665 if res.StatusCode != 200 {
3666 t.Errorf("Got non-200 response to OPTIONS *: %#v", res)
3667 }
3668
3669
3670 _, err = conn.Write([]byte("GET * HTTP/1.1\r\nHost: foo.com\r\n\r\n"))
3671 if err != nil {
3672 t.Fatal(err)
3673 }
3674 res, err = ReadResponse(br, &Request{Method: "GET"})
3675 if err != nil {
3676 t.Fatal(err)
3677 }
3678 if res.StatusCode != 400 {
3679 t.Errorf("Got non-400 response to GET *: %#v", res)
3680 }
3681
3682 res, err = Get(ts.URL + "/second")
3683 if err != nil {
3684 t.Fatal(err)
3685 }
3686 res.Body.Close()
3687 if got := <-uric; got != "/second" {
3688 t.Errorf("Handler saw request for %q; want /second", got)
3689 }
3690 }
3691
3692 func TestOptionsHandler(t *testing.T) { run(t, testOptionsHandler, []testMode{http1Mode}) }
3693 func testOptionsHandler(t *testing.T, mode testMode) {
3694 rc := make(chan *Request, 1)
3695
3696 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3697 rc <- r
3698 }), func(ts *httptest.Server) {
3699 ts.Config.DisableGeneralOptionsHandler = true
3700 }).ts
3701
3702 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3703 if err != nil {
3704 t.Fatal(err)
3705 }
3706 defer conn.Close()
3707
3708 _, err = conn.Write([]byte("OPTIONS * HTTP/1.1\r\nHost: foo.com\r\n\r\n"))
3709 if err != nil {
3710 t.Fatal(err)
3711 }
3712
3713 if got := <-rc; got.Method != "OPTIONS" || got.RequestURI != "*" {
3714 t.Errorf("Expected OPTIONS * request, got %v", got)
3715 }
3716 }
3717
3718
3719
3720
3721
3722
3723
3724
3725
3726
3727 func TestHeaderToWire(t *testing.T) {
3728 tests := []struct {
3729 name string
3730 handler func(ResponseWriter, *Request)
3731 check func(got, logs string) error
3732 }{
3733 {
3734 name: "write without Header",
3735 handler: func(rw ResponseWriter, r *Request) {
3736 rw.Write([]byte("hello world"))
3737 },
3738 check: func(got, logs string) error {
3739 if !strings.Contains(got, "Content-Length:") {
3740 return errors.New("no content-length")
3741 }
3742 if !strings.Contains(got, "Content-Type: text/plain") {
3743 return errors.New("no content-type")
3744 }
3745 return nil
3746 },
3747 },
3748 {
3749 name: "Header mutation before write",
3750 handler: func(rw ResponseWriter, r *Request) {
3751 h := rw.Header()
3752 h.Set("Content-Type", "some/type")
3753 rw.Write([]byte("hello world"))
3754 h.Set("Too-Late", "bogus")
3755 },
3756 check: func(got, logs string) error {
3757 if !strings.Contains(got, "Content-Length:") {
3758 return errors.New("no content-length")
3759 }
3760 if !strings.Contains(got, "Content-Type: some/type") {
3761 return errors.New("wrong content-type")
3762 }
3763 if strings.Contains(got, "Too-Late") {
3764 return errors.New("don't want too-late header")
3765 }
3766 return nil
3767 },
3768 },
3769 {
3770 name: "write then useless Header mutation",
3771 handler: func(rw ResponseWriter, r *Request) {
3772 rw.Write([]byte("hello world"))
3773 rw.Header().Set("Too-Late", "Write already wrote headers")
3774 },
3775 check: func(got, logs string) error {
3776 if strings.Contains(got, "Too-Late") {
3777 return errors.New("header appeared from after WriteHeader")
3778 }
3779 return nil
3780 },
3781 },
3782 {
3783 name: "flush then write",
3784 handler: func(rw ResponseWriter, r *Request) {
3785 rw.(Flusher).Flush()
3786 rw.Write([]byte("post-flush"))
3787 rw.Header().Set("Too-Late", "Write already wrote headers")
3788 },
3789 check: func(got, logs string) error {
3790 if !strings.Contains(got, "Transfer-Encoding: chunked") {
3791 return errors.New("not chunked")
3792 }
3793 if strings.Contains(got, "Too-Late") {
3794 return errors.New("header appeared from after WriteHeader")
3795 }
3796 return nil
3797 },
3798 },
3799 {
3800 name: "header then flush",
3801 handler: func(rw ResponseWriter, r *Request) {
3802 rw.Header().Set("Content-Type", "some/type")
3803 rw.(Flusher).Flush()
3804 rw.Write([]byte("post-flush"))
3805 rw.Header().Set("Too-Late", "Write already wrote headers")
3806 },
3807 check: func(got, logs string) error {
3808 if !strings.Contains(got, "Transfer-Encoding: chunked") {
3809 return errors.New("not chunked")
3810 }
3811 if strings.Contains(got, "Too-Late") {
3812 return errors.New("header appeared from after WriteHeader")
3813 }
3814 if !strings.Contains(got, "Content-Type: some/type") {
3815 return errors.New("wrong content-type")
3816 }
3817 return nil
3818 },
3819 },
3820 {
3821 name: "sniff-on-first-write content-type",
3822 handler: func(rw ResponseWriter, r *Request) {
3823 rw.Write([]byte("<html><head></head><body>some html</body></html>"))
3824 rw.Header().Set("Content-Type", "x/wrong")
3825 },
3826 check: func(got, logs string) error {
3827 if !strings.Contains(got, "Content-Type: text/html") {
3828 return errors.New("wrong content-type; want html")
3829 }
3830 return nil
3831 },
3832 },
3833 {
3834 name: "explicit content-type wins",
3835 handler: func(rw ResponseWriter, r *Request) {
3836 rw.Header().Set("Content-Type", "some/type")
3837 rw.Write([]byte("<html><head></head><body>some html</body></html>"))
3838 },
3839 check: func(got, logs string) error {
3840 if !strings.Contains(got, "Content-Type: some/type") {
3841 return errors.New("wrong content-type; want html")
3842 }
3843 return nil
3844 },
3845 },
3846 {
3847 name: "empty handler",
3848 handler: func(rw ResponseWriter, r *Request) {
3849 },
3850 check: func(got, logs string) error {
3851 if !strings.Contains(got, "Content-Length: 0") {
3852 return errors.New("want 0 content-length")
3853 }
3854 return nil
3855 },
3856 },
3857 {
3858 name: "only Header, no write",
3859 handler: func(rw ResponseWriter, r *Request) {
3860 rw.Header().Set("Some-Header", "some-value")
3861 },
3862 check: func(got, logs string) error {
3863 if !strings.Contains(got, "Some-Header") {
3864 return errors.New("didn't get header")
3865 }
3866 return nil
3867 },
3868 },
3869 {
3870 name: "WriteHeader call",
3871 handler: func(rw ResponseWriter, r *Request) {
3872 rw.WriteHeader(404)
3873 rw.Header().Set("Too-Late", "some-value")
3874 },
3875 check: func(got, logs string) error {
3876 if !strings.Contains(got, "404") {
3877 return errors.New("wrong status")
3878 }
3879 if strings.Contains(got, "Too-Late") {
3880 return errors.New("shouldn't have seen Too-Late")
3881 }
3882 return nil
3883 },
3884 },
3885 }
3886 for _, tc := range tests {
3887 ht := newHandlerTest(HandlerFunc(tc.handler))
3888 got := ht.rawResponse("GET / HTTP/1.1\nHost: golang.org")
3889 logs := ht.logbuf.String()
3890 if err := tc.check(got, logs); err != nil {
3891 t.Errorf("%s: %v\nGot response:\n%s\n\n%s", tc.name, err, got, logs)
3892 }
3893 }
3894 }
3895
3896 type errorListener struct {
3897 errs []error
3898 }
3899
3900 func (l *errorListener) Accept() (c net.Conn, err error) {
3901 if len(l.errs) == 0 {
3902 return nil, io.EOF
3903 }
3904 err = l.errs[0]
3905 l.errs = l.errs[1:]
3906 return
3907 }
3908
3909 func (l *errorListener) Close() error {
3910 return nil
3911 }
3912
3913 func (l *errorListener) Addr() net.Addr {
3914 return dummyAddr("test-address")
3915 }
3916
3917 func TestAcceptMaxFds(t *testing.T) {
3918 setParallel(t)
3919
3920 ln := &errorListener{[]error{
3921 &net.OpError{
3922 Op: "accept",
3923 Err: syscall.EMFILE,
3924 }}}
3925 server := &Server{
3926 Handler: HandlerFunc(HandlerFunc(func(ResponseWriter, *Request) {})),
3927 ErrorLog: log.New(io.Discard, "", 0),
3928 }
3929 err := server.Serve(ln)
3930 if err != io.EOF {
3931 t.Errorf("got error %v, want EOF", err)
3932 }
3933 }
3934
3935 func TestWriteAfterHijack(t *testing.T) {
3936 req := reqBytes("GET / HTTP/1.1\nHost: golang.org")
3937 var buf strings.Builder
3938 wrotec := make(chan bool, 1)
3939 conn := &rwTestConn{
3940 Reader: bytes.NewReader(req),
3941 Writer: &buf,
3942 closec: make(chan bool, 1),
3943 }
3944 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
3945 conn, bufrw, err := rw.(Hijacker).Hijack()
3946 if err != nil {
3947 t.Error(err)
3948 return
3949 }
3950 go func() {
3951 bufrw.Write([]byte("[hijack-to-bufw]"))
3952 bufrw.Flush()
3953 conn.Write([]byte("[hijack-to-conn]"))
3954 conn.Close()
3955 wrotec <- true
3956 }()
3957 })
3958 ln := &oneConnListener{conn: conn}
3959 go Serve(ln, handler)
3960 <-conn.closec
3961 <-wrotec
3962 if g, w := buf.String(), "[hijack-to-bufw][hijack-to-conn]"; g != w {
3963 t.Errorf("wrote %q; want %q", g, w)
3964 }
3965 }
3966
3967 func TestDoubleHijack(t *testing.T) {
3968 req := reqBytes("GET / HTTP/1.1\nHost: golang.org")
3969 var buf bytes.Buffer
3970 conn := &rwTestConn{
3971 Reader: bytes.NewReader(req),
3972 Writer: &buf,
3973 closec: make(chan bool, 1),
3974 }
3975 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
3976 conn, _, err := rw.(Hijacker).Hijack()
3977 if err != nil {
3978 t.Error(err)
3979 return
3980 }
3981 _, _, err = rw.(Hijacker).Hijack()
3982 if err == nil {
3983 t.Errorf("got err = nil; want err != nil")
3984 }
3985 conn.Close()
3986 })
3987 ln := &oneConnListener{conn: conn}
3988 go Serve(ln, handler)
3989 <-conn.closec
3990 }
3991
3992
3993
3994
3995
3996
3997
3998 func TestHTTP10ConnectionHeader(t *testing.T) {
3999 run(t, testHTTP10ConnectionHeader, []testMode{http1Mode})
4000 }
4001 func testHTTP10ConnectionHeader(t *testing.T, mode testMode) {
4002 mux := NewServeMux()
4003 mux.Handle("/", HandlerFunc(func(ResponseWriter, *Request) {}))
4004 ts := newClientServerTest(t, mode, mux).ts
4005
4006
4007 tests := []struct {
4008 req string
4009 expect []string
4010 }{
4011 {
4012 req: "GET / HTTP/1.0\r\n\r\n",
4013 expect: nil,
4014 },
4015 {
4016 req: "OPTIONS * HTTP/1.0\r\n\r\n",
4017 expect: nil,
4018 },
4019 {
4020 req: "GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n",
4021 expect: []string{"keep-alive"},
4022 },
4023 }
4024
4025 for _, tt := range tests {
4026 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
4027 if err != nil {
4028 t.Fatal("dial err:", err)
4029 }
4030
4031 _, err = fmt.Fprint(conn, tt.req)
4032 if err != nil {
4033 t.Fatal("conn write err:", err)
4034 }
4035
4036 resp, err := ReadResponse(bufio.NewReader(conn), &Request{Method: "GET"})
4037 if err != nil {
4038 t.Fatal("ReadResponse err:", err)
4039 }
4040 conn.Close()
4041 resp.Body.Close()
4042
4043 got := resp.Header["Connection"]
4044 if !slices.Equal(got, tt.expect) {
4045 t.Errorf("wrong Connection headers for request %q. Got %q expect %q", tt.req, got, tt.expect)
4046 }
4047 }
4048 }
4049
4050
4051 func TestServerReaderFromOrder(t *testing.T) { run(t, testServerReaderFromOrder) }
4052 func testServerReaderFromOrder(t *testing.T, mode testMode) {
4053 pr, pw := io.Pipe()
4054 const size = 3 << 20
4055 cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
4056 rw.Header().Set("Content-Type", "text/plain")
4057 done := make(chan bool)
4058 go func() {
4059 io.Copy(rw, pr)
4060 close(done)
4061 }()
4062 time.Sleep(25 * time.Millisecond)
4063 n, err := io.Copy(io.Discard, req.Body)
4064 if err != nil {
4065 t.Errorf("handler Copy: %v", err)
4066 return
4067 }
4068 if n != size {
4069 t.Errorf("handler Copy = %d; want %d", n, size)
4070 }
4071 pw.Write([]byte("hi"))
4072 pw.Close()
4073 <-done
4074 }))
4075
4076 req, err := NewRequest("POST", cst.ts.URL, io.LimitReader(neverEnding('a'), size))
4077 if err != nil {
4078 t.Fatal(err)
4079 }
4080 res, err := cst.c.Do(req)
4081 if err != nil {
4082 t.Fatal(err)
4083 }
4084 all, err := io.ReadAll(res.Body)
4085 if err != nil {
4086 t.Fatal(err)
4087 }
4088 res.Body.Close()
4089 if string(all) != "hi" {
4090 t.Errorf("Body = %q; want hi", all)
4091 }
4092 }
4093
4094
4095 func TestCodesPreventingContentTypeAndBody(t *testing.T) {
4096 for _, code := range []int{StatusNotModified, StatusNoContent} {
4097 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
4098 if r.URL.Path == "/header" {
4099 w.Header().Set("Content-Length", "123")
4100 }
4101 w.WriteHeader(code)
4102 if r.URL.Path == "/more" {
4103 w.Write([]byte("stuff"))
4104 }
4105 }))
4106 for _, req := range []string{
4107 "GET / HTTP/1.0",
4108 "GET /header HTTP/1.0",
4109 "GET /more HTTP/1.0",
4110 "GET / HTTP/1.1\nHost: foo",
4111 "GET /header HTTP/1.1\nHost: foo",
4112 "GET /more HTTP/1.1\nHost: foo",
4113 } {
4114 got := ht.rawResponse(req)
4115 wantStatus := fmt.Sprintf("%d %s", code, StatusText(code))
4116 if !strings.Contains(got, wantStatus) {
4117 t.Errorf("Code %d: Wanted %q Modified for %q: %s", code, wantStatus, req, got)
4118 } else if strings.Contains(got, "Content-Length") {
4119 t.Errorf("Code %d: Got a Content-Length from %q: %s", code, req, got)
4120 } else if strings.Contains(got, "stuff") {
4121 t.Errorf("Code %d: Response contains a body from %q: %s", code, req, got)
4122 }
4123 }
4124 }
4125 }
4126
4127 func TestContentTypeOkayOn204(t *testing.T) {
4128 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
4129 w.Header().Set("Content-Length", "123")
4130 w.Header().Set("Content-Type", "foo/bar")
4131 w.WriteHeader(204)
4132 }))
4133 got := ht.rawResponse("GET / HTTP/1.1\nHost: foo")
4134 if !strings.Contains(got, "Content-Type: foo/bar") {
4135 t.Errorf("Response = %q; want Content-Type: foo/bar", got)
4136 }
4137 if strings.Contains(got, "Content-Length: 123") {
4138 t.Errorf("Response = %q; don't want a Content-Length", got)
4139 }
4140 }
4141
4142
4143
4144
4145
4146
4147
4148 func TestTransportAndServerSharedBodyRace(t *testing.T) {
4149 run(t, testTransportAndServerSharedBodyRace, testNotParallel)
4150 }
4151 func testTransportAndServerSharedBodyRace(t *testing.T, mode testMode) {
4152
4153
4154
4155
4156 runTimeSensitiveTest(t, []time.Duration{
4157 1 * time.Millisecond,
4158 5 * time.Millisecond,
4159 10 * time.Millisecond,
4160 50 * time.Millisecond,
4161 100 * time.Millisecond,
4162 500 * time.Millisecond,
4163 time.Second,
4164 5 * time.Second,
4165 }, func(t *testing.T, timeout time.Duration) error {
4166 SetRSTAvoidanceDelay(t, timeout)
4167 t.Logf("set RST avoidance delay to %v", timeout)
4168
4169 const bodySize = 1 << 20
4170
4171 var wg sync.WaitGroup
4172 backend := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
4173
4174
4175
4176
4177
4178
4179
4180
4181 wg.Add(1)
4182 defer wg.Done()
4183
4184 n, err := io.CopyN(rw, req.Body, bodySize)
4185 t.Logf("backend CopyN: %v, %v", n, err)
4186 <-req.Context().Done()
4187 }))
4188
4189
4190 defer func() {
4191 wg.Wait()
4192 backend.close()
4193 }()
4194
4195 var proxy *clientServerTest
4196 proxy = newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
4197 req2, _ := NewRequest("POST", backend.ts.URL, req.Body)
4198 req2.ContentLength = bodySize
4199 cancel := make(chan struct{})
4200 req2.Cancel = cancel
4201
4202 bresp, err := proxy.c.Do(req2)
4203 if err != nil {
4204 t.Errorf("Proxy outbound request: %v", err)
4205 return
4206 }
4207 _, err = io.CopyN(io.Discard, bresp.Body, bodySize/2)
4208 if err != nil {
4209 t.Errorf("Proxy copy error: %v", err)
4210 return
4211 }
4212 t.Cleanup(func() { bresp.Body.Close() })
4213
4214
4215
4216
4217
4218
4219 if mode == http2Mode {
4220 close(cancel)
4221 } else {
4222 proxy.c.Transport.(*Transport).CancelRequest(req2)
4223 }
4224 rw.Write([]byte("OK"))
4225 }))
4226 defer proxy.close()
4227
4228 req, _ := NewRequest("POST", proxy.ts.URL, io.LimitReader(neverEnding('a'), bodySize))
4229 res, err := proxy.c.Do(req)
4230 if err != nil {
4231 return fmt.Errorf("original request: %v", err)
4232 }
4233 res.Body.Close()
4234 return nil
4235 })
4236 }
4237
4238
4239
4240
4241 func TestRequestBodyCloseDoesntBlock(t *testing.T) {
4242 run(t, testRequestBodyCloseDoesntBlock, []testMode{http1Mode})
4243 }
4244 func testRequestBodyCloseDoesntBlock(t *testing.T, mode testMode) {
4245 if testing.Short() {
4246 t.Skip("skipping in -short mode")
4247 }
4248
4249 readErrCh := make(chan error, 1)
4250 errCh := make(chan error, 2)
4251
4252 server := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
4253 go func(body io.Reader) {
4254 _, err := body.Read(make([]byte, 100))
4255 readErrCh <- err
4256 }(req.Body)
4257 time.Sleep(500 * time.Millisecond)
4258 })).ts
4259
4260 closeConn := make(chan bool)
4261 defer close(closeConn)
4262 go func() {
4263 conn, err := net.Dial("tcp", server.Listener.Addr().String())
4264 if err != nil {
4265 errCh <- err
4266 return
4267 }
4268 defer conn.Close()
4269 _, err = conn.Write([]byte("POST / HTTP/1.1\r\nConnection: close\r\nHost: foo\r\nContent-Length: 100000\r\n\r\n"))
4270 if err != nil {
4271 errCh <- err
4272 return
4273 }
4274
4275
4276 <-closeConn
4277 }()
4278 select {
4279 case err := <-readErrCh:
4280 if err == nil {
4281 t.Error("Read was nil. Expected error.")
4282 }
4283 case err := <-errCh:
4284 t.Error(err)
4285 }
4286 }
4287
4288
4289 func TestResponseWriterWriteString(t *testing.T) {
4290 okc := make(chan bool, 1)
4291 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
4292 _, ok := w.(io.StringWriter)
4293 okc <- ok
4294 }))
4295 ht.rawResponse("GET / HTTP/1.0")
4296 select {
4297 case ok := <-okc:
4298 if !ok {
4299 t.Error("ResponseWriter did not implement io.StringWriter")
4300 }
4301 default:
4302 t.Error("handler was never called")
4303 }
4304 }
4305
4306 func TestAppendTime(t *testing.T) {
4307 var b [len(TimeFormat)]byte
4308 t1 := time.Date(2013, 9, 21, 15, 41, 0, 0, time.FixedZone("CEST", 2*60*60))
4309 res := ExportAppendTime(b[:0], t1)
4310 t2, err := ParseTime(string(res))
4311 if err != nil {
4312 t.Fatalf("Error parsing time: %s", err)
4313 }
4314 if !t1.Equal(t2) {
4315 t.Fatalf("Times differ; expected: %v, got %v (%s)", t1, t2, string(res))
4316 }
4317 }
4318
4319 func TestServerConnState(t *testing.T) { run(t, testServerConnState, []testMode{http1Mode}) }
4320 func testServerConnState(t *testing.T, mode testMode) {
4321 handler := map[string]func(w ResponseWriter, r *Request){
4322 "/": func(w ResponseWriter, r *Request) {
4323 fmt.Fprintf(w, "Hello.")
4324 },
4325 "/close": func(w ResponseWriter, r *Request) {
4326 w.Header().Set("Connection", "close")
4327 fmt.Fprintf(w, "Hello.")
4328 },
4329 "/hijack": func(w ResponseWriter, r *Request) {
4330 c, _, _ := w.(Hijacker).Hijack()
4331 c.Write([]byte("HTTP/1.0 200 OK\r\nConnection: close\r\n\r\nHello."))
4332 c.Close()
4333 },
4334 "/hijack-panic": func(w ResponseWriter, r *Request) {
4335 c, _, _ := w.(Hijacker).Hijack()
4336 c.Write([]byte("HTTP/1.0 200 OK\r\nConnection: close\r\n\r\nHello."))
4337 c.Close()
4338 panic("intentional panic")
4339 },
4340 }
4341
4342
4343 type stateLog struct {
4344 active net.Conn
4345 got []ConnState
4346 want []ConnState
4347 complete chan<- struct{}
4348 }
4349 activeLog := make(chan *stateLog, 1)
4350
4351
4352
4353
4354 wantLog := func(doRequests func(), want ...ConnState) {
4355 t.Helper()
4356 complete := make(chan struct{})
4357 activeLog <- &stateLog{want: want, complete: complete}
4358
4359 doRequests()
4360
4361 <-complete
4362 sl := <-activeLog
4363 if !slices.Equal(sl.got, sl.want) {
4364 t.Errorf("Request(s) produced unexpected state sequence.\nGot: %v\nWant: %v", sl.got, sl.want)
4365 }
4366
4367
4368
4369 }
4370
4371 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4372 handler[r.URL.Path](w, r)
4373 }), func(ts *httptest.Server) {
4374 ts.Config.ErrorLog = log.New(io.Discard, "", 0)
4375 ts.Config.ConnState = func(c net.Conn, state ConnState) {
4376 if c == nil {
4377 t.Errorf("nil conn seen in state %s", state)
4378 return
4379 }
4380 sl := <-activeLog
4381 if sl.active == nil && state == StateNew {
4382 sl.active = c
4383 } else if sl.active != c {
4384 t.Errorf("unexpected conn in state %s", state)
4385 activeLog <- sl
4386 return
4387 }
4388 sl.got = append(sl.got, state)
4389 if sl.complete != nil && (len(sl.got) >= len(sl.want) || !slices.Equal(sl.got, sl.want[:len(sl.got)])) {
4390 close(sl.complete)
4391 sl.complete = nil
4392 }
4393 activeLog <- sl
4394 }
4395 }).ts
4396 defer func() {
4397 activeLog <- &stateLog{}
4398 ts.Close()
4399 }()
4400
4401 c := ts.Client()
4402
4403 mustGet := func(url string, headers ...string) {
4404 t.Helper()
4405 req, err := NewRequest("GET", url, nil)
4406 if err != nil {
4407 t.Fatal(err)
4408 }
4409 for len(headers) > 0 {
4410 req.Header.Add(headers[0], headers[1])
4411 headers = headers[2:]
4412 }
4413 res, err := c.Do(req)
4414 if err != nil {
4415 t.Errorf("Error fetching %s: %v", url, err)
4416 return
4417 }
4418 _, err = io.ReadAll(res.Body)
4419 defer res.Body.Close()
4420 if err != nil {
4421 t.Errorf("Error reading %s: %v", url, err)
4422 }
4423 }
4424
4425 wantLog(func() {
4426 mustGet(ts.URL + "/")
4427 mustGet(ts.URL + "/close")
4428 }, StateNew, StateActive, StateIdle, StateActive, StateClosed)
4429
4430 wantLog(func() {
4431 mustGet(ts.URL + "/")
4432 mustGet(ts.URL+"/", "Connection", "close")
4433 }, StateNew, StateActive, StateIdle, StateActive, StateClosed)
4434
4435 wantLog(func() {
4436 mustGet(ts.URL + "/hijack")
4437 }, StateNew, StateActive, StateHijacked)
4438
4439 wantLog(func() {
4440 mustGet(ts.URL + "/hijack-panic")
4441 }, StateNew, StateActive, StateHijacked)
4442
4443 wantLog(func() {
4444 c, err := net.Dial("tcp", ts.Listener.Addr().String())
4445 if err != nil {
4446 t.Fatal(err)
4447 }
4448 c.Close()
4449 }, StateNew, StateClosed)
4450
4451 wantLog(func() {
4452 c, err := net.Dial("tcp", ts.Listener.Addr().String())
4453 if err != nil {
4454 t.Fatal(err)
4455 }
4456 if _, err := io.WriteString(c, "BOGUS REQUEST\r\n\r\n"); err != nil {
4457 t.Fatal(err)
4458 }
4459 c.Read(make([]byte, 1))
4460 c.Close()
4461 }, StateNew, StateActive, StateClosed)
4462
4463 wantLog(func() {
4464 c, err := net.Dial("tcp", ts.Listener.Addr().String())
4465 if err != nil {
4466 t.Fatal(err)
4467 }
4468 if _, err := io.WriteString(c, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n"); err != nil {
4469 t.Fatal(err)
4470 }
4471 res, err := ReadResponse(bufio.NewReader(c), nil)
4472 if err != nil {
4473 t.Fatal(err)
4474 }
4475 if _, err := io.Copy(io.Discard, res.Body); err != nil {
4476 t.Fatal(err)
4477 }
4478 c.Close()
4479 }, StateNew, StateActive, StateIdle, StateClosed)
4480 }
4481
4482 func TestServerKeepAlivesEnabledResultClose(t *testing.T) {
4483 run(t, testServerKeepAlivesEnabledResultClose, []testMode{http1Mode})
4484 }
4485 func testServerKeepAlivesEnabledResultClose(t *testing.T, mode testMode) {
4486 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4487 }), func(ts *httptest.Server) {
4488 ts.Config.SetKeepAlivesEnabled(false)
4489 }).ts
4490 res, err := ts.Client().Get(ts.URL)
4491 if err != nil {
4492 t.Fatal(err)
4493 }
4494 defer res.Body.Close()
4495 if !res.Close {
4496 t.Errorf("Body.Close == false; want true")
4497 }
4498 }
4499
4500
4501 func TestServerEmptyBodyRace(t *testing.T) { run(t, testServerEmptyBodyRace) }
4502 func testServerEmptyBodyRace(t *testing.T, mode testMode) {
4503 var n int32
4504 cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
4505 atomic.AddInt32(&n, 1)
4506 }), optQuietLog)
4507 var wg sync.WaitGroup
4508 const reqs = 20
4509 for i := 0; i < reqs; i++ {
4510 wg.Add(1)
4511 go func() {
4512 defer wg.Done()
4513 res, err := cst.c.Get(cst.ts.URL)
4514 if err != nil {
4515
4516
4517 time.Sleep(10 * time.Millisecond)
4518 res, err = cst.c.Get(cst.ts.URL)
4519 if err != nil {
4520 t.Error(err)
4521 return
4522 }
4523 }
4524 defer res.Body.Close()
4525 _, err = io.Copy(io.Discard, res.Body)
4526 if err != nil {
4527 t.Error(err)
4528 return
4529 }
4530 }()
4531 }
4532 wg.Wait()
4533 if got := atomic.LoadInt32(&n); got != reqs {
4534 t.Errorf("handler ran %d times; want %d", got, reqs)
4535 }
4536 }
4537
4538 func TestServerConnStateNew(t *testing.T) {
4539 sawNew := false
4540 srv := &Server{
4541 ConnState: func(c net.Conn, state ConnState) {
4542 if state == StateNew {
4543 sawNew = true
4544 }
4545 },
4546 Handler: HandlerFunc(func(w ResponseWriter, r *Request) {}),
4547 }
4548 srv.Serve(&oneConnListener{
4549 conn: &rwTestConn{
4550 Reader: strings.NewReader("GET / HTTP/1.1\r\nHost: foo\r\n\r\n"),
4551 Writer: io.Discard,
4552 },
4553 })
4554 if !sawNew {
4555 t.Error("StateNew not seen")
4556 }
4557 }
4558
4559 type closeWriteTestConn struct {
4560 rwTestConn
4561 didCloseWrite bool
4562 }
4563
4564 func (c *closeWriteTestConn) CloseWrite() error {
4565 c.didCloseWrite = true
4566 return nil
4567 }
4568
4569 func TestCloseWrite(t *testing.T) {
4570 SetRSTAvoidanceDelay(t, 1*time.Millisecond)
4571
4572 var srv Server
4573 var testConn closeWriteTestConn
4574 c := ExportServerNewConn(&srv, &testConn)
4575 ExportCloseWriteAndWait(c)
4576 if !testConn.didCloseWrite {
4577 t.Error("didn't see CloseWrite call")
4578 }
4579 }
4580
4581
4582
4583
4584
4585
4586
4587
4588 func TestServerFlushAndHijack(t *testing.T) { run(t, testServerFlushAndHijack, []testMode{http1Mode}) }
4589 func testServerFlushAndHijack(t *testing.T, mode testMode) {
4590 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4591 io.WriteString(w, "Hello, ")
4592 w.(Flusher).Flush()
4593 conn, buf, _ := w.(Hijacker).Hijack()
4594 buf.WriteString("6\r\nworld!\r\n0\r\n\r\n")
4595 if err := buf.Flush(); err != nil {
4596 t.Error(err)
4597 }
4598 if err := conn.Close(); err != nil {
4599 t.Error(err)
4600 }
4601 })).ts
4602 res, err := Get(ts.URL)
4603 if err != nil {
4604 t.Fatal(err)
4605 }
4606 defer res.Body.Close()
4607 all, err := io.ReadAll(res.Body)
4608 if err != nil {
4609 t.Fatal(err)
4610 }
4611 if want := "Hello, world!"; string(all) != want {
4612 t.Errorf("Got %q; want %q", all, want)
4613 }
4614 }
4615
4616
4617
4618
4619
4620
4621
4622 func TestServerKeepAliveAfterWriteError(t *testing.T) {
4623 run(t, testServerKeepAliveAfterWriteError, []testMode{http1Mode})
4624 }
4625 func testServerKeepAliveAfterWriteError(t *testing.T, mode testMode) {
4626 if testing.Short() {
4627 t.Skip("skipping in -short mode")
4628 }
4629 const numReq = 3
4630 addrc := make(chan string, numReq)
4631 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4632 addrc <- r.RemoteAddr
4633 time.Sleep(500 * time.Millisecond)
4634 w.(Flusher).Flush()
4635 }), func(ts *httptest.Server) {
4636 ts.Config.WriteTimeout = 250 * time.Millisecond
4637 }).ts
4638
4639 errc := make(chan error, numReq)
4640 go func() {
4641 defer close(errc)
4642 for i := 0; i < numReq; i++ {
4643 res, err := Get(ts.URL)
4644 if res != nil {
4645 res.Body.Close()
4646 }
4647 errc <- err
4648 }
4649 }()
4650
4651 addrSeen := map[string]bool{}
4652 numOkay := 0
4653 for {
4654 select {
4655 case v := <-addrc:
4656 addrSeen[v] = true
4657 case err, ok := <-errc:
4658 if !ok {
4659 if len(addrSeen) != numReq {
4660 t.Errorf("saw %d unique client addresses; want %d", len(addrSeen), numReq)
4661 }
4662 if numOkay != 0 {
4663 t.Errorf("got %d successful client requests; want 0", numOkay)
4664 }
4665 return
4666 }
4667 if err == nil {
4668 numOkay++
4669 }
4670 }
4671 }
4672 }
4673
4674
4675
4676 func TestNoContentLengthIfTransferEncoding(t *testing.T) {
4677 run(t, testNoContentLengthIfTransferEncoding, []testMode{http1Mode})
4678 }
4679 func testNoContentLengthIfTransferEncoding(t *testing.T, mode testMode) {
4680 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4681 w.Header().Set("Transfer-Encoding", "foo")
4682 io.WriteString(w, "<html>")
4683 })).ts
4684 c, err := net.Dial("tcp", ts.Listener.Addr().String())
4685 if err != nil {
4686 t.Fatalf("Dial: %v", err)
4687 }
4688 defer c.Close()
4689 if _, err := io.WriteString(c, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n"); err != nil {
4690 t.Fatal(err)
4691 }
4692 bs := bufio.NewScanner(c)
4693 var got strings.Builder
4694 for bs.Scan() {
4695 if strings.TrimSpace(bs.Text()) == "" {
4696 break
4697 }
4698 got.WriteString(bs.Text())
4699 got.WriteByte('\n')
4700 }
4701 if err := bs.Err(); err != nil {
4702 t.Fatal(err)
4703 }
4704 if strings.Contains(got.String(), "Content-Length") {
4705 t.Errorf("Unexpected Content-Length in response headers: %s", got.String())
4706 }
4707 if strings.Contains(got.String(), "Content-Type") {
4708 t.Errorf("Unexpected Content-Type in response headers: %s", got.String())
4709 }
4710 }
4711
4712
4713
4714 func TestTolerateCRLFBeforeRequestLine(t *testing.T) {
4715 req := []byte("POST / HTTP/1.1\r\nHost: golang.org\r\nContent-Length: 3\r\n\r\nABC" +
4716 "\r\n\r\n" +
4717 "GET / HTTP/1.1\r\nHost: golang.org\r\n\r\n")
4718 var buf bytes.Buffer
4719 conn := &rwTestConn{
4720 Reader: bytes.NewReader(req),
4721 Writer: &buf,
4722 closec: make(chan bool, 1),
4723 }
4724 ln := &oneConnListener{conn: conn}
4725 numReq := 0
4726 go Serve(ln, HandlerFunc(func(rw ResponseWriter, r *Request) {
4727 numReq++
4728 }))
4729 <-conn.closec
4730 if numReq != 2 {
4731 t.Errorf("num requests = %d; want 2", numReq)
4732 t.Logf("Res: %s", buf.Bytes())
4733 }
4734 }
4735
4736 func TestIssue13893_Expect100(t *testing.T) {
4737
4738 req := reqBytes(`PUT /readbody HTTP/1.1
4739 User-Agent: PycURL/7.22.0
4740 Host: 127.0.0.1:9000
4741 Accept: */*
4742 Expect: 100-continue
4743 Content-Length: 10
4744
4745 HelloWorld
4746
4747 `)
4748 var buf bytes.Buffer
4749 conn := &rwTestConn{
4750 Reader: bytes.NewReader(req),
4751 Writer: &buf,
4752 closec: make(chan bool, 1),
4753 }
4754 ln := &oneConnListener{conn: conn}
4755 go Serve(ln, HandlerFunc(func(w ResponseWriter, r *Request) {
4756 if _, ok := r.Header["Expect"]; !ok {
4757 t.Error("Expect header should not be filtered out")
4758 }
4759 }))
4760 <-conn.closec
4761 }
4762
4763 func TestIssue11549_Expect100(t *testing.T) {
4764 req := reqBytes(`PUT /readbody HTTP/1.1
4765 User-Agent: PycURL/7.22.0
4766 Host: 127.0.0.1:9000
4767 Accept: */*
4768 Expect: 100-continue
4769 Content-Length: 10
4770
4771 HelloWorldPUT /noreadbody HTTP/1.1
4772 User-Agent: PycURL/7.22.0
4773 Host: 127.0.0.1:9000
4774 Accept: */*
4775 Expect: 100-continue
4776 Content-Length: 10
4777
4778 GET /should-be-ignored HTTP/1.1
4779 Host: foo
4780
4781 `)
4782 var buf strings.Builder
4783 conn := &rwTestConn{
4784 Reader: bytes.NewReader(req),
4785 Writer: &buf,
4786 closec: make(chan bool, 1),
4787 }
4788 ln := &oneConnListener{conn: conn}
4789 numReq := 0
4790 go Serve(ln, HandlerFunc(func(w ResponseWriter, r *Request) {
4791 numReq++
4792 if r.URL.Path == "/readbody" {
4793 io.ReadAll(r.Body)
4794 }
4795 io.WriteString(w, "Hello world!")
4796 }))
4797 <-conn.closec
4798 if numReq != 2 {
4799 t.Errorf("num requests = %d; want 2", numReq)
4800 }
4801 if !strings.Contains(buf.String(), "Connection: close\r\n") {
4802 t.Errorf("expected 'Connection: close' in response; got: %s", buf.String())
4803 }
4804 }
4805
4806
4807
4808 func TestHandlerFinishSkipBigContentLengthRead(t *testing.T) {
4809 setParallel(t)
4810 conn := newTestConn()
4811 conn.readBuf.WriteString(
4812 "POST / HTTP/1.1\r\n" +
4813 "Host: test\r\n" +
4814 "Content-Length: 9999999999\r\n" +
4815 "\r\n" + strings.Repeat("a", 1<<20))
4816
4817 ls := &oneConnListener{conn}
4818 var inHandlerLen int
4819 go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
4820 inHandlerLen = conn.readBuf.Len()
4821 rw.WriteHeader(404)
4822 }))
4823 <-conn.closec
4824 afterHandlerLen := conn.readBuf.Len()
4825
4826 if afterHandlerLen != inHandlerLen {
4827 t.Errorf("unexpected implicit read. Read buffer went from %d -> %d", inHandlerLen, afterHandlerLen)
4828 }
4829 }
4830
4831 func TestHandlerSetsBodyNil(t *testing.T) { run(t, testHandlerSetsBodyNil) }
4832 func testHandlerSetsBodyNil(t *testing.T, mode testMode) {
4833 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4834 r.Body = nil
4835 fmt.Fprintf(w, "%v", r.RemoteAddr)
4836 }))
4837 get := func() string {
4838 res, err := cst.c.Get(cst.ts.URL)
4839 if err != nil {
4840 t.Fatal(err)
4841 }
4842 defer res.Body.Close()
4843 slurp, err := io.ReadAll(res.Body)
4844 if err != nil {
4845 t.Fatal(err)
4846 }
4847 return string(slurp)
4848 }
4849 a, b := get(), get()
4850 if a != b {
4851 t.Errorf("Failed to reuse connections between requests: %v vs %v", a, b)
4852 }
4853 }
4854
4855
4856
4857 func TestServerValidatesHostHeader(t *testing.T) {
4858 tests := []struct {
4859 proto string
4860 host string
4861 want int
4862 }{
4863 {"HTTP/0.9", "", 505},
4864
4865 {"HTTP/1.1", "", 400},
4866 {"HTTP/1.1", "Host: \r\n", 200},
4867 {"HTTP/1.1", "Host: 1.2.3.4\r\n", 200},
4868 {"HTTP/1.1", "Host: foo.com\r\n", 200},
4869 {"HTTP/1.1", "Host: foo-bar_baz.com\r\n", 200},
4870 {"HTTP/1.1", "Host: foo.com:80\r\n", 200},
4871 {"HTTP/1.1", "Host: ::1\r\n", 200},
4872 {"HTTP/1.1", "Host: [::1]\r\n", 200},
4873 {"HTTP/1.1", "Host: [::1]:80\r\n", 200},
4874 {"HTTP/1.1", "Host: [::1%25en0]:80\r\n", 200},
4875 {"HTTP/1.1", "Host: 1.2.3.4\r\n", 200},
4876 {"HTTP/1.1", "Host: \x06\r\n", 400},
4877 {"HTTP/1.1", "Host: \xff\r\n", 400},
4878 {"HTTP/1.1", "Host: {\r\n", 400},
4879 {"HTTP/1.1", "Host: }\r\n", 400},
4880 {"HTTP/1.1", "Host: first\r\nHost: second\r\n", 400},
4881
4882
4883
4884 {"HTTP/1.0", "", 200},
4885 {"HTTP/1.0", "Host: first\r\nHost: second\r\n", 400},
4886 {"HTTP/1.0", "Host: \xff\r\n", 400},
4887
4888
4889 {"PRI * HTTP/2.0", "", 200},
4890
4891
4892 {"CONNECT golang.org:443 HTTP/1.1", "", 200},
4893
4894
4895 {"PRI / HTTP/2.0", "", 505},
4896 {"GET / HTTP/2.0", "", 505},
4897 {"GET / HTTP/3.0", "", 505},
4898 }
4899 for _, tt := range tests {
4900 conn := newTestConn()
4901 methodTarget := "GET / "
4902 if !strings.HasPrefix(tt.proto, "HTTP/") {
4903 methodTarget = ""
4904 }
4905 io.WriteString(&conn.readBuf, methodTarget+tt.proto+"\r\n"+tt.host+"\r\n")
4906
4907 ln := &oneConnListener{conn}
4908 srv := Server{
4909 ErrorLog: quietLog,
4910 Handler: HandlerFunc(func(ResponseWriter, *Request) {}),
4911 }
4912 go srv.Serve(ln)
4913 <-conn.closec
4914 res, err := ReadResponse(bufio.NewReader(&conn.writeBuf), nil)
4915 if err != nil {
4916 t.Errorf("For %s %q, ReadResponse: %v", tt.proto, tt.host, res)
4917 continue
4918 }
4919 if res.StatusCode != tt.want {
4920 t.Errorf("For %s %q, Status = %d; want %d", tt.proto, tt.host, res.StatusCode, tt.want)
4921 }
4922 }
4923 }
4924
4925 func TestServerHandlersCanHandleH2PRI(t *testing.T) {
4926 run(t, testServerHandlersCanHandleH2PRI, []testMode{http1Mode})
4927 }
4928 func testServerHandlersCanHandleH2PRI(t *testing.T, mode testMode) {
4929 const upgradeResponse = "upgrade here"
4930 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4931 conn, br, err := w.(Hijacker).Hijack()
4932 if err != nil {
4933 t.Error(err)
4934 return
4935 }
4936 defer conn.Close()
4937 if r.Method != "PRI" || r.RequestURI != "*" {
4938 t.Errorf("Got method/target %q %q; want PRI *", r.Method, r.RequestURI)
4939 return
4940 }
4941 if !r.Close {
4942 t.Errorf("Request.Close = true; want false")
4943 }
4944 const want = "SM\r\n\r\n"
4945 buf := make([]byte, len(want))
4946 n, err := io.ReadFull(br, buf)
4947 if err != nil || string(buf[:n]) != want {
4948 t.Errorf("Read = %v, %v (%q), want %q", n, err, buf[:n], want)
4949 return
4950 }
4951 io.WriteString(conn, upgradeResponse)
4952 })).ts
4953
4954 c, err := net.Dial("tcp", ts.Listener.Addr().String())
4955 if err != nil {
4956 t.Fatalf("Dial: %v", err)
4957 }
4958 defer c.Close()
4959 io.WriteString(c, "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n")
4960 slurp, err := io.ReadAll(c)
4961 if err != nil {
4962 t.Fatal(err)
4963 }
4964 if string(slurp) != upgradeResponse {
4965 t.Errorf("Handler response = %q; want %q", slurp, upgradeResponse)
4966 }
4967 }
4968
4969
4970
4971 func TestServerValidatesHeaders(t *testing.T) {
4972 setParallel(t)
4973 tests := []struct {
4974 header string
4975 want int
4976 }{
4977 {"", 200},
4978 {"Foo: bar\r\n", 200},
4979 {"X-Foo: bar\r\n", 200},
4980 {"Foo: a space\r\n", 200},
4981
4982 {"A space: foo\r\n", 400},
4983 {"foo\xffbar: foo\r\n", 400},
4984 {"foo\x00bar: foo\r\n", 400},
4985 {"Foo: " + strings.Repeat("x", 1<<21) + "\r\n", 431},
4986
4987
4988 {"Foo : bar\r\n", 400},
4989 {"Foo\t: bar\r\n", 400},
4990
4991
4992
4993 {": empty key\r\n", 400},
4994
4995
4996
4997
4998 {"Content-Length: notdigits\r\n", 400},
4999 {"Content-Length: notdigits\r\nTransfer-Encoding: chunked\r\n\r\n0\r\n\r\n", 400},
5000
5001 {"foo: foo foo\r\n", 200},
5002 {"foo: foo\tfoo\r\n", 200},
5003 {"foo: foo\x00foo\r\n", 400},
5004 {"foo: foo\x7ffoo\r\n", 400},
5005 {"foo: foo\xfffoo\r\n", 200},
5006 }
5007 for _, tt := range tests {
5008 conn := newTestConn()
5009 io.WriteString(&conn.readBuf, "GET / HTTP/1.1\r\nHost: foo\r\n"+tt.header+"\r\n")
5010
5011 ln := &oneConnListener{conn}
5012 srv := Server{
5013 ErrorLog: quietLog,
5014 Handler: HandlerFunc(func(ResponseWriter, *Request) {}),
5015 }
5016 go srv.Serve(ln)
5017 <-conn.closec
5018 res, err := ReadResponse(bufio.NewReader(&conn.writeBuf), nil)
5019 if err != nil {
5020 t.Errorf("For %q, ReadResponse: %v", tt.header, res)
5021 continue
5022 }
5023 if res.StatusCode != tt.want {
5024 t.Errorf("For %q, Status = %d; want %d", tt.header, res.StatusCode, tt.want)
5025 }
5026 }
5027 }
5028
5029 func TestServerRequestContextCancel_ServeHTTPDone(t *testing.T) {
5030 run(t, testServerRequestContextCancel_ServeHTTPDone)
5031 }
5032 func testServerRequestContextCancel_ServeHTTPDone(t *testing.T, mode testMode) {
5033 ctxc := make(chan context.Context, 1)
5034 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5035 ctx := r.Context()
5036 select {
5037 case <-ctx.Done():
5038 t.Error("should not be Done in ServeHTTP")
5039 default:
5040 }
5041 ctxc <- ctx
5042 }))
5043 res, err := cst.c.Get(cst.ts.URL)
5044 if err != nil {
5045 t.Fatal(err)
5046 }
5047 res.Body.Close()
5048 ctx := <-ctxc
5049 select {
5050 case <-ctx.Done():
5051 default:
5052 t.Error("context should be done after ServeHTTP completes")
5053 }
5054 }
5055
5056
5057
5058
5059
5060 func TestServerRequestContextCancel_ConnClose(t *testing.T) {
5061 run(t, testServerRequestContextCancel_ConnClose, []testMode{http1Mode})
5062 }
5063 func testServerRequestContextCancel_ConnClose(t *testing.T, mode testMode) {
5064 inHandler := make(chan struct{})
5065 handlerDone := make(chan struct{})
5066 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5067 close(inHandler)
5068 <-r.Context().Done()
5069 close(handlerDone)
5070 })).ts
5071 c, err := net.Dial("tcp", ts.Listener.Addr().String())
5072 if err != nil {
5073 t.Fatal(err)
5074 }
5075 defer c.Close()
5076 io.WriteString(c, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n")
5077 <-inHandler
5078 c.Close()
5079 <-handlerDone
5080 }
5081
5082 func TestServerContext_ServerContextKey(t *testing.T) {
5083 run(t, testServerContext_ServerContextKey)
5084 }
5085 func testServerContext_ServerContextKey(t *testing.T, mode testMode) {
5086 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5087 ctx := r.Context()
5088 got := ctx.Value(ServerContextKey)
5089 if _, ok := got.(*Server); !ok {
5090 t.Errorf("context value = %T; want *http.Server", got)
5091 }
5092 }))
5093 res, err := cst.c.Get(cst.ts.URL)
5094 if err != nil {
5095 t.Fatal(err)
5096 }
5097 res.Body.Close()
5098 }
5099
5100 func TestServerContext_LocalAddrContextKey(t *testing.T) {
5101 run(t, testServerContext_LocalAddrContextKey)
5102 }
5103 func testServerContext_LocalAddrContextKey(t *testing.T, mode testMode) {
5104 ch := make(chan any, 1)
5105 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5106 ch <- r.Context().Value(LocalAddrContextKey)
5107 }))
5108 if _, err := cst.c.Head(cst.ts.URL); err != nil {
5109 t.Fatal(err)
5110 }
5111
5112 host := cst.ts.Listener.Addr().String()
5113 got := <-ch
5114 if addr, ok := got.(net.Addr); !ok {
5115 t.Errorf("local addr value = %T; want net.Addr", got)
5116 } else if fmt.Sprint(addr) != host {
5117 t.Errorf("local addr = %v; want %v", addr, host)
5118 }
5119 }
5120
5121
5122 func TestHandlerSetTransferEncodingChunked(t *testing.T) {
5123 setParallel(t)
5124 defer afterTest(t)
5125 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
5126 w.Header().Set("Transfer-Encoding", "chunked")
5127 w.Write([]byte("hello"))
5128 }))
5129 resp := ht.rawResponse("GET / HTTP/1.1\nHost: foo")
5130 const hdr = "Transfer-Encoding: chunked"
5131 if n := strings.Count(resp, hdr); n != 1 {
5132 t.Errorf("want 1 occurrence of %q in response, got %v\nresponse: %v", hdr, n, resp)
5133 }
5134 }
5135
5136
5137 func TestHandlerSetTransferEncodingGzip(t *testing.T) {
5138 setParallel(t)
5139 defer afterTest(t)
5140 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
5141 w.Header().Set("Transfer-Encoding", "gzip")
5142 gz := gzip.NewWriter(w)
5143 gz.Write([]byte("hello"))
5144 gz.Close()
5145 }))
5146 resp := ht.rawResponse("GET / HTTP/1.1\nHost: foo")
5147 for _, v := range []string{"gzip", "chunked"} {
5148 hdr := "Transfer-Encoding: " + v
5149 if n := strings.Count(resp, hdr); n != 1 {
5150 t.Errorf("want 1 occurrence of %q in response, got %v\nresponse: %v", hdr, n, resp)
5151 }
5152 }
5153 }
5154
5155 func BenchmarkClientServer(b *testing.B) {
5156 run(b, benchmarkClientServer, []testMode{http1Mode, https1Mode, http2Mode})
5157 }
5158 func benchmarkClientServer(b *testing.B, mode testMode) {
5159 b.ReportAllocs()
5160 b.StopTimer()
5161 ts := newClientServerTest(b, mode, HandlerFunc(func(rw ResponseWriter, r *Request) {
5162 fmt.Fprintf(rw, "Hello world.\n")
5163 })).ts
5164 b.StartTimer()
5165
5166 c := ts.Client()
5167 for i := 0; i < b.N; i++ {
5168 res, err := c.Get(ts.URL)
5169 if err != nil {
5170 b.Fatal("Get:", err)
5171 }
5172 all, err := io.ReadAll(res.Body)
5173 res.Body.Close()
5174 if err != nil {
5175 b.Fatal("ReadAll:", err)
5176 }
5177 body := string(all)
5178 if body != "Hello world.\n" {
5179 b.Fatal("Got body:", body)
5180 }
5181 }
5182
5183 b.StopTimer()
5184 }
5185
5186 func BenchmarkClientServerParallel(b *testing.B) {
5187 for _, parallelism := range []int{4, 64} {
5188 b.Run(fmt.Sprint(parallelism), func(b *testing.B) {
5189 run(b, func(b *testing.B, mode testMode) {
5190 benchmarkClientServerParallel(b, parallelism, mode)
5191 }, []testMode{http1Mode, https1Mode, http2Mode})
5192 })
5193 }
5194 }
5195
5196 func benchmarkClientServerParallel(b *testing.B, parallelism int, mode testMode) {
5197 b.ReportAllocs()
5198 ts := newClientServerTest(b, mode, HandlerFunc(func(rw ResponseWriter, r *Request) {
5199 fmt.Fprintf(rw, "Hello world.\n")
5200 })).ts
5201 b.ResetTimer()
5202 b.SetParallelism(parallelism)
5203 b.RunParallel(func(pb *testing.PB) {
5204 c := ts.Client()
5205 for pb.Next() {
5206 res, err := c.Get(ts.URL)
5207 if err != nil {
5208 b.Logf("Get: %v", err)
5209 continue
5210 }
5211 all, err := io.ReadAll(res.Body)
5212 res.Body.Close()
5213 if err != nil {
5214 b.Logf("ReadAll: %v", err)
5215 continue
5216 }
5217 body := string(all)
5218 if body != "Hello world.\n" {
5219 panic("Got body: " + body)
5220 }
5221 }
5222 })
5223 }
5224
5225
5226
5227
5228
5229
5230
5231
5232
5233
5234 func BenchmarkServer(b *testing.B) {
5235 b.ReportAllocs()
5236
5237 if url := os.Getenv("TEST_BENCH_SERVER_URL"); url != "" {
5238 n, err := strconv.Atoi(os.Getenv("TEST_BENCH_CLIENT_N"))
5239 if err != nil {
5240 panic(err)
5241 }
5242 for i := 0; i < n; i++ {
5243 res, err := Get(url)
5244 if err != nil {
5245 log.Panicf("Get: %v", err)
5246 }
5247 all, err := io.ReadAll(res.Body)
5248 res.Body.Close()
5249 if err != nil {
5250 log.Panicf("ReadAll: %v", err)
5251 }
5252 body := string(all)
5253 if body != "Hello world.\n" {
5254 log.Panicf("Got body: %q", body)
5255 }
5256 }
5257 os.Exit(0)
5258 return
5259 }
5260
5261 var res = []byte("Hello world.\n")
5262 b.StopTimer()
5263 ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, r *Request) {
5264 rw.Header().Set("Content-Type", "text/html; charset=utf-8")
5265 rw.Write(res)
5266 }))
5267 defer ts.Close()
5268 b.StartTimer()
5269
5270 cmd := testenv.Command(b, os.Args[0], "-test.run=^$", "-test.bench=^BenchmarkServer$")
5271 cmd.Env = append([]string{
5272 fmt.Sprintf("TEST_BENCH_CLIENT_N=%d", b.N),
5273 fmt.Sprintf("TEST_BENCH_SERVER_URL=%s", ts.URL),
5274 }, os.Environ()...)
5275 out, err := cmd.CombinedOutput()
5276 if err != nil {
5277 b.Errorf("Test failure: %v, with output: %s", err, out)
5278 }
5279 }
5280
5281
5282 func getNoBody(urlStr string) (*Response, error) {
5283 res, err := Get(urlStr)
5284 if err != nil {
5285 return nil, err
5286 }
5287 res.Body.Close()
5288 return res, nil
5289 }
5290
5291
5292
5293 func BenchmarkClient(b *testing.B) {
5294 b.ReportAllocs()
5295 b.StopTimer()
5296 defer afterTest(b)
5297
5298 var data = []byte("Hello world.\n")
5299 if server := os.Getenv("TEST_BENCH_SERVER"); server != "" {
5300
5301 port := os.Getenv("TEST_BENCH_SERVER_PORT")
5302 if port == "" {
5303 port = "0"
5304 }
5305 ln, err := net.Listen("tcp", "localhost:"+port)
5306 if err != nil {
5307 fmt.Fprintln(os.Stderr, err.Error())
5308 os.Exit(1)
5309 }
5310 fmt.Println(ln.Addr().String())
5311 HandleFunc("/", func(w ResponseWriter, r *Request) {
5312 r.ParseForm()
5313 if r.Form.Get("stop") != "" {
5314 os.Exit(0)
5315 }
5316 w.Header().Set("Content-Type", "text/html; charset=utf-8")
5317 w.Write(data)
5318 })
5319 var srv Server
5320 log.Fatal(srv.Serve(ln))
5321 }
5322
5323
5324 ctx, cancel := context.WithCancel(context.Background())
5325 cmd := testenv.CommandContext(b, ctx, os.Args[0], "-test.run=^$", "-test.bench=^BenchmarkClient$")
5326 cmd.Env = append(cmd.Environ(), "TEST_BENCH_SERVER=yes")
5327 cmd.Stderr = os.Stderr
5328 stdout, err := cmd.StdoutPipe()
5329 if err != nil {
5330 b.Fatal(err)
5331 }
5332 if err := cmd.Start(); err != nil {
5333 b.Fatalf("subprocess failed to start: %v", err)
5334 }
5335
5336 done := make(chan error, 1)
5337 go func() {
5338 done <- cmd.Wait()
5339 close(done)
5340 }()
5341 defer func() {
5342 cancel()
5343 <-done
5344 }()
5345
5346
5347
5348 bs := bufio.NewScanner(stdout)
5349 if !bs.Scan() {
5350 b.Fatalf("failed to read listening URL from child: %v", bs.Err())
5351 }
5352 url := "http://" + strings.TrimSpace(bs.Text()) + "/"
5353 if _, err := getNoBody(url); err != nil {
5354 b.Fatalf("initial probe of child process failed: %v", err)
5355 }
5356
5357
5358 b.StartTimer()
5359 for i := 0; i < b.N; i++ {
5360 res, err := Get(url)
5361 if err != nil {
5362 b.Fatalf("Get: %v", err)
5363 }
5364 body, err := io.ReadAll(res.Body)
5365 res.Body.Close()
5366 if err != nil {
5367 b.Fatalf("ReadAll: %v", err)
5368 }
5369 if !bytes.Equal(body, data) {
5370 b.Fatalf("Got body: %q", body)
5371 }
5372 }
5373 b.StopTimer()
5374
5375
5376 getNoBody(url + "?stop=yes")
5377 if err := <-done; err != nil {
5378 b.Fatalf("subprocess failed: %v", err)
5379 }
5380 }
5381
5382 func BenchmarkServerFakeConnNoKeepAlive(b *testing.B) {
5383 b.ReportAllocs()
5384 req := reqBytes(`GET / HTTP/1.0
5385 Host: golang.org
5386 Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8
5387 User-Agent: Mozilla/5.0 (Macintosh; Intel Mac OS X 10_8_2) AppleWebKit/537.17 (KHTML, like Gecko) Chrome/24.0.1312.52 Safari/537.17
5388 Accept-Encoding: gzip,deflate,sdch
5389 Accept-Language: en-US,en;q=0.8
5390 Accept-Charset: ISO-8859-1,utf-8;q=0.7,*;q=0.3
5391 `)
5392 res := []byte("Hello world!\n")
5393
5394 conn := newTestConn()
5395 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
5396 rw.Header().Set("Content-Type", "text/html; charset=utf-8")
5397 rw.Write(res)
5398 })
5399 ln := new(oneConnListener)
5400 for i := 0; i < b.N; i++ {
5401 conn.readBuf.Reset()
5402 conn.writeBuf.Reset()
5403 conn.readBuf.Write(req)
5404 ln.conn = conn
5405 Serve(ln, handler)
5406 <-conn.closec
5407 }
5408 }
5409
5410
5411 type repeatReader struct {
5412 content []byte
5413 count int
5414 off int
5415 }
5416
5417 func (r *repeatReader) Read(p []byte) (n int, err error) {
5418 if r.count <= 0 {
5419 return 0, io.EOF
5420 }
5421 n = copy(p, r.content[r.off:])
5422 r.off += n
5423 if r.off == len(r.content) {
5424 r.count--
5425 r.off = 0
5426 }
5427 return
5428 }
5429
5430 func BenchmarkServerFakeConnWithKeepAlive(b *testing.B) {
5431 b.ReportAllocs()
5432
5433 req := reqBytes(`GET / HTTP/1.1
5434 Host: golang.org
5435 Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8
5436 User-Agent: Mozilla/5.0 (Macintosh; Intel Mac OS X 10_8_2) AppleWebKit/537.17 (KHTML, like Gecko) Chrome/24.0.1312.52 Safari/537.17
5437 Accept-Encoding: gzip,deflate,sdch
5438 Accept-Language: en-US,en;q=0.8
5439 Accept-Charset: ISO-8859-1,utf-8;q=0.7,*;q=0.3
5440 `)
5441 res := []byte("Hello world!\n")
5442
5443 conn := &rwTestConn{
5444 Reader: &repeatReader{content: req, count: b.N},
5445 Writer: io.Discard,
5446 closec: make(chan bool, 1),
5447 }
5448 handled := 0
5449 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
5450 handled++
5451 rw.Header().Set("Content-Type", "text/html; charset=utf-8")
5452 rw.Write(res)
5453 })
5454 ln := &oneConnListener{conn: conn}
5455 go Serve(ln, handler)
5456 <-conn.closec
5457 if b.N != handled {
5458 b.Errorf("b.N=%d but handled %d", b.N, handled)
5459 }
5460 }
5461
5462
5463
5464 func BenchmarkServerFakeConnWithKeepAliveLite(b *testing.B) {
5465 b.ReportAllocs()
5466
5467 req := reqBytes(`GET / HTTP/1.1
5468 Host: golang.org
5469 `)
5470 res := []byte("Hello world!\n")
5471
5472 conn := &rwTestConn{
5473 Reader: &repeatReader{content: req, count: b.N},
5474 Writer: io.Discard,
5475 closec: make(chan bool, 1),
5476 }
5477 handled := 0
5478 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
5479 handled++
5480 rw.Write(res)
5481 })
5482 ln := &oneConnListener{conn: conn}
5483 go Serve(ln, handler)
5484 <-conn.closec
5485 if b.N != handled {
5486 b.Errorf("b.N=%d but handled %d", b.N, handled)
5487 }
5488 }
5489
5490 const someResponse = "<html>some response</html>"
5491
5492
5493 var response = bytes.Repeat([]byte(someResponse), 2<<10/len(someResponse))
5494
5495
5496 func BenchmarkServerHandlerTypeLen(b *testing.B) {
5497 benchmarkHandler(b, HandlerFunc(func(w ResponseWriter, r *Request) {
5498 w.Header().Set("Content-Type", "text/html")
5499 w.Header().Set("Content-Length", strconv.Itoa(len(response)))
5500 w.Write(response)
5501 }))
5502 }
5503
5504
5505 func BenchmarkServerHandlerNoLen(b *testing.B) {
5506 benchmarkHandler(b, HandlerFunc(func(w ResponseWriter, r *Request) {
5507 w.Header().Set("Content-Type", "text/html")
5508 w.Write(response)
5509 }))
5510 }
5511
5512
5513 func BenchmarkServerHandlerNoType(b *testing.B) {
5514 benchmarkHandler(b, HandlerFunc(func(w ResponseWriter, r *Request) {
5515 w.Header().Set("Content-Length", strconv.Itoa(len(response)))
5516 w.Write(response)
5517 }))
5518 }
5519
5520
5521 func BenchmarkServerHandlerNoHeader(b *testing.B) {
5522 benchmarkHandler(b, HandlerFunc(func(w ResponseWriter, r *Request) {
5523 w.Write(response)
5524 }))
5525 }
5526
5527 func benchmarkHandler(b *testing.B, h Handler) {
5528 b.ReportAllocs()
5529 req := reqBytes(`GET / HTTP/1.1
5530 Host: golang.org
5531 `)
5532 conn := &rwTestConn{
5533 Reader: &repeatReader{content: req, count: b.N},
5534 Writer: io.Discard,
5535 closec: make(chan bool, 1),
5536 }
5537 handled := 0
5538 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
5539 handled++
5540 h.ServeHTTP(rw, r)
5541 })
5542 ln := &oneConnListener{conn: conn}
5543 go Serve(ln, handler)
5544 <-conn.closec
5545 if b.N != handled {
5546 b.Errorf("b.N=%d but handled %d", b.N, handled)
5547 }
5548 }
5549
5550 func BenchmarkServerHijack(b *testing.B) {
5551 b.ReportAllocs()
5552 req := reqBytes(`GET / HTTP/1.1
5553 Host: golang.org
5554 `)
5555 h := HandlerFunc(func(w ResponseWriter, r *Request) {
5556 conn, _, err := w.(Hijacker).Hijack()
5557 if err != nil {
5558 panic(err)
5559 }
5560 conn.Close()
5561 })
5562 conn := &rwTestConn{
5563 Writer: io.Discard,
5564 closec: make(chan bool, 1),
5565 }
5566 ln := &oneConnListener{conn: conn}
5567 for i := 0; i < b.N; i++ {
5568 conn.Reader = bytes.NewReader(req)
5569 ln.conn = conn
5570 Serve(ln, h)
5571 <-conn.closec
5572 }
5573 }
5574
5575 func BenchmarkCloseNotifier(b *testing.B) { run(b, benchmarkCloseNotifier, []testMode{http1Mode}) }
5576 func benchmarkCloseNotifier(b *testing.B, mode testMode) {
5577 b.ReportAllocs()
5578 b.StopTimer()
5579 sawClose := make(chan bool)
5580 ts := newClientServerTest(b, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
5581 <-rw.(CloseNotifier).CloseNotify()
5582 sawClose <- true
5583 })).ts
5584 b.StartTimer()
5585 for i := 0; i < b.N; i++ {
5586 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
5587 if err != nil {
5588 b.Fatalf("error dialing: %v", err)
5589 }
5590 _, err = fmt.Fprintf(conn, "GET / HTTP/1.1\r\nConnection: keep-alive\r\nHost: foo\r\n\r\n")
5591 if err != nil {
5592 b.Fatal(err)
5593 }
5594 conn.Close()
5595 <-sawClose
5596 }
5597 b.StopTimer()
5598 }
5599
5600
5601 func TestConcurrentServerServe(t *testing.T) {
5602 setParallel(t)
5603 for i := 0; i < 100; i++ {
5604 ln1 := &oneConnListener{conn: nil}
5605 ln2 := &oneConnListener{conn: nil}
5606 srv := Server{}
5607 go func() { srv.Serve(ln1) }()
5608 go func() { srv.Serve(ln2) }()
5609 }
5610 }
5611
5612 func TestServerIdleTimeout(t *testing.T) { run(t, testServerIdleTimeout, []testMode{http1Mode}) }
5613 func testServerIdleTimeout(t *testing.T, mode testMode) {
5614 if testing.Short() {
5615 t.Skip("skipping in short mode")
5616 }
5617 runTimeSensitiveTest(t, []time.Duration{
5618 10 * time.Millisecond,
5619 100 * time.Millisecond,
5620 1 * time.Second,
5621 10 * time.Second,
5622 }, func(t *testing.T, readHeaderTimeout time.Duration) error {
5623 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5624 io.Copy(io.Discard, r.Body)
5625 io.WriteString(w, r.RemoteAddr)
5626 }), func(ts *httptest.Server) {
5627 ts.Config.ReadHeaderTimeout = readHeaderTimeout
5628 ts.Config.IdleTimeout = 2 * readHeaderTimeout
5629 })
5630 defer cst.close()
5631 ts := cst.ts
5632 t.Logf("ReadHeaderTimeout = %v", ts.Config.ReadHeaderTimeout)
5633 t.Logf("IdleTimeout = %v", ts.Config.IdleTimeout)
5634 c := ts.Client()
5635
5636 get := func() (string, error) {
5637 res, err := c.Get(ts.URL)
5638 if err != nil {
5639 return "", err
5640 }
5641 defer res.Body.Close()
5642 slurp, err := io.ReadAll(res.Body)
5643 if err != nil {
5644
5645
5646
5647 t.Fatal(err)
5648 }
5649 return string(slurp), nil
5650 }
5651
5652 a1, err := get()
5653 if err != nil {
5654 return err
5655 }
5656 a2, err := get()
5657 if err != nil {
5658 return err
5659 }
5660 if a1 != a2 {
5661 return fmt.Errorf("did requests on different connections")
5662 }
5663 time.Sleep(ts.Config.IdleTimeout * 3 / 2)
5664 a3, err := get()
5665 if err != nil {
5666 return err
5667 }
5668 if a2 == a3 {
5669 return fmt.Errorf("request three unexpectedly on same connection")
5670 }
5671
5672
5673 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
5674 if err != nil {
5675 return err
5676 }
5677 defer conn.Close()
5678 conn.Write([]byte("GET / HTTP/1.1\r\nHost: foo.com\r\n"))
5679 time.Sleep(ts.Config.ReadHeaderTimeout * 2)
5680 if _, err := io.CopyN(io.Discard, conn, 1); err == nil {
5681 return fmt.Errorf("copy byte succeeded; want err")
5682 }
5683
5684 return nil
5685 })
5686 }
5687
5688 func get(t *testing.T, c *Client, url string) string {
5689 res, err := c.Get(url)
5690 if err != nil {
5691 t.Fatal(err)
5692 }
5693 defer res.Body.Close()
5694 slurp, err := io.ReadAll(res.Body)
5695 if err != nil {
5696 t.Fatal(err)
5697 }
5698 return string(slurp)
5699 }
5700
5701
5702
5703 func TestServerSetKeepAlivesEnabledClosesConns(t *testing.T) {
5704 run(t, testServerSetKeepAlivesEnabledClosesConns, []testMode{http1Mode})
5705 }
5706 func testServerSetKeepAlivesEnabledClosesConns(t *testing.T, mode testMode) {
5707 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5708 io.WriteString(w, r.RemoteAddr)
5709 })).ts
5710
5711 c := ts.Client()
5712 tr := c.Transport.(*Transport)
5713
5714 get := func() string { return get(t, c, ts.URL) }
5715
5716 a1, a2 := get(), get()
5717 if a1 == a2 {
5718 t.Logf("made two requests from a single conn %q (as expected)", a1)
5719 } else {
5720 t.Errorf("server reported requests from %q and %q; expected same connection", a1, a2)
5721 }
5722
5723
5724
5725
5726
5727 if conns := tr.IdleConnStrsForTesting(); len(conns) != 1 {
5728 t.Errorf("found %d idle conns (%q); want 1", len(conns), conns)
5729 }
5730
5731
5732 ts.Config.SetKeepAlivesEnabled(false)
5733
5734 waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
5735 if conns := tr.IdleConnStrsForTesting(); len(conns) > 0 {
5736 if d > 0 {
5737 t.Logf("idle conns %v after SetKeepAlivesEnabled called = %q; waiting for empty", d, conns)
5738 }
5739 return false
5740 }
5741 return true
5742 })
5743
5744
5745
5746
5747 }
5748
5749 func TestServerShutdown(t *testing.T) { run(t, testServerShutdown) }
5750 func testServerShutdown(t *testing.T, mode testMode) {
5751 var cst *clientServerTest
5752
5753 var once sync.Once
5754 statesRes := make(chan map[ConnState]int, 1)
5755 shutdownRes := make(chan error, 1)
5756 gotOnShutdown := make(chan struct{})
5757 handler := HandlerFunc(func(w ResponseWriter, r *Request) {
5758 first := false
5759 once.Do(func() {
5760 statesRes <- cst.ts.Config.ExportAllConnsByState()
5761 go func() {
5762 shutdownRes <- cst.ts.Config.Shutdown(context.Background())
5763 }()
5764 first = true
5765 })
5766
5767 if first {
5768
5769
5770
5771 <-gotOnShutdown
5772
5773
5774 for !t.Failed() {
5775 res, err := cst.c.Get(cst.ts.URL)
5776 if err != nil {
5777 break
5778 }
5779 out, _ := io.ReadAll(res.Body)
5780 res.Body.Close()
5781 if mode == http2Mode {
5782 t.Logf("%v: unexpected success (%q). Listener should be closed before OnShutdown is called.", cst.ts.URL, out)
5783 t.Logf("Retrying to work around https://go.dev/issue/59038.")
5784 continue
5785 }
5786 t.Errorf("%v: unexpected success (%q). Listener should be closed before OnShutdown is called.", cst.ts.URL, out)
5787 }
5788 }
5789
5790 io.WriteString(w, r.RemoteAddr)
5791 })
5792
5793 cst = newClientServerTest(t, mode, handler, func(srv *httptest.Server) {
5794 srv.Config.RegisterOnShutdown(func() { close(gotOnShutdown) })
5795 })
5796
5797 out := get(t, cst.c, cst.ts.URL)
5798 t.Logf("%v: %q", cst.ts.URL, out)
5799
5800 if err := <-shutdownRes; err != nil {
5801 t.Fatalf("Shutdown: %v", err)
5802 }
5803 <-gotOnShutdown
5804
5805 if states := <-statesRes; states[StateActive] != 1 {
5806 t.Errorf("connection in wrong state, %v", states)
5807 }
5808 }
5809
5810 func TestServerShutdownStateNew(t *testing.T) { runSynctest(t, testServerShutdownStateNew) }
5811 func testServerShutdownStateNew(t testing.TB, mode testMode) {
5812 if testing.Short() {
5813 t.Skip("test takes 5-6 seconds; skipping in short mode")
5814 }
5815
5816 listener := fakeNetListen()
5817 defer listener.Close()
5818
5819 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5820
5821 }), func(ts *httptest.Server) {
5822 ts.Listener.Close()
5823 ts.Listener = listener
5824
5825 ts.Config.ErrorLog = log.New(io.Discard, "", 0)
5826 }).ts
5827
5828
5829 c := listener.connect()
5830 defer c.Close()
5831 synctest.Wait()
5832
5833 shutdownRes := runAsync(func() (struct{}, error) {
5834 return struct{}{}, ts.Config.Shutdown(context.Background())
5835 })
5836
5837
5838
5839
5840 const expectTimeout = 5 * time.Second
5841
5842
5843 time.Sleep(expectTimeout - 1)
5844 synctest.Wait()
5845 if shutdownRes.done() {
5846 t.Fatal("shutdown too soon")
5847 }
5848 if c.IsClosedByPeer() {
5849 t.Fatal("connection was closed by server too soon")
5850 }
5851
5852
5853
5854
5855
5856 time.Sleep(2 * time.Second)
5857 synctest.Wait()
5858 if _, err := shutdownRes.result(); err != nil {
5859 t.Fatalf("Shutdown() = %v, want complete", err)
5860 }
5861 if !c.IsClosedByPeer() {
5862 t.Fatalf("connection was not closed by server after shutdown")
5863 }
5864 }
5865
5866
5867 func TestServerCloseDeadlock(t *testing.T) {
5868 var s Server
5869 s.Close()
5870 s.Close()
5871 }
5872
5873
5874
5875 func TestServerKeepAlivesEnabled(t *testing.T) { run(t, testServerKeepAlivesEnabled, testNotParallel) }
5876 func testServerKeepAlivesEnabled(t *testing.T, mode testMode) {
5877 if mode == http2Mode {
5878 restore := ExportSetH2GoawayTimeout(10 * time.Millisecond)
5879 defer restore()
5880 }
5881
5882 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}))
5883 defer cst.close()
5884 srv := cst.ts.Config
5885 srv.SetKeepAlivesEnabled(false)
5886 for try := 0; try < 2; try++ {
5887 waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
5888 if !srv.ExportAllConnsIdle() {
5889 if d > 0 {
5890 t.Logf("test server still has active conns after %v", d)
5891 }
5892 return false
5893 }
5894 return true
5895 })
5896 conns := 0
5897 var info httptrace.GotConnInfo
5898 ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
5899 GotConn: func(v httptrace.GotConnInfo) {
5900 conns++
5901 info = v
5902 },
5903 })
5904 req, err := NewRequestWithContext(ctx, "GET", cst.ts.URL, nil)
5905 if err != nil {
5906 t.Fatal(err)
5907 }
5908 res, err := cst.c.Do(req)
5909 if err != nil {
5910 t.Fatal(err)
5911 }
5912 res.Body.Close()
5913 if conns != 1 {
5914 t.Fatalf("request %v: got %v conns, want 1", try, conns)
5915 }
5916 if info.Reused || info.WasIdle {
5917 t.Fatalf("request %v: Reused=%v (want false), WasIdle=%v (want false)", try, info.Reused, info.WasIdle)
5918 }
5919 }
5920 }
5921
5922
5923
5924
5925 func TestServerCancelsReadTimeoutWhenIdle(t *testing.T) { run(t, testServerCancelsReadTimeoutWhenIdle) }
5926 func testServerCancelsReadTimeoutWhenIdle(t *testing.T, mode testMode) {
5927 runTimeSensitiveTest(t, []time.Duration{
5928 10 * time.Millisecond,
5929 50 * time.Millisecond,
5930 250 * time.Millisecond,
5931 time.Second,
5932 2 * time.Second,
5933 }, func(t *testing.T, timeout time.Duration) error {
5934 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5935 select {
5936 case <-time.After(2 * timeout):
5937 fmt.Fprint(w, "ok")
5938 case <-r.Context().Done():
5939 fmt.Fprint(w, r.Context().Err())
5940 }
5941 }), func(ts *httptest.Server) {
5942 ts.Config.ReadTimeout = timeout
5943 t.Logf("Server.Config.ReadTimeout = %v", timeout)
5944 })
5945 defer cst.close()
5946 ts := cst.ts
5947
5948 var retries atomic.Int32
5949 cst.c.Transport.(*Transport).Proxy = func(*Request) (*url.URL, error) {
5950 if retries.Add(1) != 1 {
5951 return nil, errors.New("too many retries")
5952 }
5953 return nil, nil
5954 }
5955
5956 c := ts.Client()
5957
5958 res, err := c.Get(ts.URL)
5959 if err != nil {
5960 return fmt.Errorf("Get: %v", err)
5961 }
5962 slurp, err := io.ReadAll(res.Body)
5963 res.Body.Close()
5964 if err != nil {
5965 return fmt.Errorf("Body ReadAll: %v", err)
5966 }
5967 if string(slurp) != "ok" {
5968 return fmt.Errorf("got: %q, want ok", slurp)
5969 }
5970 return nil
5971 })
5972 }
5973
5974
5975
5976
5977 func TestServerCancelsReadHeaderTimeoutWhenIdle(t *testing.T) {
5978 run(t, testServerCancelsReadHeaderTimeoutWhenIdle, []testMode{http1Mode})
5979 }
5980 func testServerCancelsReadHeaderTimeoutWhenIdle(t *testing.T, mode testMode) {
5981 runTimeSensitiveTest(t, []time.Duration{
5982 10 * time.Millisecond,
5983 50 * time.Millisecond,
5984 250 * time.Millisecond,
5985 time.Second,
5986 2 * time.Second,
5987 }, func(t *testing.T, timeout time.Duration) error {
5988 cst := newClientServerTest(t, mode, serve(200), func(ts *httptest.Server) {
5989 ts.Config.ReadHeaderTimeout = timeout
5990 ts.Config.IdleTimeout = 0
5991 })
5992 defer cst.close()
5993 ts := cst.ts
5994
5995
5996
5997 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
5998 if err != nil {
5999 t.Fatalf("dial failed: %v", err)
6000 }
6001 br := bufio.NewReader(conn)
6002 defer conn.Close()
6003
6004 if _, err := conn.Write([]byte("GET / HTTP/1.1\r\nHost: e.com\r\n\r\n")); err != nil {
6005 return fmt.Errorf("writing first request failed: %v", err)
6006 }
6007
6008 if _, err := ReadResponse(br, nil); err != nil {
6009 return fmt.Errorf("first response (before timeout) failed: %v", err)
6010 }
6011
6012
6013
6014 time.Sleep(timeout * 3 / 2)
6015
6016 if _, err := conn.Write([]byte("GET / HTTP/1.1\r\nHost: e.com\r\n\r\n")); err != nil {
6017 return fmt.Errorf("writing second request failed: %v", err)
6018 }
6019
6020 if _, err := ReadResponse(br, nil); err != nil {
6021 return fmt.Errorf("second response (after timeout) failed: %v", err)
6022 }
6023
6024 return nil
6025 })
6026 }
6027
6028
6029
6030 func runTimeSensitiveTest(t *testing.T, durations []time.Duration, test func(t *testing.T, d time.Duration) error) {
6031 for i, d := range durations {
6032 err := test(t, d)
6033 if err == nil {
6034 return
6035 }
6036 if i == len(durations)-1 || t.Failed() {
6037 t.Fatalf("failed with duration %v: %v", d, err)
6038 }
6039 t.Logf("retrying after error with duration %v: %v", d, err)
6040 }
6041 }
6042
6043
6044
6045 func TestServerDuplicateBackgroundRead(t *testing.T) {
6046 run(t, testServerDuplicateBackgroundRead, []testMode{http1Mode})
6047 }
6048 func testServerDuplicateBackgroundRead(t *testing.T, mode testMode) {
6049 if runtime.GOOS == "netbsd" && runtime.GOARCH == "arm" {
6050 testenv.SkipFlaky(t, 24826)
6051 }
6052
6053 goroutines := 5
6054 requests := 2000
6055 if testing.Short() {
6056 goroutines = 3
6057 requests = 100
6058 }
6059
6060 hts := newClientServerTest(t, mode, HandlerFunc(NotFound)).ts
6061
6062 reqBytes := []byte("GET / HTTP/1.1\r\nHost: e.com\r\n\r\n")
6063
6064 var wg sync.WaitGroup
6065 for i := 0; i < goroutines; i++ {
6066 wg.Add(1)
6067 go func() {
6068 defer wg.Done()
6069 cn, err := net.Dial("tcp", hts.Listener.Addr().String())
6070 if err != nil {
6071 t.Error(err)
6072 return
6073 }
6074 defer cn.Close()
6075
6076 wg.Add(1)
6077 go func() {
6078 defer wg.Done()
6079 io.Copy(io.Discard, cn)
6080 }()
6081
6082 for j := 0; j < requests; j++ {
6083 if t.Failed() {
6084 return
6085 }
6086 _, err := cn.Write(reqBytes)
6087 if err != nil {
6088 t.Error(err)
6089 return
6090 }
6091 }
6092 }()
6093 }
6094 wg.Wait()
6095 }
6096
6097
6098
6099
6100
6101
6102 func TestServerHijackGetsBackgroundByte(t *testing.T) {
6103 run(t, testServerHijackGetsBackgroundByte, []testMode{http1Mode})
6104 }
6105 func testServerHijackGetsBackgroundByte(t *testing.T, mode testMode) {
6106 if runtime.GOOS == "plan9" {
6107 t.Skip("skipping test; see https://golang.org/issue/18657")
6108 }
6109 done := make(chan struct{})
6110 inHandler := make(chan bool, 1)
6111 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6112 defer close(done)
6113
6114
6115 inHandler <- true
6116
6117 conn, buf, err := w.(Hijacker).Hijack()
6118 if err != nil {
6119 t.Error(err)
6120 return
6121 }
6122 defer conn.Close()
6123
6124 peek, err := buf.Reader.Peek(3)
6125 if string(peek) != "foo" || err != nil {
6126 t.Errorf("Peek = %q, %v; want foo, nil", peek, err)
6127 }
6128
6129 select {
6130 case <-r.Context().Done():
6131 t.Error("context unexpectedly canceled")
6132 default:
6133 }
6134 })).ts
6135
6136 cn, err := net.Dial("tcp", ts.Listener.Addr().String())
6137 if err != nil {
6138 t.Fatal(err)
6139 }
6140 defer cn.Close()
6141 if _, err := cn.Write([]byte("GET / HTTP/1.1\r\nHost: e.com\r\n\r\n")); err != nil {
6142 t.Fatal(err)
6143 }
6144 <-inHandler
6145 if _, err := cn.Write([]byte("foo")); err != nil {
6146 t.Fatal(err)
6147 }
6148
6149 if err := cn.(*net.TCPConn).CloseWrite(); err != nil {
6150 t.Fatal(err)
6151 }
6152 <-done
6153 }
6154
6155
6156
6157
6158 func TestServerHijackGetsBackgroundByte_big(t *testing.T) {
6159 run(t, testServerHijackGetsBackgroundByte_big, []testMode{http1Mode})
6160 }
6161 func testServerHijackGetsBackgroundByte_big(t *testing.T, mode testMode) {
6162 if runtime.GOOS == "plan9" {
6163 t.Skip("skipping test; see https://golang.org/issue/18657")
6164 }
6165 done := make(chan struct{})
6166 const size = 8 << 10
6167 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6168 defer close(done)
6169
6170 conn, buf, err := w.(Hijacker).Hijack()
6171 if err != nil {
6172 t.Error(err)
6173 return
6174 }
6175 defer conn.Close()
6176 slurp, err := io.ReadAll(buf.Reader)
6177 if err != nil {
6178 t.Errorf("Copy: %v", err)
6179 }
6180 allX := true
6181 for _, v := range slurp {
6182 if v != 'x' {
6183 allX = false
6184 }
6185 }
6186 if len(slurp) != size {
6187 t.Errorf("read %d; want %d", len(slurp), size)
6188 } else if !allX {
6189 t.Errorf("read %q; want %d 'x'", slurp, size)
6190 }
6191 })).ts
6192
6193 cn, err := net.Dial("tcp", ts.Listener.Addr().String())
6194 if err != nil {
6195 t.Fatal(err)
6196 }
6197 defer cn.Close()
6198 if _, err := fmt.Fprintf(cn, "GET / HTTP/1.1\r\nHost: e.com\r\n\r\n%s",
6199 strings.Repeat("x", size)); err != nil {
6200 t.Fatal(err)
6201 }
6202 if err := cn.(*net.TCPConn).CloseWrite(); err != nil {
6203 t.Fatal(err)
6204 }
6205
6206 <-done
6207 }
6208
6209
6210 func TestServerValidatesMethod(t *testing.T) {
6211 tests := []struct {
6212 method string
6213 want int
6214 }{
6215 {"GET", 200},
6216 {"GE(T", 400},
6217 }
6218 for _, tt := range tests {
6219 conn := newTestConn()
6220 io.WriteString(&conn.readBuf, tt.method+" / HTTP/1.1\r\nHost: foo.example\r\n\r\n")
6221
6222 ln := &oneConnListener{conn}
6223 go Serve(ln, serve(200))
6224 <-conn.closec
6225 res, err := ReadResponse(bufio.NewReader(&conn.writeBuf), nil)
6226 if err != nil {
6227 t.Errorf("For %s, ReadResponse: %v", tt.method, res)
6228 continue
6229 }
6230 if res.StatusCode != tt.want {
6231 t.Errorf("For %s, Status = %d; want %d", tt.method, res.StatusCode, tt.want)
6232 }
6233 }
6234 }
6235
6236
6237 type eofListenerNotComparable []int
6238
6239 func (eofListenerNotComparable) Accept() (net.Conn, error) { return nil, io.EOF }
6240 func (eofListenerNotComparable) Addr() net.Addr { return nil }
6241 func (eofListenerNotComparable) Close() error { return nil }
6242
6243
6244 func TestServerListenNotComparableListener(t *testing.T) {
6245 var s Server
6246 s.Serve(make(eofListenerNotComparable, 1))
6247 }
6248
6249
6250 type countCloseListener struct {
6251 net.Listener
6252 closes int32
6253 }
6254
6255 func (p *countCloseListener) Close() error {
6256 var err error
6257 if n := atomic.AddInt32(&p.closes, 1); n == 1 && p.Listener != nil {
6258 err = p.Listener.Close()
6259 }
6260 return err
6261 }
6262
6263
6264 func TestServerCloseListenerOnce(t *testing.T) {
6265 setParallel(t)
6266 defer afterTest(t)
6267
6268 ln := newLocalListener(t)
6269 defer ln.Close()
6270
6271 cl := &countCloseListener{Listener: ln}
6272 server := &Server{}
6273 sdone := make(chan bool, 1)
6274
6275 go func() {
6276 server.Serve(cl)
6277 sdone <- true
6278 }()
6279 time.Sleep(10 * time.Millisecond)
6280 server.Shutdown(context.Background())
6281 ln.Close()
6282 <-sdone
6283
6284 nclose := atomic.LoadInt32(&cl.closes)
6285 if nclose != 1 {
6286 t.Errorf("Close calls = %v; want 1", nclose)
6287 }
6288 }
6289
6290
6291 func TestServerShutdownThenServe(t *testing.T) {
6292 var srv Server
6293 cl := &countCloseListener{Listener: nil}
6294 srv.Shutdown(context.Background())
6295 got := srv.Serve(cl)
6296 if got != ErrServerClosed {
6297 t.Errorf("Serve err = %v; want ErrServerClosed", got)
6298 }
6299 nclose := atomic.LoadInt32(&cl.closes)
6300 if nclose != 1 {
6301 t.Errorf("Close calls = %v; want 1", nclose)
6302 }
6303 }
6304
6305
6306 func TestStripPortFromHost(t *testing.T) {
6307 mux := NewServeMux()
6308
6309 mux.HandleFunc("example.com/", func(w ResponseWriter, r *Request) {
6310 fmt.Fprintf(w, "OK")
6311 })
6312 mux.HandleFunc("example.com:9000/", func(w ResponseWriter, r *Request) {
6313 fmt.Fprintf(w, "uh-oh!")
6314 })
6315
6316 req := httptest.NewRequest("GET", "http://example.com:9000/", nil)
6317 rw := httptest.NewRecorder()
6318
6319 mux.ServeHTTP(rw, req)
6320
6321 response := rw.Body.String()
6322 if response != "OK" {
6323 t.Errorf("Response gotten was %q", response)
6324 }
6325 }
6326
6327 func TestServerContexts(t *testing.T) { run(t, testServerContexts) }
6328 func testServerContexts(t *testing.T, mode testMode) {
6329 type baseKey struct{}
6330 type connKey struct{}
6331 ch := make(chan context.Context, 1)
6332 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, r *Request) {
6333 ch <- r.Context()
6334 }), func(ts *httptest.Server) {
6335 ts.Config.BaseContext = func(ln net.Listener) context.Context {
6336 if strings.Contains(reflect.TypeOf(ln).String(), "onceClose") {
6337 t.Errorf("unexpected onceClose listener type %T", ln)
6338 }
6339 return context.WithValue(context.Background(), baseKey{}, "base")
6340 }
6341 ts.Config.ConnContext = func(ctx context.Context, c net.Conn) context.Context {
6342 if got, want := ctx.Value(baseKey{}), "base"; got != want {
6343 t.Errorf("in ConnContext, base context key = %#v; want %q", got, want)
6344 }
6345 return context.WithValue(ctx, connKey{}, "conn")
6346 }
6347 }).ts
6348 res, err := ts.Client().Get(ts.URL)
6349 if err != nil {
6350 t.Fatal(err)
6351 }
6352 res.Body.Close()
6353 ctx := <-ch
6354 if got, want := ctx.Value(baseKey{}), "base"; got != want {
6355 t.Errorf("base context key = %#v; want %q", got, want)
6356 }
6357 if got, want := ctx.Value(connKey{}), "conn"; got != want {
6358 t.Errorf("conn context key = %#v; want %q", got, want)
6359 }
6360 }
6361
6362
6363 func TestConnContextNotModifyingAllContexts(t *testing.T) {
6364 run(t, testConnContextNotModifyingAllContexts)
6365 }
6366 func testConnContextNotModifyingAllContexts(t *testing.T, mode testMode) {
6367 type connKey struct{}
6368 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, r *Request) {
6369 rw.Header().Set("Connection", "close")
6370 }), func(ts *httptest.Server) {
6371 ts.Config.ConnContext = func(ctx context.Context, c net.Conn) context.Context {
6372 if got := ctx.Value(connKey{}); got != nil {
6373 t.Errorf("in ConnContext, unexpected context key = %#v", got)
6374 }
6375 return context.WithValue(ctx, connKey{}, "conn")
6376 }
6377 }).ts
6378
6379 var res *Response
6380 var err error
6381
6382 res, err = ts.Client().Get(ts.URL)
6383 if err != nil {
6384 t.Fatal(err)
6385 }
6386 res.Body.Close()
6387
6388 res, err = ts.Client().Get(ts.URL)
6389 if err != nil {
6390 t.Fatal(err)
6391 }
6392 res.Body.Close()
6393 }
6394
6395
6396
6397 func TestUnsupportedTransferEncodingsReturn501(t *testing.T) {
6398 run(t, testUnsupportedTransferEncodingsReturn501, []testMode{http1Mode})
6399 }
6400 func testUnsupportedTransferEncodingsReturn501(t *testing.T, mode testMode) {
6401 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6402 w.Write([]byte("Hello, World!"))
6403 })).ts
6404
6405 serverURL, err := url.Parse(cst.URL)
6406 if err != nil {
6407 t.Fatalf("Failed to parse server URL: %v", err)
6408 }
6409
6410 unsupportedTEs := []string{
6411 "fugazi",
6412 "foo-bar",
6413 "unknown",
6414 `" chunked"`,
6415 }
6416
6417 for _, badTE := range unsupportedTEs {
6418 http1ReqBody := fmt.Sprintf(""+
6419 "POST / HTTP/1.1\r\nConnection: close\r\n"+
6420 "Host: localhost\r\nTransfer-Encoding: %s\r\n\r\n", badTE)
6421
6422 gotBody, err := fetchWireResponse(serverURL.Host, []byte(http1ReqBody))
6423 if err != nil {
6424 t.Errorf("%q. unexpected error: %v", badTE, err)
6425 continue
6426 }
6427
6428 wantBody := fmt.Sprintf("" +
6429 "HTTP/1.1 501 Not Implemented\r\nContent-Type: text/plain; charset=utf-8\r\n" +
6430 "Connection: close\r\n\r\nUnsupported transfer encoding")
6431
6432 if string(gotBody) != wantBody {
6433 t.Errorf("%q. body\ngot\n%q\nwant\n%q", badTE, gotBody, wantBody)
6434 }
6435 }
6436 }
6437
6438
6439 func TestContentEncodingNoSniffing(t *testing.T) { run(t, testContentEncodingNoSniffing) }
6440 func testContentEncodingNoSniffing(t *testing.T, mode testMode) {
6441 type setting struct {
6442 name string
6443 body []byte
6444
6445
6446
6447
6448 contentEncoding any
6449 wantContentType string
6450 }
6451
6452 settings := []*setting{
6453 {
6454 name: "gzip content-encoding, gzipped",
6455 contentEncoding: "application/gzip",
6456 wantContentType: "",
6457 body: func() []byte {
6458 buf := new(bytes.Buffer)
6459 gzw := gzip.NewWriter(buf)
6460 gzw.Write([]byte("doctype html><p>Hello</p>"))
6461 gzw.Close()
6462 return buf.Bytes()
6463 }(),
6464 },
6465 {
6466 name: "zlib content-encoding, zlibbed",
6467 contentEncoding: "application/zlib",
6468 wantContentType: "",
6469 body: func() []byte {
6470 buf := new(bytes.Buffer)
6471 zw := zlib.NewWriter(buf)
6472 zw.Write([]byte("doctype html><p>Hello</p>"))
6473 zw.Close()
6474 return buf.Bytes()
6475 }(),
6476 },
6477 {
6478 name: "no content-encoding",
6479 wantContentType: "application/x-gzip",
6480 body: func() []byte {
6481 buf := new(bytes.Buffer)
6482 gzw := gzip.NewWriter(buf)
6483 gzw.Write([]byte("doctype html><p>Hello</p>"))
6484 gzw.Close()
6485 return buf.Bytes()
6486 }(),
6487 },
6488 {
6489 name: "phony content-encoding",
6490 contentEncoding: "foo/bar",
6491 body: []byte("doctype html><p>Hello</p>"),
6492 },
6493 {
6494 name: "empty but set content-encoding",
6495 contentEncoding: "",
6496 wantContentType: "audio/mpeg",
6497 body: []byte("ID3"),
6498 },
6499 }
6500
6501 for _, tt := range settings {
6502 t.Run(tt.name, func(t *testing.T) {
6503 cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, r *Request) {
6504 if tt.contentEncoding != nil {
6505 rw.Header().Set("Content-Encoding", tt.contentEncoding.(string))
6506 }
6507 rw.Write(tt.body)
6508 }))
6509
6510 res, err := cst.c.Get(cst.ts.URL)
6511 if err != nil {
6512 t.Fatalf("Failed to fetch URL: %v", err)
6513 }
6514 defer res.Body.Close()
6515
6516 if g, w := res.Header.Get("Content-Encoding"), tt.contentEncoding; g != w {
6517 if w != nil {
6518 t.Errorf("Content-Encoding mismatch\n\tgot: %q\n\twant: %q", g, w)
6519 } else if g != "" {
6520 t.Errorf("Unexpected Content-Encoding %q", g)
6521 }
6522 }
6523
6524 if g, w := res.Header.Get("Content-Type"), tt.wantContentType; g != w {
6525 t.Errorf("Content-Type mismatch\n\tgot: %q\n\twant: %q", g, w)
6526 }
6527 })
6528 }
6529 }
6530
6531
6532
6533 func TestTimeoutHandlerSuperfluousLogs(t *testing.T) {
6534 run(t, testTimeoutHandlerSuperfluousLogs, []testMode{http1Mode})
6535 }
6536 func testTimeoutHandlerSuperfluousLogs(t *testing.T, mode testMode) {
6537 if testing.Short() {
6538 t.Skip("skipping in short mode")
6539 }
6540
6541 pc, curFile, _, _ := runtime.Caller(0)
6542 curFileBaseName := filepath.Base(curFile)
6543 testFuncName := runtime.FuncForPC(pc).Name()
6544
6545 timeoutMsg := "timed out here!"
6546
6547 tests := []struct {
6548 name string
6549 mustTimeout bool
6550 wantResp string
6551 }{
6552 {
6553 name: "return before timeout",
6554 wantResp: "HTTP/1.1 404 Not Found\r\nContent-Length: 0\r\n\r\n",
6555 },
6556 {
6557 name: "return after timeout",
6558 mustTimeout: true,
6559 wantResp: fmt.Sprintf("HTTP/1.1 503 Service Unavailable\r\nContent-Length: %d\r\n\r\n%s",
6560 len(timeoutMsg), timeoutMsg),
6561 },
6562 }
6563
6564 for _, tt := range tests {
6565 tt := tt
6566 t.Run(tt.name, func(t *testing.T) {
6567 exitHandler := make(chan bool, 1)
6568 defer close(exitHandler)
6569 lastLine := make(chan int, 1)
6570
6571 sh := HandlerFunc(func(w ResponseWriter, r *Request) {
6572 w.WriteHeader(404)
6573 w.WriteHeader(404)
6574 w.WriteHeader(404)
6575 w.WriteHeader(404)
6576 _, _, line, _ := runtime.Caller(0)
6577 lastLine <- line
6578 <-exitHandler
6579 })
6580
6581 if !tt.mustTimeout {
6582 exitHandler <- true
6583 }
6584
6585 logBuf := new(strings.Builder)
6586 srvLog := log.New(logBuf, "", 0)
6587
6588 dur := 20 * time.Millisecond
6589 if !tt.mustTimeout {
6590
6591 dur = 10 * time.Second
6592 }
6593 th := TimeoutHandler(sh, dur, timeoutMsg)
6594 cst := newClientServerTest(t, mode, th, optWithServerLog(srvLog))
6595 defer cst.close()
6596
6597 res, err := cst.c.Get(cst.ts.URL)
6598 if err != nil {
6599 t.Fatalf("Unexpected error: %v", err)
6600 }
6601
6602
6603
6604 res.Header.Del("Date")
6605 res.Header.Del("Content-Type")
6606
6607
6608 blob, _ := httputil.DumpResponse(res, true)
6609 if g, w := string(blob), tt.wantResp; g != w {
6610 t.Errorf("Response mismatch\nGot\n%q\n\nWant\n%q", g, w)
6611 }
6612
6613
6614
6615 logEntries := strings.Split(strings.TrimSpace(logBuf.String()), "\n")
6616 if g, w := len(logEntries), 3; g != w {
6617 blob, _ := json.MarshalIndent(logEntries, "", " ")
6618 t.Fatalf("Server logs count mismatch\ngot %d, want %d\n\nGot\n%s\n", g, w, blob)
6619 }
6620
6621 lastSpuriousLine := <-lastLine
6622 firstSpuriousLine := lastSpuriousLine - 3
6623
6624
6625 for i, logEntry := range logEntries {
6626 wantLine := firstSpuriousLine + i
6627 pat := fmt.Sprintf("^http: superfluous response.WriteHeader call from %s.func\\d+.\\d+ \\(%s:%d\\)$",
6628 testFuncName, curFileBaseName, wantLine)
6629 re := regexp.MustCompile(pat)
6630 if !re.MatchString(logEntry) {
6631 t.Errorf("Log entry mismatch\n\t%s\ndoes not match\n\t%s", logEntry, pat)
6632 }
6633 }
6634 })
6635 }
6636 }
6637
6638
6639
6640
6641 func fetchWireResponse(host string, http1ReqBody []byte) ([]byte, error) {
6642 conn, err := net.Dial("tcp", host)
6643 if err != nil {
6644 return nil, err
6645 }
6646 defer conn.Close()
6647
6648 if _, err := conn.Write(http1ReqBody); err != nil {
6649 return nil, err
6650 }
6651 return io.ReadAll(conn)
6652 }
6653
6654 func BenchmarkResponseStatusLine(b *testing.B) {
6655 b.ReportAllocs()
6656 b.RunParallel(func(pb *testing.PB) {
6657 bw := bufio.NewWriter(io.Discard)
6658 var buf3 [3]byte
6659 for pb.Next() {
6660 Export_writeStatusLine(bw, true, 200, buf3[:])
6661 }
6662 })
6663 }
6664
6665 func TestDisableKeepAliveUpgrade(t *testing.T) {
6666 run(t, testDisableKeepAliveUpgrade, []testMode{http1Mode})
6667 }
6668 func testDisableKeepAliveUpgrade(t *testing.T, mode testMode) {
6669 if testing.Short() {
6670 t.Skip("skipping in short mode")
6671 }
6672
6673 s := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6674 w.Header().Set("Connection", "Upgrade")
6675 w.Header().Set("Upgrade", "someProto")
6676 w.WriteHeader(StatusSwitchingProtocols)
6677 c, buf, err := w.(Hijacker).Hijack()
6678 if err != nil {
6679 return
6680 }
6681 defer c.Close()
6682
6683
6684
6685 io.Copy(c, buf)
6686 }), func(ts *httptest.Server) {
6687 ts.Config.SetKeepAlivesEnabled(false)
6688 }).ts
6689
6690 cl := s.Client()
6691 cl.Transport.(*Transport).DisableKeepAlives = true
6692
6693 resp, err := cl.Get(s.URL)
6694 if err != nil {
6695 t.Fatalf("failed to perform request: %v", err)
6696 }
6697 defer resp.Body.Close()
6698
6699 if resp.StatusCode != StatusSwitchingProtocols {
6700 t.Fatalf("unexpected status code: %v", resp.StatusCode)
6701 }
6702
6703 rwc, ok := resp.Body.(io.ReadWriteCloser)
6704 if !ok {
6705 t.Fatalf("Response.Body is not an io.ReadWriteCloser: %T", resp.Body)
6706 }
6707
6708 _, err = rwc.Write([]byte("hello"))
6709 if err != nil {
6710 t.Fatalf("failed to write to body: %v", err)
6711 }
6712
6713 b := make([]byte, 5)
6714 _, err = io.ReadFull(rwc, b)
6715 if err != nil {
6716 t.Fatalf("failed to read from body: %v", err)
6717 }
6718
6719 if string(b) != "hello" {
6720 t.Fatalf("unexpected value read from body:\ngot: %q\nwant: %q", b, "hello")
6721 }
6722 }
6723
6724 type tlogWriter struct{ t *testing.T }
6725
6726 func (w tlogWriter) Write(p []byte) (int, error) {
6727 w.t.Log(string(p))
6728 return len(p), nil
6729 }
6730
6731 func TestWriteHeaderSwitchingProtocols(t *testing.T) {
6732 run(t, testWriteHeaderSwitchingProtocols, []testMode{http1Mode})
6733 }
6734 func testWriteHeaderSwitchingProtocols(t *testing.T, mode testMode) {
6735 const wantBody = "want"
6736 const wantUpgrade = "someProto"
6737 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6738 w.Header().Set("Connection", "Upgrade")
6739 w.Header().Set("Upgrade", wantUpgrade)
6740 w.WriteHeader(StatusSwitchingProtocols)
6741 NewResponseController(w).Flush()
6742
6743
6744 w.WriteHeader(200)
6745 if _, err := w.Write([]byte("x")); err == nil {
6746 t.Errorf("Write to body after 101 Switching Protocols unexpectedly succeeded")
6747 }
6748
6749 c, _, err := NewResponseController(w).Hijack()
6750 if err != nil {
6751 t.Errorf("Hijack: %v", err)
6752 return
6753 }
6754 defer c.Close()
6755 if _, err := c.Write([]byte(wantBody)); err != nil {
6756 t.Errorf("Write to hijacked body: %v", err)
6757 }
6758 }), func(ts *httptest.Server) {
6759
6760 ts.Config.ErrorLog = log.New(tlogWriter{t}, "log: ", 0)
6761 }).ts
6762
6763 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
6764 if err != nil {
6765 t.Fatalf("net.Dial: %v", err)
6766 }
6767 _, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: foo\r\n\r\n"))
6768 if err != nil {
6769 t.Fatalf("conn.Write: %v", err)
6770 }
6771 defer conn.Close()
6772
6773 r := bufio.NewReader(conn)
6774 res, err := ReadResponse(r, &Request{Method: "GET"})
6775 if err != nil {
6776 t.Fatal("ReadResponse error:", err)
6777 }
6778 if res.StatusCode != StatusSwitchingProtocols {
6779 t.Errorf("Response StatusCode=%v, want 101", res.StatusCode)
6780 }
6781 if got := res.Header.Get("Upgrade"); got != wantUpgrade {
6782 t.Errorf("Response Upgrade header = %q, want %q", got, wantUpgrade)
6783 }
6784 body, err := io.ReadAll(r)
6785 if err != nil {
6786 t.Error(err)
6787 }
6788 if string(body) != wantBody {
6789 t.Errorf("Response body = %q, want %q", string(body), wantBody)
6790 }
6791 }
6792
6793 func TestMuxRedirectRelative(t *testing.T) {
6794 setParallel(t)
6795 req, err := ReadRequest(bufio.NewReader(strings.NewReader("GET http://example.com HTTP/1.1\r\nHost: test\r\n\r\n")))
6796 if err != nil {
6797 t.Errorf("%s", err)
6798 }
6799 mux := NewServeMux()
6800 resp := httptest.NewRecorder()
6801 mux.ServeHTTP(resp, req)
6802 if got, want := resp.Header().Get("Location"), "/"; got != want {
6803 t.Errorf("Location header expected %q; got %q", want, got)
6804 }
6805 if got, want := resp.Code, StatusMovedPermanently; got != want {
6806 t.Errorf("Expected response code %d; got %d", want, got)
6807 }
6808 }
6809
6810
6811 func TestQuerySemicolon(t *testing.T) {
6812 t.Cleanup(func() { afterTest(t) })
6813
6814 tests := []struct {
6815 query string
6816 xNoSemicolons string
6817 xWithSemicolons string
6818 expectParseFormErr bool
6819 }{
6820 {"?a=1;x=bad&x=good", "good", "bad", true},
6821 {"?a=1;b=bad&x=good", "good", "good", true},
6822 {"?a=1%3Bx=bad&x=good%3B", "good;", "good;", false},
6823 {"?a=1;x=good;x=bad", "", "good", true},
6824 }
6825
6826 run(t, func(t *testing.T, mode testMode) {
6827 for _, tt := range tests {
6828 t.Run(tt.query+"/allow=false", func(t *testing.T) {
6829 allowSemicolons := false
6830 testQuerySemicolon(t, mode, tt.query, tt.xNoSemicolons, allowSemicolons, tt.expectParseFormErr)
6831 })
6832 t.Run(tt.query+"/allow=true", func(t *testing.T) {
6833 allowSemicolons, expectParseFormErr := true, false
6834 testQuerySemicolon(t, mode, tt.query, tt.xWithSemicolons, allowSemicolons, expectParseFormErr)
6835 })
6836 }
6837 })
6838 }
6839
6840 func testQuerySemicolon(t *testing.T, mode testMode, query string, wantX string, allowSemicolons, expectParseFormErr bool) {
6841 writeBackX := func(w ResponseWriter, r *Request) {
6842 x := r.URL.Query().Get("x")
6843 if expectParseFormErr {
6844 if err := r.ParseForm(); err == nil || !strings.Contains(err.Error(), "semicolon") {
6845 t.Errorf("expected error mentioning semicolons from ParseForm, got %v", err)
6846 }
6847 } else {
6848 if err := r.ParseForm(); err != nil {
6849 t.Errorf("expected no error from ParseForm, got %v", err)
6850 }
6851 }
6852 if got := r.FormValue("x"); x != got {
6853 t.Errorf("got %q from FormValue, want %q", got, x)
6854 }
6855 fmt.Fprintf(w, "%s", x)
6856 }
6857
6858 h := Handler(HandlerFunc(writeBackX))
6859 if allowSemicolons {
6860 h = AllowQuerySemicolons(h)
6861 }
6862
6863 logBuf := &strings.Builder{}
6864 ts := newClientServerTest(t, mode, h, func(ts *httptest.Server) {
6865 ts.Config.ErrorLog = log.New(logBuf, "", 0)
6866 }).ts
6867
6868 req, _ := NewRequest("GET", ts.URL+query, nil)
6869 res, err := ts.Client().Do(req)
6870 if err != nil {
6871 t.Fatal(err)
6872 }
6873 slurp, _ := io.ReadAll(res.Body)
6874 res.Body.Close()
6875 if got, want := res.StatusCode, 200; got != want {
6876 t.Errorf("Status = %d; want = %d", got, want)
6877 }
6878 if got, want := string(slurp), wantX; got != want {
6879 t.Errorf("Body = %q; want = %q", got, want)
6880 }
6881 }
6882
6883 func TestMaxBytesHandler(t *testing.T) {
6884
6885 defer afterTest(t)
6886
6887 for _, maxSize := range []int64{100, 1_000, 1_000_000} {
6888 for _, requestSize := range []int64{100, 1_000, 1_000_000} {
6889 t.Run(fmt.Sprintf("max size %d request size %d", maxSize, requestSize),
6890 func(t *testing.T) {
6891 run(t, func(t *testing.T, mode testMode) {
6892 testMaxBytesHandler(t, mode, maxSize, requestSize)
6893 }, testNotParallel)
6894 })
6895 }
6896 }
6897 }
6898
6899 func testMaxBytesHandler(t *testing.T, mode testMode, maxSize, requestSize int64) {
6900 runTimeSensitiveTest(t, []time.Duration{
6901 1 * time.Millisecond,
6902 5 * time.Millisecond,
6903 10 * time.Millisecond,
6904 50 * time.Millisecond,
6905 100 * time.Millisecond,
6906 500 * time.Millisecond,
6907 time.Second,
6908 5 * time.Second,
6909 }, func(t *testing.T, timeout time.Duration) error {
6910 SetRSTAvoidanceDelay(t, timeout)
6911 t.Logf("set RST avoidance delay to %v", timeout)
6912
6913 var (
6914 handlerN int64
6915 handlerErr error
6916 )
6917 echo := HandlerFunc(func(w ResponseWriter, r *Request) {
6918 var buf bytes.Buffer
6919 handlerN, handlerErr = io.Copy(&buf, r.Body)
6920 io.Copy(w, &buf)
6921 })
6922
6923 cst := newClientServerTest(t, mode, MaxBytesHandler(echo, maxSize))
6924
6925
6926 defer cst.close()
6927 ts := cst.ts
6928 c := ts.Client()
6929
6930 body := strings.Repeat("a", int(requestSize))
6931 var wg sync.WaitGroup
6932 defer wg.Wait()
6933 getBody := func() (io.ReadCloser, error) {
6934 wg.Add(1)
6935 body := &wgReadCloser{
6936 Reader: strings.NewReader(body),
6937 wg: &wg,
6938 }
6939 return body, nil
6940 }
6941 reqBody, _ := getBody()
6942 req, err := NewRequest("POST", ts.URL, reqBody)
6943 if err != nil {
6944 reqBody.Close()
6945 t.Fatal(err)
6946 }
6947 req.ContentLength = int64(len(body))
6948 req.GetBody = getBody
6949 req.Header.Set("Content-Type", "text/plain")
6950
6951 var buf strings.Builder
6952 res, err := c.Do(req)
6953 if err != nil {
6954 return fmt.Errorf("unexpected connection error: %v", err)
6955 } else {
6956 _, err = io.Copy(&buf, res.Body)
6957 res.Body.Close()
6958 if err != nil {
6959 return fmt.Errorf("unexpected read error: %v", err)
6960 }
6961 }
6962
6963
6964
6965
6966 if handlerN > maxSize {
6967 t.Errorf("expected max request body %d; got %d", maxSize, handlerN)
6968 }
6969 if requestSize > maxSize && handlerErr == nil {
6970 t.Error("expected error on handler side; got nil")
6971 }
6972 if requestSize <= maxSize {
6973 if handlerErr != nil {
6974 t.Errorf("%d expected nil error on handler side; got %v", requestSize, handlerErr)
6975 }
6976 if handlerN != requestSize {
6977 t.Errorf("expected request of size %d; got %d", requestSize, handlerN)
6978 }
6979 }
6980 if buf.Len() != int(handlerN) {
6981 t.Errorf("expected echo of size %d; got %d", handlerN, buf.Len())
6982 }
6983
6984 return nil
6985 })
6986 }
6987
6988 func TestEarlyHints(t *testing.T) {
6989 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
6990 h := w.Header()
6991 h.Add("Link", "</style.css>; rel=preload; as=style")
6992 h.Add("Link", "</script.js>; rel=preload; as=script")
6993 w.WriteHeader(StatusEarlyHints)
6994
6995 h.Add("Link", "</foo.js>; rel=preload; as=script")
6996 w.WriteHeader(StatusEarlyHints)
6997
6998 w.Write([]byte("stuff"))
6999 }))
7000
7001 got := ht.rawResponse("GET / HTTP/1.1\nHost: golang.org")
7002 expected := "HTTP/1.1 103 Early Hints\r\nLink: </style.css>; rel=preload; as=style\r\nLink: </script.js>; rel=preload; as=script\r\n\r\nHTTP/1.1 103 Early Hints\r\nLink: </style.css>; rel=preload; as=style\r\nLink: </script.js>; rel=preload; as=script\r\nLink: </foo.js>; rel=preload; as=script\r\n\r\nHTTP/1.1 200 OK\r\nLink: </style.css>; rel=preload; as=style\r\nLink: </script.js>; rel=preload; as=script\r\nLink: </foo.js>; rel=preload; as=script\r\nDate: "
7003 if !strings.Contains(got, expected) {
7004 t.Errorf("unexpected response; got %q; should start by %q", got, expected)
7005 }
7006 }
7007 func TestProcessing(t *testing.T) {
7008 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
7009 w.WriteHeader(StatusProcessing)
7010 w.Write([]byte("stuff"))
7011 }))
7012
7013 got := ht.rawResponse("GET / HTTP/1.1\nHost: golang.org")
7014 expected := "HTTP/1.1 102 Processing\r\n\r\nHTTP/1.1 200 OK\r\nDate: "
7015 if !strings.Contains(got, expected) {
7016 t.Errorf("unexpected response; got %q; should start by %q", got, expected)
7017 }
7018 }
7019
7020 func TestParseFormCleanup(t *testing.T) { run(t, testParseFormCleanup) }
7021 func testParseFormCleanup(t *testing.T, mode testMode) {
7022 if mode == http2Mode {
7023 t.Skip("https://go.dev/issue/20253")
7024 }
7025
7026 const maxMemory = 1024
7027 const key = "file"
7028
7029 if runtime.GOOS == "windows" {
7030
7031 t.Skip("https://go.dev/issue/25965")
7032 }
7033
7034 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7035 r.ParseMultipartForm(maxMemory)
7036 f, _, err := r.FormFile(key)
7037 if err != nil {
7038 t.Errorf("r.FormFile(%q) = %v", key, err)
7039 return
7040 }
7041 of, ok := f.(*os.File)
7042 if !ok {
7043 t.Errorf("r.FormFile(%q) returned type %T, want *os.File", key, f)
7044 return
7045 }
7046 w.Write([]byte(of.Name()))
7047 }))
7048
7049 fBuf := new(bytes.Buffer)
7050 mw := multipart.NewWriter(fBuf)
7051 mf, err := mw.CreateFormFile(key, "myfile.txt")
7052 if err != nil {
7053 t.Fatal(err)
7054 }
7055 if _, err := mf.Write(bytes.Repeat([]byte("A"), maxMemory*2)); err != nil {
7056 t.Fatal(err)
7057 }
7058 if err := mw.Close(); err != nil {
7059 t.Fatal(err)
7060 }
7061 req, err := NewRequest("POST", cst.ts.URL, fBuf)
7062 if err != nil {
7063 t.Fatal(err)
7064 }
7065 req.Header.Set("Content-Type", mw.FormDataContentType())
7066 res, err := cst.c.Do(req)
7067 if err != nil {
7068 t.Fatal(err)
7069 }
7070 defer res.Body.Close()
7071 fname, err := io.ReadAll(res.Body)
7072 if err != nil {
7073 t.Fatal(err)
7074 }
7075 cst.close()
7076 if _, err := os.Stat(string(fname)); !errors.Is(err, os.ErrNotExist) {
7077 t.Errorf("file %q exists after HTTP handler returned", string(fname))
7078 }
7079 }
7080
7081 func TestHeadBody(t *testing.T) {
7082 const identityMode = false
7083 const chunkedMode = true
7084 run(t, func(t *testing.T, mode testMode) {
7085 t.Run("identity", func(t *testing.T) { testHeadBody(t, mode, identityMode, "HEAD") })
7086 t.Run("chunked", func(t *testing.T) { testHeadBody(t, mode, chunkedMode, "HEAD") })
7087 })
7088 }
7089
7090 func TestGetBody(t *testing.T) {
7091 const identityMode = false
7092 const chunkedMode = true
7093 run(t, func(t *testing.T, mode testMode) {
7094 t.Run("identity", func(t *testing.T) { testHeadBody(t, mode, identityMode, "GET") })
7095 t.Run("chunked", func(t *testing.T) { testHeadBody(t, mode, chunkedMode, "GET") })
7096 })
7097 }
7098
7099 func testHeadBody(t *testing.T, mode testMode, chunked bool, method string) {
7100 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7101 b, err := io.ReadAll(r.Body)
7102 if err != nil {
7103 t.Errorf("server reading body: %v", err)
7104 return
7105 }
7106 w.Header().Set("X-Request-Body", string(b))
7107 w.Header().Set("Content-Length", "0")
7108 }))
7109 defer cst.close()
7110 for _, reqBody := range []string{
7111 "",
7112 "",
7113 "request_body",
7114 "",
7115 } {
7116 var bodyReader io.Reader
7117 if reqBody != "" {
7118 bodyReader = strings.NewReader(reqBody)
7119 if chunked {
7120 bodyReader = bufio.NewReader(bodyReader)
7121 }
7122 }
7123 req, err := NewRequest(method, cst.ts.URL, bodyReader)
7124 if err != nil {
7125 t.Fatal(err)
7126 }
7127 res, err := cst.c.Do(req)
7128 if err != nil {
7129 t.Fatal(err)
7130 }
7131 res.Body.Close()
7132 if got, want := res.StatusCode, 200; got != want {
7133 t.Errorf("%v request with %d-byte body: StatusCode = %v, want %v", method, len(reqBody), got, want)
7134 }
7135 if got, want := res.Header.Get("X-Request-Body"), reqBody; got != want {
7136 t.Errorf("%v request with %d-byte body: handler read body %q, want %q", method, len(reqBody), got, want)
7137 }
7138 }
7139 }
7140
7141
7142
7143 func TestDisableContentLength(t *testing.T) { run(t, testDisableContentLength) }
7144 func testDisableContentLength(t *testing.T, mode testMode) {
7145 if mode == http2Mode {
7146 t.Skip("skipping until h2_bundle.go is updated; see https://go-review.googlesource.com/c/net/+/471535")
7147 }
7148
7149 noCL := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7150 w.Header()["Content-Length"] = nil
7151 fmt.Fprintf(w, "OK")
7152 }))
7153
7154 res, err := noCL.c.Get(noCL.ts.URL)
7155 if err != nil {
7156 t.Fatal(err)
7157 }
7158 if got, haveCL := res.Header["Content-Length"]; haveCL {
7159 t.Errorf("Unexpected Content-Length: %q", got)
7160 }
7161 if err := res.Body.Close(); err != nil {
7162 t.Fatal(err)
7163 }
7164
7165 withCL := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7166 fmt.Fprintf(w, "OK")
7167 }))
7168
7169 res, err = withCL.c.Get(withCL.ts.URL)
7170 if err != nil {
7171 t.Fatal(err)
7172 }
7173 if got := res.Header.Get("Content-Length"); got != "2" {
7174 t.Errorf("Content-Length: %q; want 2", got)
7175 }
7176 if err := res.Body.Close(); err != nil {
7177 t.Fatal(err)
7178 }
7179 }
7180
7181 func TestErrorContentLength(t *testing.T) { run(t, testErrorContentLength) }
7182 func testErrorContentLength(t *testing.T, mode testMode) {
7183 const errorBody = "an error occurred"
7184 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7185 w.Header().Set("Content-Length", "1000")
7186 Error(w, errorBody, 400)
7187 }))
7188 res, err := cst.c.Get(cst.ts.URL)
7189 if err != nil {
7190 t.Fatalf("Get(%q) = %v", cst.ts.URL, err)
7191 }
7192 defer res.Body.Close()
7193 body, err := io.ReadAll(res.Body)
7194 if err != nil {
7195 t.Fatalf("io.ReadAll(res.Body) = %v", err)
7196 }
7197 if string(body) != errorBody+"\n" {
7198 t.Fatalf("read body: %q, want %q", string(body), errorBody)
7199 }
7200 }
7201
7202 func TestError(t *testing.T) {
7203 w := httptest.NewRecorder()
7204 w.Header().Set("Content-Length", "1")
7205 w.Header().Set("X-Content-Type-Options", "scratch and sniff")
7206 w.Header().Set("Other", "foo")
7207 Error(w, "oops", 432)
7208
7209 h := w.Header()
7210 for _, hdr := range []string{"Content-Length"} {
7211 if v, ok := h[hdr]; ok {
7212 t.Errorf("%s: %q, want not present", hdr, v)
7213 }
7214 }
7215 if v := h.Get("Content-Type"); v != "text/plain; charset=utf-8" {
7216 t.Errorf("Content-Type: %q, want %q", v, "text/plain; charset=utf-8")
7217 }
7218 if v := h.Get("X-Content-Type-Options"); v != "nosniff" {
7219 t.Errorf("X-Content-Type-Options: %q, want %q", v, "nosniff")
7220 }
7221 }
7222
7223 func TestServerReadAfterWriteHeader100Continue(t *testing.T) {
7224 run(t, testServerReadAfterWriteHeader100Continue)
7225 }
7226 func testServerReadAfterWriteHeader100Continue(t *testing.T, mode testMode) {
7227 t.Skip("https://go.dev/issue/67555")
7228 body := []byte("body")
7229 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7230 w.WriteHeader(200)
7231 NewResponseController(w).Flush()
7232 io.ReadAll(r.Body)
7233 w.Write(body)
7234 }), func(tr *Transport) {
7235 tr.ExpectContinueTimeout = 24 * time.Hour
7236 })
7237
7238 req, _ := NewRequest("GET", cst.ts.URL, strings.NewReader("body"))
7239 req.Header.Set("Expect", "100-continue")
7240 res, err := cst.c.Do(req)
7241 if err != nil {
7242 t.Fatalf("Get(%q) = %v", cst.ts.URL, err)
7243 }
7244 defer res.Body.Close()
7245 got, err := io.ReadAll(res.Body)
7246 if err != nil {
7247 t.Fatalf("io.ReadAll(res.Body) = %v", err)
7248 }
7249 if !bytes.Equal(got, body) {
7250 t.Fatalf("response body = %q, want %q", got, body)
7251 }
7252 }
7253
7254 func TestServerReadAfterHandlerDone100Continue(t *testing.T) {
7255 run(t, testServerReadAfterHandlerDone100Continue)
7256 }
7257 func testServerReadAfterHandlerDone100Continue(t *testing.T, mode testMode) {
7258 t.Skip("https://go.dev/issue/67555")
7259 readyc := make(chan struct{})
7260 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7261 go func() {
7262 <-readyc
7263 io.ReadAll(r.Body)
7264 <-readyc
7265 }()
7266 }), func(tr *Transport) {
7267 tr.ExpectContinueTimeout = 24 * time.Hour
7268 })
7269
7270 req, _ := NewRequest("GET", cst.ts.URL, strings.NewReader("body"))
7271 req.Header.Set("Expect", "100-continue")
7272 res, err := cst.c.Do(req)
7273 if err != nil {
7274 t.Fatalf("Get(%q) = %v", cst.ts.URL, err)
7275 }
7276 res.Body.Close()
7277 readyc <- struct{}{}
7278 readyc <- struct{}{}
7279 }
7280
7281 func TestServerReadAfterHandlerAbort100Continue(t *testing.T) {
7282 run(t, testServerReadAfterHandlerAbort100Continue)
7283 }
7284 func testServerReadAfterHandlerAbort100Continue(t *testing.T, mode testMode) {
7285 t.Skip("https://go.dev/issue/67555")
7286 readyc := make(chan struct{})
7287 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7288 go func() {
7289 <-readyc
7290 io.ReadAll(r.Body)
7291 <-readyc
7292 }()
7293 panic(ErrAbortHandler)
7294 }), func(tr *Transport) {
7295 tr.ExpectContinueTimeout = 24 * time.Hour
7296 })
7297
7298 req, _ := NewRequest("GET", cst.ts.URL, strings.NewReader("body"))
7299 req.Header.Set("Expect", "100-continue")
7300 res, err := cst.c.Do(req)
7301 if err == nil {
7302 res.Body.Close()
7303 }
7304 readyc <- struct{}{}
7305 readyc <- struct{}{}
7306 }
7307
7308
7309 func TestServerTLSNextProtos(t *testing.T) {
7310 run(t, testServerTLSNextProtos, []testMode{https1Mode, http2Mode})
7311 }
7312 func testServerTLSNextProtos(t *testing.T, mode testMode) {
7313 CondSkipHTTP2(t)
7314
7315 cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
7316 if err != nil {
7317 t.Fatal(err)
7318 }
7319 leafCert, err := x509.ParseCertificate(cert.Certificate[0])
7320 if err != nil {
7321 t.Fatal(err)
7322 }
7323 certpool := x509.NewCertPool()
7324 certpool.AddCert(leafCert)
7325
7326 protos := new(Protocols)
7327 switch mode {
7328 case https1Mode:
7329 protos.SetHTTP1(true)
7330 case http2Mode:
7331 protos.SetHTTP2(true)
7332 }
7333
7334 wantNextProtos := []string{"http/1.1", "h2", "other"}
7335 nextProtos := slices.Clone(wantNextProtos)
7336
7337
7338 srv := &Server{
7339 TLSConfig: &tls.Config{
7340 Certificates: []tls.Certificate{cert},
7341 NextProtos: nextProtos,
7342 },
7343 Handler: HandlerFunc(func(w ResponseWriter, req *Request) {}),
7344 Protocols: protos,
7345 }
7346 tr := &Transport{
7347 TLSClientConfig: &tls.Config{
7348 RootCAs: certpool,
7349 NextProtos: nextProtos,
7350 },
7351 Protocols: protos,
7352 }
7353
7354 listener := newLocalListener(t)
7355 srvc := make(chan error, 1)
7356 go func() {
7357 srvc <- srv.ServeTLS(listener, "", "")
7358 }()
7359 t.Cleanup(func() {
7360 srv.Close()
7361 <-srvc
7362 })
7363
7364 client := &Client{Transport: tr}
7365 resp, err := client.Get("https://" + listener.Addr().String())
7366 if err != nil {
7367 t.Fatal(err)
7368 }
7369 resp.Body.Close()
7370
7371 if !slices.Equal(nextProtos, wantNextProtos) {
7372 t.Fatalf("after running test: original NextProtos slice = %v, want %v", nextProtos, wantNextProtos)
7373 }
7374 }
7375
7376 func TestInvalidChunkedBodies(t *testing.T) {
7377 for _, test := range []struct {
7378 name string
7379 b string
7380 }{{
7381 name: "bare LF in chunk size",
7382 b: "1\na\r\n0\r\n\r\n",
7383 }, {
7384 name: "bare LF at body end",
7385 b: "1\r\na\r\n0\r\n\n",
7386 }} {
7387 t.Run(test.name, func(t *testing.T) {
7388 reqc := make(chan error)
7389 ts := newClientServerTest(t, http1Mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7390 got, err := io.ReadAll(r.Body)
7391 if err == nil {
7392 t.Logf("read body: %q", got)
7393 }
7394 reqc <- err
7395 })).ts
7396
7397 serverURL, err := url.Parse(ts.URL)
7398 if err != nil {
7399 t.Fatal(err)
7400 }
7401
7402 conn, err := net.Dial("tcp", serverURL.Host)
7403 if err != nil {
7404 t.Fatal(err)
7405 }
7406
7407 if _, err := conn.Write([]byte(
7408 "POST / HTTP/1.1\r\n" +
7409 "Host: localhost\r\n" +
7410 "Transfer-Encoding: chunked\r\n" +
7411 "Connection: close\r\n" +
7412 "\r\n" +
7413 test.b)); err != nil {
7414 t.Fatal(err)
7415 }
7416 conn.(*net.TCPConn).CloseWrite()
7417
7418 if err := <-reqc; err == nil {
7419 t.Errorf("server handler: io.ReadAll(r.Body) succeeded, want error")
7420 }
7421 })
7422 }
7423 }
7424
View as plain text