Source file src/net/http/httputil/reverseproxy_test.go

     1  // Copyright 2011 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // Reverse proxy tests.
     6  
     7  package httputil
     8  
     9  import (
    10  	"bufio"
    11  	"bytes"
    12  	"context"
    13  	"errors"
    14  	"fmt"
    15  	"io"
    16  	"log"
    17  	"net"
    18  	"net/http"
    19  	"net/http/httptest"
    20  	"net/http/httptrace"
    21  	"net/http/internal/ascii"
    22  	"net/textproto"
    23  	"net/url"
    24  	"os"
    25  	"reflect"
    26  	"runtime"
    27  	"slices"
    28  	"strconv"
    29  	"strings"
    30  	"sync"
    31  	"testing"
    32  	"time"
    33  )
    34  
    35  const fakeHopHeader = "X-Fake-Hop-Header-For-Test"
    36  
    37  func init() {
    38  	inOurTests = true
    39  	hopHeaders = append(hopHeaders, fakeHopHeader)
    40  }
    41  
    42  func TestReverseProxy(t *testing.T) {
    43  	const backendResponse = "I am the backend"
    44  	const backendStatus = 404
    45  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    46  		if r.Method == "GET" && r.FormValue("mode") == "hangup" {
    47  			c, _, _ := w.(http.Hijacker).Hijack()
    48  			c.Close()
    49  			return
    50  		}
    51  		if len(r.TransferEncoding) > 0 {
    52  			t.Errorf("backend got unexpected TransferEncoding: %v", r.TransferEncoding)
    53  		}
    54  		if r.Header.Get("X-Forwarded-For") == "" {
    55  			t.Errorf("didn't get X-Forwarded-For header")
    56  		}
    57  		if c := r.Header.Get("Connection"); c != "" {
    58  			t.Errorf("handler got Connection header value %q", c)
    59  		}
    60  		if c := r.Header.Get("Te"); c != "trailers" {
    61  			t.Errorf("handler got Te header value %q; want 'trailers'", c)
    62  		}
    63  		if c := r.Header.Get("Upgrade"); c != "" {
    64  			t.Errorf("handler got Upgrade header value %q", c)
    65  		}
    66  		if c := r.Header.Get("Proxy-Connection"); c != "" {
    67  			t.Errorf("handler got Proxy-Connection header value %q", c)
    68  		}
    69  		if g, e := r.Host, "some-name"; g != e {
    70  			t.Errorf("backend got Host header %q, want %q", g, e)
    71  		}
    72  		w.Header().Set("Trailers", "not a special header field name")
    73  		w.Header().Set("Trailer", "X-Trailer")
    74  		w.Header().Set("X-Foo", "bar")
    75  		w.Header().Set("Upgrade", "foo")
    76  		w.Header().Set(fakeHopHeader, "foo")
    77  		w.Header().Add("X-Multi-Value", "foo")
    78  		w.Header().Add("X-Multi-Value", "bar")
    79  		http.SetCookie(w, &http.Cookie{Name: "flavor", Value: "chocolateChip"})
    80  		w.WriteHeader(backendStatus)
    81  		w.Write([]byte(backendResponse))
    82  		w.Header().Set("X-Trailer", "trailer_value")
    83  		w.Header().Set(http.TrailerPrefix+"X-Unannounced-Trailer", "unannounced_trailer_value")
    84  	}))
    85  	defer backend.Close()
    86  	backendURL, err := url.Parse(backend.URL)
    87  	if err != nil {
    88  		t.Fatal(err)
    89  	}
    90  	proxyHandler := NewSingleHostReverseProxy(backendURL)
    91  	proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
    92  	frontend := httptest.NewServer(proxyHandler)
    93  	defer frontend.Close()
    94  	frontendClient := frontend.Client()
    95  
    96  	getReq, _ := http.NewRequest("GET", frontend.URL, nil)
    97  	getReq.Host = "some-name"
    98  	getReq.Header.Set("Connection", "close, TE")
    99  	getReq.Header.Add("Te", "foo")
   100  	getReq.Header.Add("Te", "bar, trailers")
   101  	getReq.Header.Set("Proxy-Connection", "should be deleted")
   102  	getReq.Header.Set("Upgrade", "foo")
   103  	getReq.Close = true
   104  	res, err := frontendClient.Do(getReq)
   105  	if err != nil {
   106  		t.Fatalf("Get: %v", err)
   107  	}
   108  	if g, e := res.StatusCode, backendStatus; g != e {
   109  		t.Errorf("got res.StatusCode %d; expected %d", g, e)
   110  	}
   111  	if g, e := res.Header.Get("X-Foo"), "bar"; g != e {
   112  		t.Errorf("got X-Foo %q; expected %q", g, e)
   113  	}
   114  	if c := res.Header.Get(fakeHopHeader); c != "" {
   115  		t.Errorf("got %s header value %q", fakeHopHeader, c)
   116  	}
   117  	if g, e := res.Header.Get("Trailers"), "not a special header field name"; g != e {
   118  		t.Errorf("header Trailers = %q; want %q", g, e)
   119  	}
   120  	if g, e := len(res.Header["X-Multi-Value"]), 2; g != e {
   121  		t.Errorf("got %d X-Multi-Value header values; expected %d", g, e)
   122  	}
   123  	if g, e := len(res.Header["Set-Cookie"]), 1; g != e {
   124  		t.Fatalf("got %d SetCookies, want %d", g, e)
   125  	}
   126  	if g, e := res.Trailer, (http.Header{"X-Trailer": nil}); !reflect.DeepEqual(g, e) {
   127  		t.Errorf("before reading body, Trailer = %#v; want %#v", g, e)
   128  	}
   129  	if cookie := res.Cookies()[0]; cookie.Name != "flavor" {
   130  		t.Errorf("unexpected cookie %q", cookie.Name)
   131  	}
   132  	bodyBytes, _ := io.ReadAll(res.Body)
   133  	if g, e := string(bodyBytes), backendResponse; g != e {
   134  		t.Errorf("got body %q; expected %q", g, e)
   135  	}
   136  	if g, e := res.Trailer.Get("X-Trailer"), "trailer_value"; g != e {
   137  		t.Errorf("Trailer(X-Trailer) = %q ; want %q", g, e)
   138  	}
   139  	if g, e := res.Trailer.Get("X-Unannounced-Trailer"), "unannounced_trailer_value"; g != e {
   140  		t.Errorf("Trailer(X-Unannounced-Trailer) = %q ; want %q", g, e)
   141  	}
   142  	res.Body.Close()
   143  
   144  	// Test that a backend failing to be reached or one which doesn't return
   145  	// a response results in a StatusBadGateway.
   146  	getReq, _ = http.NewRequest("GET", frontend.URL+"/?mode=hangup", nil)
   147  	getReq.Close = true
   148  	res, err = frontendClient.Do(getReq)
   149  	if err != nil {
   150  		t.Fatal(err)
   151  	}
   152  	res.Body.Close()
   153  	if res.StatusCode != http.StatusBadGateway {
   154  		t.Errorf("request to bad proxy = %v; want 502 StatusBadGateway", res.Status)
   155  	}
   156  
   157  }
   158  
   159  // Issue 16875: remove any proxied headers mentioned in the "Connection"
   160  // header value.
   161  func TestReverseProxyStripHeadersPresentInConnection(t *testing.T) {
   162  	const fakeConnectionToken = "X-Fake-Connection-Token"
   163  	const backendResponse = "I am the backend"
   164  
   165  	// someConnHeader is some arbitrary header to be declared as a hop-by-hop header
   166  	// in the Request's Connection header.
   167  	const someConnHeader = "X-Some-Conn-Header"
   168  
   169  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   170  		if c := r.Header.Get("Connection"); c != "" {
   171  			t.Errorf("handler got header %q = %q; want empty", "Connection", c)
   172  		}
   173  		if c := r.Header.Get(fakeConnectionToken); c != "" {
   174  			t.Errorf("handler got header %q = %q; want empty", fakeConnectionToken, c)
   175  		}
   176  		if c := r.Header.Get(someConnHeader); c != "" {
   177  			t.Errorf("handler got header %q = %q; want empty", someConnHeader, c)
   178  		}
   179  		w.Header().Add("Connection", "Upgrade, "+fakeConnectionToken)
   180  		w.Header().Add("Connection", someConnHeader)
   181  		w.Header().Set(someConnHeader, "should be deleted")
   182  		w.Header().Set(fakeConnectionToken, "should be deleted")
   183  		io.WriteString(w, backendResponse)
   184  	}))
   185  	defer backend.Close()
   186  	backendURL, err := url.Parse(backend.URL)
   187  	if err != nil {
   188  		t.Fatal(err)
   189  	}
   190  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   191  	frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   192  		proxyHandler.ServeHTTP(w, r)
   193  		if c := r.Header.Get(someConnHeader); c != "should be deleted" {
   194  			t.Errorf("handler modified header %q = %q; want %q", someConnHeader, c, "should be deleted")
   195  		}
   196  		if c := r.Header.Get(fakeConnectionToken); c != "should be deleted" {
   197  			t.Errorf("handler modified header %q = %q; want %q", fakeConnectionToken, c, "should be deleted")
   198  		}
   199  		c := r.Header["Connection"]
   200  		var cf []string
   201  		for _, f := range c {
   202  			for sf := range strings.SplitSeq(f, ",") {
   203  				if sf = strings.TrimSpace(sf); sf != "" {
   204  					cf = append(cf, sf)
   205  				}
   206  			}
   207  		}
   208  		slices.Sort(cf)
   209  		expectedValues := []string{"Upgrade", someConnHeader, fakeConnectionToken}
   210  		slices.Sort(expectedValues)
   211  		if !slices.Equal(cf, expectedValues) {
   212  			t.Errorf("handler modified header %q = %q; want %q", "Connection", cf, expectedValues)
   213  		}
   214  	}))
   215  	defer frontend.Close()
   216  
   217  	getReq, _ := http.NewRequest("GET", frontend.URL, nil)
   218  	getReq.Header.Add("Connection", "Upgrade, "+fakeConnectionToken)
   219  	getReq.Header.Add("Connection", someConnHeader)
   220  	getReq.Header.Set(someConnHeader, "should be deleted")
   221  	getReq.Header.Set(fakeConnectionToken, "should be deleted")
   222  	res, err := frontend.Client().Do(getReq)
   223  	if err != nil {
   224  		t.Fatalf("Get: %v", err)
   225  	}
   226  	defer res.Body.Close()
   227  	bodyBytes, err := io.ReadAll(res.Body)
   228  	if err != nil {
   229  		t.Fatalf("reading body: %v", err)
   230  	}
   231  	if got, want := string(bodyBytes), backendResponse; got != want {
   232  		t.Errorf("got body %q; want %q", got, want)
   233  	}
   234  	if c := res.Header.Get("Connection"); c != "" {
   235  		t.Errorf("handler got header %q = %q; want empty", "Connection", c)
   236  	}
   237  	if c := res.Header.Get(someConnHeader); c != "" {
   238  		t.Errorf("handler got header %q = %q; want empty", someConnHeader, c)
   239  	}
   240  	if c := res.Header.Get(fakeConnectionToken); c != "" {
   241  		t.Errorf("handler got header %q = %q; want empty", fakeConnectionToken, c)
   242  	}
   243  }
   244  
   245  func TestReverseProxyStripEmptyConnection(t *testing.T) {
   246  	// See Issue 46313.
   247  	const backendResponse = "I am the backend"
   248  
   249  	// someConnHeader is some arbitrary header to be declared as a hop-by-hop header
   250  	// in the Request's Connection header.
   251  	const someConnHeader = "X-Some-Conn-Header"
   252  
   253  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   254  		if c := r.Header.Values("Connection"); len(c) != 0 {
   255  			t.Errorf("handler got header %q = %v; want empty", "Connection", c)
   256  		}
   257  		if c := r.Header.Get(someConnHeader); c != "" {
   258  			t.Errorf("handler got header %q = %q; want empty", someConnHeader, c)
   259  		}
   260  		w.Header().Add("Connection", "")
   261  		w.Header().Add("Connection", someConnHeader)
   262  		w.Header().Set(someConnHeader, "should be deleted")
   263  		io.WriteString(w, backendResponse)
   264  	}))
   265  	defer backend.Close()
   266  	backendURL, err := url.Parse(backend.URL)
   267  	if err != nil {
   268  		t.Fatal(err)
   269  	}
   270  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   271  	frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   272  		proxyHandler.ServeHTTP(w, r)
   273  		if c := r.Header.Get(someConnHeader); c != "should be deleted" {
   274  			t.Errorf("handler modified header %q = %q; want %q", someConnHeader, c, "should be deleted")
   275  		}
   276  	}))
   277  	defer frontend.Close()
   278  
   279  	getReq, _ := http.NewRequest("GET", frontend.URL, nil)
   280  	getReq.Header.Add("Connection", "")
   281  	getReq.Header.Add("Connection", someConnHeader)
   282  	getReq.Header.Set(someConnHeader, "should be deleted")
   283  	res, err := frontend.Client().Do(getReq)
   284  	if err != nil {
   285  		t.Fatalf("Get: %v", err)
   286  	}
   287  	defer res.Body.Close()
   288  	bodyBytes, err := io.ReadAll(res.Body)
   289  	if err != nil {
   290  		t.Fatalf("reading body: %v", err)
   291  	}
   292  	if got, want := string(bodyBytes), backendResponse; got != want {
   293  		t.Errorf("got body %q; want %q", got, want)
   294  	}
   295  	if c := res.Header.Get("Connection"); c != "" {
   296  		t.Errorf("handler got header %q = %q; want empty", "Connection", c)
   297  	}
   298  	if c := res.Header.Get(someConnHeader); c != "" {
   299  		t.Errorf("handler got header %q = %q; want empty", someConnHeader, c)
   300  	}
   301  }
   302  
   303  func TestXForwardedFor(t *testing.T) {
   304  	const prevForwardedFor = "client ip"
   305  	const backendResponse = "I am the backend"
   306  	const backendStatus = 404
   307  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   308  		if r.Header.Get("X-Forwarded-For") == "" {
   309  			t.Errorf("didn't get X-Forwarded-For header")
   310  		}
   311  		if !strings.Contains(r.Header.Get("X-Forwarded-For"), prevForwardedFor) {
   312  			t.Errorf("X-Forwarded-For didn't contain prior data")
   313  		}
   314  		w.WriteHeader(backendStatus)
   315  		w.Write([]byte(backendResponse))
   316  	}))
   317  	defer backend.Close()
   318  	backendURL, err := url.Parse(backend.URL)
   319  	if err != nil {
   320  		t.Fatal(err)
   321  	}
   322  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   323  	frontend := httptest.NewServer(proxyHandler)
   324  	defer frontend.Close()
   325  
   326  	getReq, _ := http.NewRequest("GET", frontend.URL, nil)
   327  	getReq.Header.Set("Connection", "close")
   328  	getReq.Header.Set("X-Forwarded-For", prevForwardedFor)
   329  	getReq.Close = true
   330  	res, err := frontend.Client().Do(getReq)
   331  	if err != nil {
   332  		t.Fatalf("Get: %v", err)
   333  	}
   334  	defer res.Body.Close()
   335  	if g, e := res.StatusCode, backendStatus; g != e {
   336  		t.Errorf("got res.StatusCode %d; expected %d", g, e)
   337  	}
   338  	bodyBytes, _ := io.ReadAll(res.Body)
   339  	if g, e := string(bodyBytes), backendResponse; g != e {
   340  		t.Errorf("got body %q; expected %q", g, e)
   341  	}
   342  }
   343  
   344  // Issue 38079: don't append to X-Forwarded-For if it's present but nil
   345  func TestXForwardedFor_Omit(t *testing.T) {
   346  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   347  		if v := r.Header.Get("X-Forwarded-For"); v != "" {
   348  			t.Errorf("got X-Forwarded-For header: %q", v)
   349  		}
   350  		w.Write([]byte("hi"))
   351  	}))
   352  	defer backend.Close()
   353  	backendURL, err := url.Parse(backend.URL)
   354  	if err != nil {
   355  		t.Fatal(err)
   356  	}
   357  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   358  	frontend := httptest.NewServer(proxyHandler)
   359  	defer frontend.Close()
   360  
   361  	oldDirector := proxyHandler.Director
   362  	proxyHandler.Director = func(r *http.Request) {
   363  		r.Header["X-Forwarded-For"] = nil
   364  		oldDirector(r)
   365  	}
   366  
   367  	getReq, _ := http.NewRequest("GET", frontend.URL, nil)
   368  	getReq.Host = "some-name"
   369  	getReq.Close = true
   370  	res, err := frontend.Client().Do(getReq)
   371  	if err != nil {
   372  		t.Fatalf("Get: %v", err)
   373  	}
   374  	res.Body.Close()
   375  }
   376  
   377  func TestReverseProxyRewriteStripsForwarded(t *testing.T) {
   378  	headers := []string{
   379  		"Forwarded",
   380  		"X-Forwarded-For",
   381  		"X-Forwarded-Host",
   382  		"X-Forwarded-Proto",
   383  	}
   384  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   385  		for _, h := range headers {
   386  			if v := r.Header.Get(h); v != "" {
   387  				t.Errorf("got %v header: %q", h, v)
   388  			}
   389  		}
   390  	}))
   391  	defer backend.Close()
   392  	backendURL, err := url.Parse(backend.URL)
   393  	if err != nil {
   394  		t.Fatal(err)
   395  	}
   396  	proxyHandler := &ReverseProxy{
   397  		Rewrite: func(r *ProxyRequest) {
   398  			r.SetURL(backendURL)
   399  		},
   400  	}
   401  	frontend := httptest.NewServer(proxyHandler)
   402  	defer frontend.Close()
   403  
   404  	getReq, _ := http.NewRequest("GET", frontend.URL, nil)
   405  	getReq.Host = "some-name"
   406  	getReq.Close = true
   407  	for _, h := range headers {
   408  		getReq.Header.Set(h, "x")
   409  	}
   410  	res, err := frontend.Client().Do(getReq)
   411  	if err != nil {
   412  		t.Fatalf("Get: %v", err)
   413  	}
   414  	res.Body.Close()
   415  }
   416  
   417  var proxyQueryTests = []struct {
   418  	baseSuffix string // suffix to add to backend URL
   419  	reqSuffix  string // suffix to add to frontend's request URL
   420  	want       string // what backend should see for final request URL (without ?)
   421  }{
   422  	{"", "", ""},
   423  	{"?sta=tic", "?us=er", "sta=tic&us=er"},
   424  	{"", "?us=er", "us=er"},
   425  	{"?sta=tic", "", "sta=tic"},
   426  }
   427  
   428  func TestReverseProxyQuery(t *testing.T) {
   429  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   430  		w.Header().Set("X-Got-Query", r.URL.RawQuery)
   431  		w.Write([]byte("hi"))
   432  	}))
   433  	defer backend.Close()
   434  
   435  	for i, tt := range proxyQueryTests {
   436  		backendURL, err := url.Parse(backend.URL + tt.baseSuffix)
   437  		if err != nil {
   438  			t.Fatal(err)
   439  		}
   440  		frontend := httptest.NewServer(NewSingleHostReverseProxy(backendURL))
   441  		req, _ := http.NewRequest("GET", frontend.URL+tt.reqSuffix, nil)
   442  		req.Close = true
   443  		res, err := frontend.Client().Do(req)
   444  		if err != nil {
   445  			t.Fatalf("%d. Get: %v", i, err)
   446  		}
   447  		if g, e := res.Header.Get("X-Got-Query"), tt.want; g != e {
   448  			t.Errorf("%d. got query %q; expected %q", i, g, e)
   449  		}
   450  		res.Body.Close()
   451  		frontend.Close()
   452  	}
   453  }
   454  
   455  func TestReverseProxyFlushInterval(t *testing.T) {
   456  	const expected = "hi"
   457  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   458  		w.Write([]byte(expected))
   459  	}))
   460  	defer backend.Close()
   461  
   462  	backendURL, err := url.Parse(backend.URL)
   463  	if err != nil {
   464  		t.Fatal(err)
   465  	}
   466  
   467  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   468  	proxyHandler.FlushInterval = time.Microsecond
   469  
   470  	frontend := httptest.NewServer(proxyHandler)
   471  	defer frontend.Close()
   472  
   473  	req, _ := http.NewRequest("GET", frontend.URL, nil)
   474  	req.Close = true
   475  	res, err := frontend.Client().Do(req)
   476  	if err != nil {
   477  		t.Fatalf("Get: %v", err)
   478  	}
   479  	defer res.Body.Close()
   480  	if bodyBytes, _ := io.ReadAll(res.Body); string(bodyBytes) != expected {
   481  		t.Errorf("got body %q; expected %q", bodyBytes, expected)
   482  	}
   483  }
   484  
   485  type mockFlusher struct {
   486  	http.ResponseWriter
   487  	flushed bool
   488  }
   489  
   490  func (m *mockFlusher) Flush() {
   491  	m.flushed = true
   492  }
   493  
   494  type wrappedRW struct {
   495  	http.ResponseWriter
   496  }
   497  
   498  func (w *wrappedRW) Unwrap() http.ResponseWriter {
   499  	return w.ResponseWriter
   500  }
   501  
   502  func TestReverseProxyResponseControllerFlushInterval(t *testing.T) {
   503  	const expected = "hi"
   504  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   505  		w.Write([]byte(expected))
   506  	}))
   507  	defer backend.Close()
   508  
   509  	backendURL, err := url.Parse(backend.URL)
   510  	if err != nil {
   511  		t.Fatal(err)
   512  	}
   513  
   514  	mf := &mockFlusher{}
   515  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   516  	proxyHandler.FlushInterval = -1 // flush immediately
   517  	proxyWithMiddleware := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   518  		mf.ResponseWriter = w
   519  		w = &wrappedRW{mf}
   520  		proxyHandler.ServeHTTP(w, r)
   521  	})
   522  
   523  	frontend := httptest.NewServer(proxyWithMiddleware)
   524  	defer frontend.Close()
   525  
   526  	req, _ := http.NewRequest("GET", frontend.URL, nil)
   527  	req.Close = true
   528  	res, err := frontend.Client().Do(req)
   529  	if err != nil {
   530  		t.Fatalf("Get: %v", err)
   531  	}
   532  	defer res.Body.Close()
   533  	if bodyBytes, _ := io.ReadAll(res.Body); string(bodyBytes) != expected {
   534  		t.Errorf("got body %q; expected %q", bodyBytes, expected)
   535  	}
   536  	if !mf.flushed {
   537  		t.Errorf("response writer was not flushed")
   538  	}
   539  }
   540  
   541  func TestReverseProxyFlushIntervalHeaders(t *testing.T) {
   542  	const expected = "hi"
   543  	stopCh := make(chan struct{})
   544  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   545  		w.Header().Add("MyHeader", expected)
   546  		w.WriteHeader(200)
   547  		w.(http.Flusher).Flush()
   548  		<-stopCh
   549  	}))
   550  	defer backend.Close()
   551  	defer close(stopCh)
   552  
   553  	backendURL, err := url.Parse(backend.URL)
   554  	if err != nil {
   555  		t.Fatal(err)
   556  	}
   557  
   558  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   559  	proxyHandler.FlushInterval = time.Microsecond
   560  
   561  	frontend := httptest.NewServer(proxyHandler)
   562  	defer frontend.Close()
   563  
   564  	req, _ := http.NewRequest("GET", frontend.URL, nil)
   565  	req.Close = true
   566  
   567  	ctx, cancel := context.WithTimeout(req.Context(), 10*time.Second)
   568  	defer cancel()
   569  	req = req.WithContext(ctx)
   570  
   571  	res, err := frontend.Client().Do(req)
   572  	if err != nil {
   573  		t.Fatalf("Get: %v", err)
   574  	}
   575  	defer res.Body.Close()
   576  
   577  	if res.Header.Get("MyHeader") != expected {
   578  		t.Errorf("got header %q; expected %q", res.Header.Get("MyHeader"), expected)
   579  	}
   580  }
   581  
   582  func TestReverseProxyCancellation(t *testing.T) {
   583  	const backendResponse = "I am the backend"
   584  
   585  	reqInFlight := make(chan struct{})
   586  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   587  		close(reqInFlight) // cause the client to cancel its request
   588  
   589  		select {
   590  		case <-time.After(10 * time.Second):
   591  			// Note: this should only happen in broken implementations, and the
   592  			// closenotify case should be instantaneous.
   593  			t.Error("Handler never saw CloseNotify")
   594  			return
   595  		case <-w.(http.CloseNotifier).CloseNotify():
   596  		}
   597  
   598  		w.WriteHeader(http.StatusOK)
   599  		w.Write([]byte(backendResponse))
   600  	}))
   601  
   602  	defer backend.Close()
   603  
   604  	backend.Config.ErrorLog = log.New(io.Discard, "", 0)
   605  
   606  	backendURL, err := url.Parse(backend.URL)
   607  	if err != nil {
   608  		t.Fatal(err)
   609  	}
   610  
   611  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   612  
   613  	// Discards errors of the form:
   614  	// http: proxy error: read tcp 127.0.0.1:44643: use of closed network connection
   615  	proxyHandler.ErrorLog = log.New(io.Discard, "", 0)
   616  
   617  	frontend := httptest.NewServer(proxyHandler)
   618  	defer frontend.Close()
   619  	frontendClient := frontend.Client()
   620  
   621  	getReq, _ := http.NewRequest("GET", frontend.URL, nil)
   622  	go func() {
   623  		<-reqInFlight
   624  		frontendClient.Transport.(*http.Transport).CancelRequest(getReq)
   625  	}()
   626  	res, err := frontendClient.Do(getReq)
   627  	if res != nil {
   628  		t.Errorf("got response %v; want nil", res.Status)
   629  	}
   630  	if err == nil {
   631  		// This should be an error like:
   632  		// Get "http://127.0.0.1:58079": read tcp 127.0.0.1:58079:
   633  		//    use of closed network connection
   634  		t.Error("Server.Client().Do() returned nil error; want non-nil error")
   635  	}
   636  }
   637  
   638  func req(t *testing.T, v string) *http.Request {
   639  	req, err := http.ReadRequest(bufio.NewReader(strings.NewReader(v)))
   640  	if err != nil {
   641  		t.Fatal(err)
   642  	}
   643  	return req
   644  }
   645  
   646  // Issue 12344
   647  func TestNilBody(t *testing.T) {
   648  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   649  		w.Write([]byte("hi"))
   650  	}))
   651  	defer backend.Close()
   652  
   653  	frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
   654  		backURL, _ := url.Parse(backend.URL)
   655  		rp := NewSingleHostReverseProxy(backURL)
   656  		r := req(t, "GET / HTTP/1.0\r\n\r\n")
   657  		r.Body = nil // this accidentally worked in Go 1.4 and below, so keep it working
   658  		rp.ServeHTTP(w, r)
   659  	}))
   660  	defer frontend.Close()
   661  
   662  	res, err := http.Get(frontend.URL)
   663  	if err != nil {
   664  		t.Fatal(err)
   665  	}
   666  	defer res.Body.Close()
   667  	slurp, err := io.ReadAll(res.Body)
   668  	if err != nil {
   669  		t.Fatal(err)
   670  	}
   671  	if string(slurp) != "hi" {
   672  		t.Errorf("Got %q; want %q", slurp, "hi")
   673  	}
   674  }
   675  
   676  // Issue 15524
   677  func TestUserAgentHeader(t *testing.T) {
   678  	var gotUA string
   679  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   680  		gotUA = r.Header.Get("User-Agent")
   681  	}))
   682  	defer backend.Close()
   683  	backendURL, err := url.Parse(backend.URL)
   684  	if err != nil {
   685  		t.Fatal(err)
   686  	}
   687  
   688  	proxyHandler := new(ReverseProxy)
   689  	proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
   690  	proxyHandler.Director = func(req *http.Request) {
   691  		req.URL = backendURL
   692  	}
   693  	frontend := httptest.NewServer(proxyHandler)
   694  	defer frontend.Close()
   695  	frontendClient := frontend.Client()
   696  
   697  	for _, sentUA := range []string{"explicit UA", ""} {
   698  		getReq, _ := http.NewRequest("GET", frontend.URL, nil)
   699  		getReq.Header.Set("User-Agent", sentUA)
   700  		getReq.Close = true
   701  		res, err := frontendClient.Do(getReq)
   702  		if err != nil {
   703  			t.Fatalf("Get: %v", err)
   704  		}
   705  		res.Body.Close()
   706  		if got, want := gotUA, sentUA; got != want {
   707  			t.Errorf("got forwarded User-Agent %q, want %q", got, want)
   708  		}
   709  	}
   710  }
   711  
   712  type bufferPool struct {
   713  	get func() []byte
   714  	put func([]byte)
   715  }
   716  
   717  func (bp bufferPool) Get() []byte  { return bp.get() }
   718  func (bp bufferPool) Put(v []byte) { bp.put(v) }
   719  
   720  func TestReverseProxyGetPutBuffer(t *testing.T) {
   721  	const msg = "hi"
   722  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   723  		io.WriteString(w, msg)
   724  	}))
   725  	defer backend.Close()
   726  
   727  	backendURL, err := url.Parse(backend.URL)
   728  	if err != nil {
   729  		t.Fatal(err)
   730  	}
   731  
   732  	var (
   733  		mu  sync.Mutex
   734  		log []string
   735  	)
   736  	addLog := func(event string) {
   737  		mu.Lock()
   738  		defer mu.Unlock()
   739  		log = append(log, event)
   740  	}
   741  	rp := NewSingleHostReverseProxy(backendURL)
   742  	const size = 1234
   743  	rp.BufferPool = bufferPool{
   744  		get: func() []byte {
   745  			addLog("getBuf")
   746  			return make([]byte, size)
   747  		},
   748  		put: func(p []byte) {
   749  			addLog("putBuf-" + strconv.Itoa(len(p)))
   750  		},
   751  	}
   752  	frontend := httptest.NewServer(rp)
   753  	defer frontend.Close()
   754  
   755  	req, _ := http.NewRequest("GET", frontend.URL, nil)
   756  	req.Close = true
   757  	res, err := frontend.Client().Do(req)
   758  	if err != nil {
   759  		t.Fatalf("Get: %v", err)
   760  	}
   761  	slurp, err := io.ReadAll(res.Body)
   762  	res.Body.Close()
   763  	if err != nil {
   764  		t.Fatalf("reading body: %v", err)
   765  	}
   766  	if string(slurp) != msg {
   767  		t.Errorf("msg = %q; want %q", slurp, msg)
   768  	}
   769  	wantLog := []string{"getBuf", "putBuf-" + strconv.Itoa(size)}
   770  	mu.Lock()
   771  	defer mu.Unlock()
   772  	if !slices.Equal(log, wantLog) {
   773  		t.Errorf("Log events = %q; want %q", log, wantLog)
   774  	}
   775  }
   776  
   777  func TestReverseProxy_Post(t *testing.T) {
   778  	const backendResponse = "I am the backend"
   779  	const backendStatus = 200
   780  	var requestBody = bytes.Repeat([]byte("a"), 1<<20)
   781  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   782  		slurp, err := io.ReadAll(r.Body)
   783  		if err != nil {
   784  			t.Errorf("Backend body read = %v", err)
   785  		}
   786  		if len(slurp) != len(requestBody) {
   787  			t.Errorf("Backend read %d request body bytes; want %d", len(slurp), len(requestBody))
   788  		}
   789  		if !bytes.Equal(slurp, requestBody) {
   790  			t.Error("Backend read wrong request body.") // 1MB; omitting details
   791  		}
   792  		w.Write([]byte(backendResponse))
   793  	}))
   794  	defer backend.Close()
   795  	backendURL, err := url.Parse(backend.URL)
   796  	if err != nil {
   797  		t.Fatal(err)
   798  	}
   799  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   800  	frontend := httptest.NewServer(proxyHandler)
   801  	defer frontend.Close()
   802  
   803  	postReq, _ := http.NewRequest("POST", frontend.URL, bytes.NewReader(requestBody))
   804  	res, err := frontend.Client().Do(postReq)
   805  	if err != nil {
   806  		t.Fatalf("Do: %v", err)
   807  	}
   808  	defer res.Body.Close()
   809  	if g, e := res.StatusCode, backendStatus; g != e {
   810  		t.Errorf("got res.StatusCode %d; expected %d", g, e)
   811  	}
   812  	bodyBytes, _ := io.ReadAll(res.Body)
   813  	if g, e := string(bodyBytes), backendResponse; g != e {
   814  		t.Errorf("got body %q; expected %q", g, e)
   815  	}
   816  }
   817  
   818  type RoundTripperFunc func(*http.Request) (*http.Response, error)
   819  
   820  func (fn RoundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
   821  	return fn(req)
   822  }
   823  
   824  // Issue 16036: send a Request with a nil Body when possible
   825  func TestReverseProxy_NilBody(t *testing.T) {
   826  	backendURL, _ := url.Parse("http://fake.tld/")
   827  	proxyHandler := NewSingleHostReverseProxy(backendURL)
   828  	proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
   829  	proxyHandler.Transport = RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
   830  		if req.Body != nil {
   831  			t.Error("Body != nil; want a nil Body")
   832  		}
   833  		return nil, errors.New("done testing the interesting part; so force a 502 Gateway error")
   834  	})
   835  	frontend := httptest.NewServer(proxyHandler)
   836  	defer frontend.Close()
   837  
   838  	res, err := frontend.Client().Get(frontend.URL)
   839  	if err != nil {
   840  		t.Fatal(err)
   841  	}
   842  	defer res.Body.Close()
   843  	if res.StatusCode != 502 {
   844  		t.Errorf("status code = %v; want 502 (Gateway Error)", res.Status)
   845  	}
   846  }
   847  
   848  // Issue 33142: always allocate the request headers
   849  func TestReverseProxy_AllocatedHeader(t *testing.T) {
   850  	proxyHandler := new(ReverseProxy)
   851  	proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
   852  	proxyHandler.Director = func(*http.Request) {}     // noop
   853  	proxyHandler.Transport = RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
   854  		if req.Header == nil {
   855  			t.Error("Header == nil; want a non-nil Header")
   856  		}
   857  		return nil, errors.New("done testing the interesting part; so force a 502 Gateway error")
   858  	})
   859  
   860  	proxyHandler.ServeHTTP(httptest.NewRecorder(), &http.Request{
   861  		Method:     "GET",
   862  		URL:        &url.URL{Scheme: "http", Host: "fake.tld", Path: "/"},
   863  		Proto:      "HTTP/1.0",
   864  		ProtoMajor: 1,
   865  	})
   866  }
   867  
   868  // Issue 14237. Test ModifyResponse and that an error from it
   869  // causes the proxy to return StatusBadGateway, or StatusOK otherwise.
   870  func TestReverseProxyModifyResponse(t *testing.T) {
   871  	backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   872  		w.Header().Add("X-Hit-Mod", fmt.Sprintf("%v", r.URL.Path == "/mod"))
   873  	}))
   874  	defer backendServer.Close()
   875  
   876  	rpURL, _ := url.Parse(backendServer.URL)
   877  	rproxy := NewSingleHostReverseProxy(rpURL)
   878  	rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
   879  	rproxy.ModifyResponse = func(resp *http.Response) error {
   880  		if resp.Header.Get("X-Hit-Mod") != "true" {
   881  			return fmt.Errorf("tried to by-pass proxy")
   882  		}
   883  		return nil
   884  	}
   885  
   886  	frontendProxy := httptest.NewServer(rproxy)
   887  	defer frontendProxy.Close()
   888  
   889  	tests := []struct {
   890  		url      string
   891  		wantCode int
   892  	}{
   893  		{frontendProxy.URL + "/mod", http.StatusOK},
   894  		{frontendProxy.URL + "/schedule", http.StatusBadGateway},
   895  	}
   896  
   897  	for i, tt := range tests {
   898  		resp, err := http.Get(tt.url)
   899  		if err != nil {
   900  			t.Fatalf("failed to reach proxy: %v", err)
   901  		}
   902  		if g, e := resp.StatusCode, tt.wantCode; g != e {
   903  			t.Errorf("#%d: got res.StatusCode %d; expected %d", i, g, e)
   904  		}
   905  		resp.Body.Close()
   906  	}
   907  }
   908  
   909  type failingRoundTripper struct{}
   910  
   911  func (failingRoundTripper) RoundTrip(*http.Request) (*http.Response, error) {
   912  	return nil, errors.New("some error")
   913  }
   914  
   915  type staticResponseRoundTripper struct{ res *http.Response }
   916  
   917  func (rt staticResponseRoundTripper) RoundTrip(*http.Request) (*http.Response, error) {
   918  	return rt.res, nil
   919  }
   920  
   921  func TestReverseProxyErrorHandler(t *testing.T) {
   922  	tests := []struct {
   923  		name           string
   924  		wantCode       int
   925  		errorHandler   func(http.ResponseWriter, *http.Request, error)
   926  		transport      http.RoundTripper // defaults to failingRoundTripper
   927  		modifyResponse func(*http.Response) error
   928  	}{
   929  		{
   930  			name:     "default",
   931  			wantCode: http.StatusBadGateway,
   932  		},
   933  		{
   934  			name:         "errorhandler",
   935  			wantCode:     http.StatusTeapot,
   936  			errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) },
   937  		},
   938  		{
   939  			name: "modifyresponse_noerr",
   940  			transport: staticResponseRoundTripper{
   941  				&http.Response{StatusCode: 345, Body: http.NoBody},
   942  			},
   943  			modifyResponse: func(res *http.Response) error {
   944  				res.StatusCode++
   945  				return nil
   946  			},
   947  			errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) },
   948  			wantCode:     346,
   949  		},
   950  		{
   951  			name: "modifyresponse_err",
   952  			transport: staticResponseRoundTripper{
   953  				&http.Response{StatusCode: 345, Body: http.NoBody},
   954  			},
   955  			modifyResponse: func(res *http.Response) error {
   956  				res.StatusCode++
   957  				return errors.New("some error to trigger errorHandler")
   958  			},
   959  			errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) },
   960  			wantCode:     http.StatusTeapot,
   961  		},
   962  	}
   963  
   964  	for _, tt := range tests {
   965  		t.Run(tt.name, func(t *testing.T) {
   966  			target := &url.URL{
   967  				Scheme: "http",
   968  				Host:   "dummy.tld",
   969  				Path:   "/",
   970  			}
   971  			rproxy := NewSingleHostReverseProxy(target)
   972  			rproxy.Transport = tt.transport
   973  			rproxy.ModifyResponse = tt.modifyResponse
   974  			if rproxy.Transport == nil {
   975  				rproxy.Transport = failingRoundTripper{}
   976  			}
   977  			rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
   978  			if tt.errorHandler != nil {
   979  				rproxy.ErrorHandler = tt.errorHandler
   980  			}
   981  			frontendProxy := httptest.NewServer(rproxy)
   982  			defer frontendProxy.Close()
   983  
   984  			resp, err := http.Get(frontendProxy.URL + "/test")
   985  			if err != nil {
   986  				t.Fatalf("failed to reach proxy: %v", err)
   987  			}
   988  			if g, e := resp.StatusCode, tt.wantCode; g != e {
   989  				t.Errorf("got res.StatusCode %d; expected %d", g, e)
   990  			}
   991  			resp.Body.Close()
   992  		})
   993  	}
   994  }
   995  
   996  // Issue 16659: log errors from short read
   997  func TestReverseProxy_CopyBuffer(t *testing.T) {
   998  	backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   999  		out := "this call was relayed by the reverse proxy"
  1000  		// Coerce a wrong content length to induce io.UnexpectedEOF
  1001  		w.Header().Set("Content-Length", fmt.Sprintf("%d", len(out)*2))
  1002  		fmt.Fprintln(w, out)
  1003  	}))
  1004  	defer backendServer.Close()
  1005  
  1006  	rpURL, err := url.Parse(backendServer.URL)
  1007  	if err != nil {
  1008  		t.Fatal(err)
  1009  	}
  1010  
  1011  	var proxyLog bytes.Buffer
  1012  	rproxy := NewSingleHostReverseProxy(rpURL)
  1013  	rproxy.ErrorLog = log.New(&proxyLog, "", log.Lshortfile)
  1014  	donec := make(chan bool, 1)
  1015  	frontendProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  1016  		defer func() { donec <- true }()
  1017  		rproxy.ServeHTTP(w, r)
  1018  	}))
  1019  	defer frontendProxy.Close()
  1020  
  1021  	if _, err = frontendProxy.Client().Get(frontendProxy.URL); err == nil {
  1022  		t.Fatalf("want non-nil error")
  1023  	}
  1024  	// The race detector complains about the proxyLog usage in logf in copyBuffer
  1025  	// and our usage below with proxyLog.Bytes() so we're explicitly using a
  1026  	// channel to ensure that the ReverseProxy's ServeHTTP is done before we
  1027  	// continue after Get.
  1028  	<-donec
  1029  
  1030  	expected := []string{
  1031  		"EOF",
  1032  		"read",
  1033  	}
  1034  	for _, phrase := range expected {
  1035  		if !bytes.Contains(proxyLog.Bytes(), []byte(phrase)) {
  1036  			t.Errorf("expected log to contain phrase %q", phrase)
  1037  		}
  1038  	}
  1039  }
  1040  
  1041  type staticTransport struct {
  1042  	res *http.Response
  1043  }
  1044  
  1045  func (t *staticTransport) RoundTrip(r *http.Request) (*http.Response, error) {
  1046  	return t.res, nil
  1047  }
  1048  
  1049  func BenchmarkServeHTTP(b *testing.B) {
  1050  	res := &http.Response{
  1051  		StatusCode: 200,
  1052  		Body:       io.NopCloser(strings.NewReader("")),
  1053  	}
  1054  	proxy := &ReverseProxy{
  1055  		Director:  func(*http.Request) {},
  1056  		Transport: &staticTransport{res},
  1057  	}
  1058  
  1059  	w := httptest.NewRecorder()
  1060  	r := httptest.NewRequest("GET", "/", nil)
  1061  
  1062  	b.ReportAllocs()
  1063  	for i := 0; i < b.N; i++ {
  1064  		proxy.ServeHTTP(w, r)
  1065  	}
  1066  }
  1067  
  1068  func TestServeHTTPDeepCopy(t *testing.T) {
  1069  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  1070  		w.Write([]byte("Hello Gopher!"))
  1071  	}))
  1072  	defer backend.Close()
  1073  	backendURL, err := url.Parse(backend.URL)
  1074  	if err != nil {
  1075  		t.Fatal(err)
  1076  	}
  1077  
  1078  	type result struct {
  1079  		before, after string
  1080  	}
  1081  
  1082  	resultChan := make(chan result, 1)
  1083  	proxyHandler := NewSingleHostReverseProxy(backendURL)
  1084  	frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  1085  		before := r.URL.String()
  1086  		proxyHandler.ServeHTTP(w, r)
  1087  		after := r.URL.String()
  1088  		resultChan <- result{before: before, after: after}
  1089  	}))
  1090  	defer frontend.Close()
  1091  
  1092  	want := result{before: "/", after: "/"}
  1093  
  1094  	res, err := frontend.Client().Get(frontend.URL)
  1095  	if err != nil {
  1096  		t.Fatalf("Do: %v", err)
  1097  	}
  1098  	res.Body.Close()
  1099  
  1100  	got := <-resultChan
  1101  	if got != want {
  1102  		t.Errorf("got = %+v; want = %+v", got, want)
  1103  	}
  1104  }
  1105  
  1106  // Issue 18327: verify we always do a deep copy of the Request.Header map
  1107  // before any mutations.
  1108  func TestClonesRequestHeaders(t *testing.T) {
  1109  	log.SetOutput(io.Discard)
  1110  	defer log.SetOutput(os.Stderr)
  1111  	req, _ := http.NewRequest("GET", "http://foo.tld/", nil)
  1112  	req.RemoteAddr = "1.2.3.4:56789"
  1113  	rp := &ReverseProxy{
  1114  		Director: func(req *http.Request) {
  1115  			req.Header.Set("From-Director", "1")
  1116  		},
  1117  		Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) {
  1118  			if v := req.Header.Get("From-Director"); v != "1" {
  1119  				t.Errorf("From-Directory value = %q; want 1", v)
  1120  			}
  1121  			return nil, io.EOF
  1122  		}),
  1123  	}
  1124  	rp.ServeHTTP(httptest.NewRecorder(), req)
  1125  
  1126  	for _, h := range []string{
  1127  		"From-Director",
  1128  		"X-Forwarded-For",
  1129  	} {
  1130  		if req.Header.Get(h) != "" {
  1131  			t.Errorf("%v header mutation modified caller's request", h)
  1132  		}
  1133  	}
  1134  }
  1135  
  1136  type roundTripperFunc func(req *http.Request) (*http.Response, error)
  1137  
  1138  func (fn roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
  1139  	return fn(req)
  1140  }
  1141  
  1142  func TestModifyResponseClosesBody(t *testing.T) {
  1143  	req, _ := http.NewRequest("GET", "http://foo.tld/", nil)
  1144  	req.RemoteAddr = "1.2.3.4:56789"
  1145  	closeCheck := new(checkCloser)
  1146  	logBuf := new(strings.Builder)
  1147  	outErr := errors.New("ModifyResponse error")
  1148  	rp := &ReverseProxy{
  1149  		Director: func(req *http.Request) {},
  1150  		Transport: &staticTransport{&http.Response{
  1151  			StatusCode: 200,
  1152  			Body:       closeCheck,
  1153  		}},
  1154  		ErrorLog: log.New(logBuf, "", 0),
  1155  		ModifyResponse: func(*http.Response) error {
  1156  			return outErr
  1157  		},
  1158  	}
  1159  	rec := httptest.NewRecorder()
  1160  	rp.ServeHTTP(rec, req)
  1161  	res := rec.Result()
  1162  	if g, e := res.StatusCode, http.StatusBadGateway; g != e {
  1163  		t.Errorf("got res.StatusCode %d; expected %d", g, e)
  1164  	}
  1165  	if !closeCheck.closed {
  1166  		t.Errorf("body should have been closed")
  1167  	}
  1168  	if g, e := logBuf.String(), outErr.Error(); !strings.Contains(g, e) {
  1169  		t.Errorf("ErrorLog %q does not contain %q", g, e)
  1170  	}
  1171  }
  1172  
  1173  type checkCloser struct {
  1174  	closed bool
  1175  }
  1176  
  1177  func (cc *checkCloser) Close() error {
  1178  	cc.closed = true
  1179  	return nil
  1180  }
  1181  
  1182  func (cc *checkCloser) Read(b []byte) (int, error) {
  1183  	return len(b), nil
  1184  }
  1185  
  1186  // Issue 23643: panic on body copy error
  1187  func TestReverseProxy_PanicBodyError(t *testing.T) {
  1188  	log.SetOutput(io.Discard)
  1189  	defer log.SetOutput(os.Stderr)
  1190  	backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  1191  		out := "this call was relayed by the reverse proxy"
  1192  		// Coerce a wrong content length to induce io.ErrUnexpectedEOF
  1193  		w.Header().Set("Content-Length", fmt.Sprintf("%d", len(out)*2))
  1194  		fmt.Fprintln(w, out)
  1195  	}))
  1196  	defer backendServer.Close()
  1197  
  1198  	rpURL, err := url.Parse(backendServer.URL)
  1199  	if err != nil {
  1200  		t.Fatal(err)
  1201  	}
  1202  
  1203  	rproxy := NewSingleHostReverseProxy(rpURL)
  1204  
  1205  	// Ensure that the handler panics when the body read encounters an
  1206  	// io.ErrUnexpectedEOF
  1207  	defer func() {
  1208  		err := recover()
  1209  		if err == nil {
  1210  			t.Fatal("handler should have panicked")
  1211  		}
  1212  		if err != http.ErrAbortHandler {
  1213  			t.Fatal("expected ErrAbortHandler, got", err)
  1214  		}
  1215  	}()
  1216  	req, _ := http.NewRequest("GET", "http://foo.tld/", nil)
  1217  	rproxy.ServeHTTP(httptest.NewRecorder(), req)
  1218  }
  1219  
  1220  // Issue #46866: panic without closing incoming request body causes a panic
  1221  func TestReverseProxy_PanicClosesIncomingBody(t *testing.T) {
  1222  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  1223  		out := "this call was relayed by the reverse proxy"
  1224  		// Coerce a wrong content length to induce io.ErrUnexpectedEOF
  1225  		w.Header().Set("Content-Length", fmt.Sprintf("%d", len(out)*2))
  1226  		fmt.Fprintln(w, out)
  1227  	}))
  1228  	defer backend.Close()
  1229  	backendURL, err := url.Parse(backend.URL)
  1230  	if err != nil {
  1231  		t.Fatal(err)
  1232  	}
  1233  	proxyHandler := NewSingleHostReverseProxy(backendURL)
  1234  	proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
  1235  	frontend := httptest.NewServer(proxyHandler)
  1236  	defer frontend.Close()
  1237  	frontendClient := frontend.Client()
  1238  
  1239  	var wg sync.WaitGroup
  1240  	for i := 0; i < 2; i++ {
  1241  		wg.Add(1)
  1242  		go func() {
  1243  			defer wg.Done()
  1244  			for j := 0; j < 10; j++ {
  1245  				const reqLen = 6 * 1024 * 1024
  1246  				req, _ := http.NewRequest("POST", frontend.URL, &io.LimitedReader{R: neverEnding('x'), N: reqLen})
  1247  				req.ContentLength = reqLen
  1248  				resp, _ := frontendClient.Transport.RoundTrip(req)
  1249  				if resp != nil {
  1250  					io.Copy(io.Discard, resp.Body)
  1251  					resp.Body.Close()
  1252  				}
  1253  			}
  1254  		}()
  1255  	}
  1256  	wg.Wait()
  1257  }
  1258  
  1259  func TestSelectFlushInterval(t *testing.T) {
  1260  	tests := []struct {
  1261  		name string
  1262  		p    *ReverseProxy
  1263  		res  *http.Response
  1264  		want time.Duration
  1265  	}{
  1266  		{
  1267  			name: "default",
  1268  			res:  &http.Response{},
  1269  			p:    &ReverseProxy{FlushInterval: 123},
  1270  			want: 123,
  1271  		},
  1272  		{
  1273  			name: "server-sent events overrides non-zero",
  1274  			res: &http.Response{
  1275  				Header: http.Header{
  1276  					"Content-Type": {"text/event-stream"},
  1277  				},
  1278  			},
  1279  			p:    &ReverseProxy{FlushInterval: 123},
  1280  			want: -1,
  1281  		},
  1282  		{
  1283  			name: "server-sent events overrides zero",
  1284  			res: &http.Response{
  1285  				Header: http.Header{
  1286  					"Content-Type": {"text/event-stream"},
  1287  				},
  1288  			},
  1289  			p:    &ReverseProxy{FlushInterval: 0},
  1290  			want: -1,
  1291  		},
  1292  		{
  1293  			name: "server-sent events with media-type parameters overrides non-zero",
  1294  			res: &http.Response{
  1295  				Header: http.Header{
  1296  					"Content-Type": {"text/event-stream;charset=utf-8"},
  1297  				},
  1298  			},
  1299  			p:    &ReverseProxy{FlushInterval: 123},
  1300  			want: -1,
  1301  		},
  1302  		{
  1303  			name: "server-sent events with media-type parameters overrides zero",
  1304  			res: &http.Response{
  1305  				Header: http.Header{
  1306  					"Content-Type": {"text/event-stream;charset=utf-8"},
  1307  				},
  1308  			},
  1309  			p:    &ReverseProxy{FlushInterval: 0},
  1310  			want: -1,
  1311  		},
  1312  		{
  1313  			name: "Content-Length: -1, overrides non-zero",
  1314  			res: &http.Response{
  1315  				ContentLength: -1,
  1316  			},
  1317  			p:    &ReverseProxy{FlushInterval: 123},
  1318  			want: -1,
  1319  		},
  1320  		{
  1321  			name: "Content-Length: -1, overrides zero",
  1322  			res: &http.Response{
  1323  				ContentLength: -1,
  1324  			},
  1325  			p:    &ReverseProxy{FlushInterval: 0},
  1326  			want: -1,
  1327  		},
  1328  	}
  1329  	for _, tt := range tests {
  1330  		t.Run(tt.name, func(t *testing.T) {
  1331  			got := tt.p.flushInterval(tt.res)
  1332  			if got != tt.want {
  1333  				t.Errorf("flushLatency = %v; want %v", got, tt.want)
  1334  			}
  1335  		})
  1336  	}
  1337  }
  1338  
  1339  func TestReverseProxyWebSocket(t *testing.T) {
  1340  	backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  1341  		if upgradeType(r.Header) != "websocket" {
  1342  			t.Error("unexpected backend request")
  1343  			http.Error(w, "unexpected request", 400)
  1344  			return
  1345  		}
  1346  		c, _, err := w.(http.Hijacker).Hijack()
  1347  		if err != nil {
  1348  			t.Error(err)
  1349  			return
  1350  		}
  1351  		defer c.Close()
  1352  		io.WriteString(c, "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: WebSocket\r\n\r\n")
  1353  		bs := bufio.NewScanner(c)
  1354  		if !bs.Scan() {
  1355  			t.Errorf("backend failed to read line from client: %v", bs.Err())
  1356  			return
  1357  		}
  1358  		fmt.Fprintf(c, "backend got %q\n", bs.Text())
  1359  	}))
  1360  	defer backendServer.Close()
  1361  
  1362  	backURL, _ := url.Parse(backendServer.URL)
  1363  	rproxy := NewSingleHostReverseProxy(backURL)
  1364  	rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
  1365  	rproxy.ModifyResponse = func(res *http.Response) error {
  1366  		res.Header.Add("X-Modified", "true")
  1367  		return nil
  1368  	}
  1369  
  1370  	handler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
  1371  		rw.Header().Set("X-Header", "X-Value")
  1372  		rproxy.ServeHTTP(rw, req)
  1373  		if got, want := rw.Header().Get("X-Modified"), "true"; got != want {
  1374  			t.Errorf("response writer X-Modified header = %q; want %q", got, want)
  1375  		}
  1376  	})
  1377  
  1378  	frontendProxy := httptest.NewServer(handler)
  1379  	defer frontendProxy.Close()
  1380  
  1381  	req, _ := http.NewRequest("GET", frontendProxy.URL, nil)
  1382  	req.Header.Set("Connection", "Upgrade")
  1383  	req.Header.Set("Upgrade", "websocket")
  1384  
  1385  	c := frontendProxy.Client()
  1386  	res, err := c.Do(req)
  1387  	if err != nil {
  1388  		t.Fatal(err)
  1389  	}
  1390  	if res.StatusCode != 101 {
  1391  		t.Fatalf("status = %v; want 101", res.Status)
  1392  	}
  1393  
  1394  	got := res.Header.Get("X-Header")
  1395  	want := "X-Value"
  1396  	if got != want {
  1397  		t.Errorf("Header(XHeader) = %q; want %q", got, want)
  1398  	}
  1399  
  1400  	if !ascii.EqualFold(upgradeType(res.Header), "websocket") {
  1401  		t.Fatalf("not websocket upgrade; got %#v", res.Header)
  1402  	}
  1403  	rwc, ok := res.Body.(io.ReadWriteCloser)
  1404  	if !ok {
  1405  		t.Fatalf("response body is of type %T; does not implement ReadWriteCloser", res.Body)
  1406  	}
  1407  	defer rwc.Close()
  1408  
  1409  	if got, want := res.Header.Get("X-Modified"), "true"; got != want {
  1410  		t.Errorf("response X-Modified header = %q; want %q", got, want)
  1411  	}
  1412  
  1413  	io.WriteString(rwc, "Hello\n")
  1414  	bs := bufio.NewScanner(rwc)
  1415  	if !bs.Scan() {
  1416  		t.Fatalf("Scan: %v", bs.Err())
  1417  	}
  1418  	got = bs.Text()
  1419  	want = `backend got "Hello"`
  1420  	if got != want {
  1421  		t.Errorf("got %#q, want %#q", got, want)
  1422  	}
  1423  }
  1424  
  1425  func TestReverseProxyWebSocketCancellation(t *testing.T) {
  1426  	n := 5
  1427  	triggerCancelCh := make(chan bool, n)
  1428  	nthResponse := func(i int) string {
  1429  		return fmt.Sprintf("backend response #%d\n", i)
  1430  	}
  1431  	terminalMsg := "final message"
  1432  
  1433  	cst := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  1434  		if g, ws := upgradeType(r.Header), "websocket"; g != ws {
  1435  			t.Errorf("Unexpected upgrade type %q, want %q", g, ws)
  1436  			http.Error(w, "Unexpected request", 400)
  1437  			return
  1438  		}
  1439  		conn, bufrw, err := w.(http.Hijacker).Hijack()
  1440  		if err != nil {
  1441  			t.Error(err)
  1442  			return
  1443  		}
  1444  		defer conn.Close()
  1445  
  1446  		upgradeMsg := "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: WebSocket\r\n\r\n"
  1447  		if _, err := io.WriteString(conn, upgradeMsg); err != nil {
  1448  			t.Error(err)
  1449  			return
  1450  		}
  1451  		if _, _, err := bufrw.ReadLine(); err != nil {
  1452  			t.Errorf("Failed to read line from client: %v", err)
  1453  			return
  1454  		}
  1455  
  1456  		for i := 0; i < n; i++ {
  1457  			if _, err := bufrw.WriteString(nthResponse(i)); err != nil {
  1458  				select {
  1459  				case <-triggerCancelCh:
  1460  				default:
  1461  					t.Errorf("Writing response #%d failed: %v", i, err)
  1462  				}
  1463  				return
  1464  			}
  1465  			bufrw.Flush()
  1466  			time.Sleep(time.Second)
  1467  		}
  1468  		if _, err := bufrw.WriteString(terminalMsg); err != nil {
  1469  			select {
  1470  			case <-triggerCancelCh:
  1471  			default:
  1472  				t.Errorf("Failed to write terminal message: %v", err)
  1473  			}
  1474  		}
  1475  		bufrw.Flush()
  1476  	}))
  1477  	defer cst.Close()
  1478  
  1479  	backendURL, _ := url.Parse(cst.URL)
  1480  	rproxy := NewSingleHostReverseProxy(backendURL)
  1481  	rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
  1482  	rproxy.ModifyResponse = func(res *http.Response) error {
  1483  		res.Header.Add("X-Modified", "true")
  1484  		return nil
  1485  	}
  1486  
  1487  	handler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
  1488  		rw.Header().Set("X-Header", "X-Value")
  1489  		ctx, cancel := context.WithCancel(req.Context())
  1490  		go func() {
  1491  			<-triggerCancelCh
  1492  			cancel()
  1493  		}()
  1494  		rproxy.ServeHTTP(rw, req.WithContext(ctx))
  1495  	})
  1496  
  1497  	frontendProxy := httptest.NewServer(handler)
  1498  	defer frontendProxy.Close()
  1499  
  1500  	req, _ := http.NewRequest("GET", frontendProxy.URL, nil)
  1501  	req.Header.Set("Connection", "Upgrade")
  1502  	req.Header.Set("Upgrade", "websocket")
  1503  
  1504  	res, err := frontendProxy.Client().Do(req)
  1505  	if err != nil {
  1506  		t.Fatalf("Dialing to frontend proxy: %v", err)
  1507  	}
  1508  	defer res.Body.Close()
  1509  	if g, w := res.StatusCode, 101; g != w {
  1510  		t.Fatalf("Switching protocols failed, got: %d, want: %d", g, w)
  1511  	}
  1512  
  1513  	if g, w := res.Header.Get("X-Header"), "X-Value"; g != w {
  1514  		t.Errorf("X-Header mismatch\n\tgot:  %q\n\twant: %q", g, w)
  1515  	}
  1516  
  1517  	if g, w := upgradeType(res.Header), "websocket"; !ascii.EqualFold(g, w) {
  1518  		t.Fatalf("Upgrade header mismatch\n\tgot:  %q\n\twant: %q", g, w)
  1519  	}
  1520  
  1521  	rwc, ok := res.Body.(io.ReadWriteCloser)
  1522  	if !ok {
  1523  		t.Fatalf("Response body type mismatch, got %T, want io.ReadWriteCloser", res.Body)
  1524  	}
  1525  
  1526  	if got, want := res.Header.Get("X-Modified"), "true"; got != want {
  1527  		t.Errorf("response X-Modified header = %q; want %q", got, want)
  1528  	}
  1529  
  1530  	if _, err := io.WriteString(rwc, "Hello\n"); err != nil {
  1531  		t.Fatalf("Failed to write first message: %v", err)
  1532  	}
  1533  
  1534  	// Read loop.
  1535  
  1536  	br := bufio.NewReader(rwc)
  1537  	for {
  1538  		line, err := br.ReadString('\n')
  1539  		switch {
  1540  		case line == terminalMsg: // this case before "err == io.EOF"
  1541  			t.Fatalf("The websocket request was not canceled, unfortunately!")
  1542  
  1543  		case err == io.EOF:
  1544  			return
  1545  
  1546  		case err != nil:
  1547  			t.Fatalf("Unexpected error: %v", err)
  1548  
  1549  		case line == nthResponse(0): // We've gotten the first response back
  1550  			// Let's trigger a cancel.
  1551  			close(triggerCancelCh)
  1552  		}
  1553  	}
  1554  }
  1555  
  1556  func TestReverseProxyWebSocketHalfTCP(t *testing.T) {
  1557  	// Issue #35892: support TCP half-close when HTTP is upgraded in the ReverseProxy.
  1558  	// Specifically testing:
  1559  	// - the communication through the reverse proxy when the client or server closes
  1560  	//   either the read or write streams
  1561  	// - that closing the write stream is propagated through the proxy and results in reading
  1562  	//   EOF at the other end of the connection
  1563  
  1564  	switch runtime.GOOS {
  1565  	case "plan9":
  1566  		t.Skipf("not supported on %s", runtime.GOOS)
  1567  	}
  1568  
  1569  	mustRead := func(t *testing.T, conn *net.TCPConn, msg string) {
  1570  		b := make([]byte, len(msg))
  1571  		if _, err := conn.Read(b); err != nil {
  1572  			t.Errorf("failed to read: %v", err)
  1573  		}
  1574  
  1575  		if got, want := string(b), msg; got != want {
  1576  			t.Errorf("got %#q, want %#q", got, want)
  1577  		}
  1578  	}
  1579  
  1580  	mustReadError := func(t *testing.T, conn *net.TCPConn, e error) {
  1581  		b := make([]byte, 1)
  1582  		if _, err := conn.Read(b); !errors.Is(err, e) {
  1583  			t.Errorf("failed to read error: %v", err)
  1584  		}
  1585  	}
  1586  
  1587  	mustWrite := func(t *testing.T, conn *net.TCPConn, msg string) {
  1588  		if _, err := conn.Write([]byte(msg)); err != nil {
  1589  			t.Errorf("failed to write: %v", err)
  1590  		}
  1591  	}
  1592  
  1593  	mustCloseRead := func(t *testing.T, conn *net.TCPConn) {
  1594  		if err := conn.CloseRead(); err != nil {
  1595  			t.Errorf("failed to CloseRead: %v", err)
  1596  		}
  1597  	}
  1598  
  1599  	mustCloseWrite := func(t *testing.T, conn *net.TCPConn) {
  1600  		if err := conn.CloseWrite(); err != nil {
  1601  			t.Errorf("failed to CloseWrite: %v", err)
  1602  		}
  1603  	}
  1604  
  1605  	tests := map[string]func(t *testing.T, cli, srv *net.TCPConn){
  1606  		"server close read": func(t *testing.T, cli, srv *net.TCPConn) {
  1607  			mustCloseRead(t, srv)
  1608  			mustWrite(t, srv, "server sends")
  1609  			mustRead(t, cli, "server sends")
  1610  		},
  1611  		"server close write": func(t *testing.T, cli, srv *net.TCPConn) {
  1612  			mustCloseWrite(t, srv)
  1613  			mustWrite(t, cli, "client sends")
  1614  			mustRead(t, srv, "client sends")
  1615  			mustReadError(t, cli, io.EOF)
  1616  		},
  1617  		"client close read": func(t *testing.T, cli, srv *net.TCPConn) {
  1618  			mustCloseRead(t, cli)
  1619  			mustWrite(t, cli, "client sends")
  1620  			mustRead(t, srv, "client sends")
  1621  		},
  1622  		"client close write": func(t *testing.T, cli, srv *net.TCPConn) {
  1623  			mustCloseWrite(t, cli)
  1624  			mustWrite(t, srv, "server sends")
  1625  			mustRead(t, cli, "server sends")
  1626  			mustReadError(t, srv, io.EOF)
  1627  		},
  1628  	}
  1629  
  1630  	for name, test := range tests {
  1631  		t.Run(name, func(t *testing.T) {
  1632  			var srv *net.TCPConn
  1633  
  1634  			backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  1635  				if g, ws := upgradeType(r.Header), "websocket"; g != ws {
  1636  					t.Fatalf("Unexpected upgrade type %q, want %q", g, ws)
  1637  				}
  1638  
  1639  				conn, _, err := w.(http.Hijacker).Hijack()
  1640  				if err != nil {
  1641  					conn.Close()
  1642  					t.Fatalf("hijack failed: %v", err)
  1643  				}
  1644  
  1645  				var ok bool
  1646  				if srv, ok = conn.(*net.TCPConn); !ok {
  1647  					conn.Close()
  1648  					t.Fatal("conn is not a TCPConn")
  1649  				}
  1650  
  1651  				upgradeMsg := "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: WebSocket\r\n\r\n"
  1652  				if _, err := io.WriteString(srv, upgradeMsg); err != nil {
  1653  					srv.Close()
  1654  					t.Fatalf("backend upgrade failed: %v", err)
  1655  				}
  1656  			}))
  1657  			defer backendServer.Close()
  1658  
  1659  			backendURL, _ := url.Parse(backendServer.URL)
  1660  			rproxy := NewSingleHostReverseProxy(backendURL)
  1661  			rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
  1662  			frontendProxy := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
  1663  				rproxy.ServeHTTP(rw, req)
  1664  			}))
  1665  			defer frontendProxy.Close()
  1666  
  1667  			frontendURL, _ := url.Parse(frontendProxy.URL)
  1668  			addr, err := net.ResolveTCPAddr("tcp", frontendURL.Host)
  1669  			if err != nil {
  1670  				t.Fatalf("failed to resolve TCP address: %v", err)
  1671  			}
  1672  			cli, err := net.DialTCP("tcp", nil, addr)
  1673  			if err != nil {
  1674  				t.Fatalf("failed to dial TCP address: %v", err)
  1675  			}
  1676  			defer cli.Close()
  1677  
  1678  			req, _ := http.NewRequest("GET", frontendProxy.URL, nil)
  1679  			req.Header.Set("Connection", "Upgrade")
  1680  			req.Header.Set("Upgrade", "websocket")
  1681  			if err := req.Write(cli); err != nil {
  1682  				t.Fatalf("failed to write request: %v", err)
  1683  			}
  1684  
  1685  			br := bufio.NewReader(cli)
  1686  			resp, err := http.ReadResponse(br, &http.Request{Method: "GET"})
  1687  			if err != nil {
  1688  				t.Fatalf("failed to read response: %v", err)
  1689  			}
  1690  			if resp.StatusCode != 101 {
  1691  				t.Fatalf("status code not 101: %v", resp.StatusCode)
  1692  			}
  1693  			if strings.ToLower(resp.Header.Get("Upgrade")) != "websocket" ||
  1694  				strings.ToLower(resp.Header.Get("Connection")) != "upgrade" {
  1695  				t.Fatalf("frontend upgrade failed")
  1696  			}
  1697  			defer srv.Close()
  1698  
  1699  			test(t, cli, srv)
  1700  		})
  1701  	}
  1702  }
  1703  
  1704  func TestReverseProxyUpgradeNoCloseWrite(t *testing.T) {
  1705  	// The backend hijacks the connection,
  1706  	// reads all data from the client,
  1707  	// and returns.
  1708  	backendDone := make(chan struct{})
  1709  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  1710  		w.Header().Set("Connection", "upgrade")
  1711  		w.Header().Set("Upgrade", "u")
  1712  		w.WriteHeader(101)
  1713  		conn, _, err := http.NewResponseController(w).Hijack()
  1714  		if err != nil {
  1715  			t.Errorf("Hijack: %v", err)
  1716  		}
  1717  		io.Copy(io.Discard, conn)
  1718  		close(backendDone)
  1719  	}))
  1720  	backendURL, err := url.Parse(backend.URL)
  1721  	if err != nil {
  1722  		t.Fatal(err)
  1723  	}
  1724  
  1725  	// The proxy includes a ModifyResponse function which replaces the response body
  1726  	// with its own wrapper, dropping the original body's CloseWrite method.
  1727  	proxyHandler := NewSingleHostReverseProxy(backendURL)
  1728  	proxyHandler.ModifyResponse = func(resp *http.Response) error {
  1729  		type readWriteCloserOnly struct {
  1730  			io.ReadWriteCloser
  1731  		}
  1732  		resp.Body = readWriteCloserOnly{resp.Body.(io.ReadWriteCloser)}
  1733  		return nil
  1734  	}
  1735  	frontend := httptest.NewServer(proxyHandler)
  1736  	defer frontend.Close()
  1737  
  1738  	// The client sends a request and closes the connection.
  1739  	req, _ := http.NewRequest("GET", frontend.URL, nil)
  1740  	req.Header.Set("Connection", "upgrade")
  1741  	req.Header.Set("Upgrade", "u")
  1742  	resp, err := frontend.Client().Do(req)
  1743  	if err != nil {
  1744  		t.Fatal(err)
  1745  	}
  1746  	resp.Body.Close()
  1747  
  1748  	// We expect that the client's closure of the connection is propagated to the backend.
  1749  	<-backendDone
  1750  }
  1751  
  1752  func TestUnannouncedTrailer(t *testing.T) {
  1753  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  1754  		w.WriteHeader(http.StatusOK)
  1755  		w.(http.Flusher).Flush()
  1756  		w.Header().Set(http.TrailerPrefix+"X-Unannounced-Trailer", "unannounced_trailer_value")
  1757  	}))
  1758  	defer backend.Close()
  1759  	backendURL, err := url.Parse(backend.URL)
  1760  	if err != nil {
  1761  		t.Fatal(err)
  1762  	}
  1763  	proxyHandler := NewSingleHostReverseProxy(backendURL)
  1764  	proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
  1765  	frontend := httptest.NewServer(proxyHandler)
  1766  	defer frontend.Close()
  1767  	frontendClient := frontend.Client()
  1768  
  1769  	res, err := frontendClient.Get(frontend.URL)
  1770  	if err != nil {
  1771  		t.Fatalf("Get: %v", err)
  1772  	}
  1773  
  1774  	io.ReadAll(res.Body)
  1775  	res.Body.Close()
  1776  	if g, w := res.Trailer.Get("X-Unannounced-Trailer"), "unannounced_trailer_value"; g != w {
  1777  		t.Errorf("Trailer(X-Unannounced-Trailer) = %q; want %q", g, w)
  1778  	}
  1779  
  1780  }
  1781  
  1782  func TestSetURL(t *testing.T) {
  1783  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  1784  		w.Write([]byte(r.Host))
  1785  	}))
  1786  	defer backend.Close()
  1787  	backendURL, err := url.Parse(backend.URL)
  1788  	if err != nil {
  1789  		t.Fatal(err)
  1790  	}
  1791  	proxyHandler := &ReverseProxy{
  1792  		Rewrite: func(r *ProxyRequest) {
  1793  			r.SetURL(backendURL)
  1794  		},
  1795  	}
  1796  	frontend := httptest.NewServer(proxyHandler)
  1797  	defer frontend.Close()
  1798  	frontendClient := frontend.Client()
  1799  
  1800  	res, err := frontendClient.Get(frontend.URL)
  1801  	if err != nil {
  1802  		t.Fatalf("Get: %v", err)
  1803  	}
  1804  	defer res.Body.Close()
  1805  
  1806  	body, err := io.ReadAll(res.Body)
  1807  	if err != nil {
  1808  		t.Fatalf("Reading body: %v", err)
  1809  	}
  1810  
  1811  	if got, want := string(body), backendURL.Host; got != want {
  1812  		t.Errorf("backend got Host %q, want %q", got, want)
  1813  	}
  1814  }
  1815  
  1816  func TestSingleJoinSlash(t *testing.T) {
  1817  	tests := []struct {
  1818  		slasha   string
  1819  		slashb   string
  1820  		expected string
  1821  	}{
  1822  		{"https://www.google.com/", "/favicon.ico", "https://www.google.com/favicon.ico"},
  1823  		{"https://www.google.com", "/favicon.ico", "https://www.google.com/favicon.ico"},
  1824  		{"https://www.google.com", "favicon.ico", "https://www.google.com/favicon.ico"},
  1825  		{"https://www.google.com", "", "https://www.google.com/"},
  1826  		{"", "favicon.ico", "/favicon.ico"},
  1827  	}
  1828  	for _, tt := range tests {
  1829  		if got := singleJoiningSlash(tt.slasha, tt.slashb); got != tt.expected {
  1830  			t.Errorf("singleJoiningSlash(%q,%q) want %q got %q",
  1831  				tt.slasha,
  1832  				tt.slashb,
  1833  				tt.expected,
  1834  				got)
  1835  		}
  1836  	}
  1837  }
  1838  
  1839  func TestJoinURLPath(t *testing.T) {
  1840  	tests := []struct {
  1841  		a        *url.URL
  1842  		b        *url.URL
  1843  		wantPath string
  1844  		wantRaw  string
  1845  	}{
  1846  		{&url.URL{Path: "/a/b"}, &url.URL{Path: "/c"}, "/a/b/c", ""},
  1847  		{&url.URL{Path: "/a/b", RawPath: "badpath"}, &url.URL{Path: "c"}, "/a/b/c", "/a/b/c"},
  1848  		{&url.URL{Path: "/a/b", RawPath: "/a%2Fb"}, &url.URL{Path: "/c"}, "/a/b/c", "/a%2Fb/c"},
  1849  		{&url.URL{Path: "/a/b", RawPath: "/a%2Fb"}, &url.URL{Path: "/c"}, "/a/b/c", "/a%2Fb/c"},
  1850  		{&url.URL{Path: "/a/b/", RawPath: "/a%2Fb%2F"}, &url.URL{Path: "c"}, "/a/b//c", "/a%2Fb%2F/c"},
  1851  		{&url.URL{Path: "/a/b/", RawPath: "/a%2Fb/"}, &url.URL{Path: "/c/d", RawPath: "/c%2Fd"}, "/a/b/c/d", "/a%2Fb/c%2Fd"},
  1852  	}
  1853  
  1854  	for _, tt := range tests {
  1855  		p, rp := joinURLPath(tt.a, tt.b)
  1856  		if p != tt.wantPath || rp != tt.wantRaw {
  1857  			t.Errorf("joinURLPath(URL(%q,%q),URL(%q,%q)) want (%q,%q) got (%q,%q)",
  1858  				tt.a.Path, tt.a.RawPath,
  1859  				tt.b.Path, tt.b.RawPath,
  1860  				tt.wantPath, tt.wantRaw,
  1861  				p, rp)
  1862  		}
  1863  	}
  1864  }
  1865  
  1866  func TestReverseProxyRewriteReplacesOut(t *testing.T) {
  1867  	const content = "response_content"
  1868  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  1869  		w.Write([]byte(content))
  1870  	}))
  1871  	defer backend.Close()
  1872  	proxyHandler := &ReverseProxy{
  1873  		Rewrite: func(r *ProxyRequest) {
  1874  			r.Out, _ = http.NewRequest("GET", backend.URL, nil)
  1875  		},
  1876  	}
  1877  	frontend := httptest.NewServer(proxyHandler)
  1878  	defer frontend.Close()
  1879  
  1880  	res, err := frontend.Client().Get(frontend.URL)
  1881  	if err != nil {
  1882  		t.Fatalf("Get: %v", err)
  1883  	}
  1884  	defer res.Body.Close()
  1885  	body, _ := io.ReadAll(res.Body)
  1886  	if got, want := string(body), content; got != want {
  1887  		t.Errorf("got response %q, want %q", got, want)
  1888  	}
  1889  }
  1890  
  1891  func Test1xxHeadersNotModifiedAfterRoundTrip(t *testing.T) {
  1892  	// https://go.dev/issue/65123: We use httptrace.Got1xxResponse to capture 1xx responses
  1893  	// and proxy them. httptrace handlers can execute after RoundTrip returns, in particular
  1894  	// after experiencing connection errors. When this happens, we shouldn't modify the
  1895  	// ResponseWriter headers after ReverseProxy.ServeHTTP returns.
  1896  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  1897  		for i := 0; i < 5; i++ {
  1898  			w.WriteHeader(103)
  1899  		}
  1900  	}))
  1901  	defer backend.Close()
  1902  	backendURL, err := url.Parse(backend.URL)
  1903  	if err != nil {
  1904  		t.Fatal(err)
  1905  	}
  1906  	proxyHandler := NewSingleHostReverseProxy(backendURL)
  1907  	proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
  1908  
  1909  	rw := &testResponseWriter{}
  1910  	func() {
  1911  		// Cancel the request (and cause RoundTrip to return) immediately upon
  1912  		// seeing a 1xx response.
  1913  		ctx, cancel := context.WithCancel(context.Background())
  1914  		defer cancel()
  1915  		ctx = httptrace.WithClientTrace(ctx, &httptrace.ClientTrace{
  1916  			Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
  1917  				cancel()
  1918  				return nil
  1919  			},
  1920  		})
  1921  
  1922  		req, _ := http.NewRequestWithContext(ctx, "GET", "http://go.dev/", nil)
  1923  		proxyHandler.ServeHTTP(rw, req)
  1924  	}()
  1925  	// Trigger data race while iterating over response headers.
  1926  	// When run with -race, this causes the condition in https://go.dev/issue/65123 often
  1927  	// enough to detect reliably.
  1928  	for _ = range rw.Header() {
  1929  	}
  1930  }
  1931  
  1932  func Test1xxResponses(t *testing.T) {
  1933  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  1934  		h := w.Header()
  1935  		h.Add("Link", "</style.css>; rel=preload; as=style")
  1936  		h.Add("Link", "</script.js>; rel=preload; as=script")
  1937  		w.WriteHeader(http.StatusEarlyHints)
  1938  
  1939  		h.Add("Link", "</foo.js>; rel=preload; as=script")
  1940  		w.WriteHeader(http.StatusProcessing)
  1941  
  1942  		w.Write([]byte("Hello"))
  1943  	}))
  1944  	defer backend.Close()
  1945  	backendURL, err := url.Parse(backend.URL)
  1946  	if err != nil {
  1947  		t.Fatal(err)
  1948  	}
  1949  	proxyHandler := NewSingleHostReverseProxy(backendURL)
  1950  	proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
  1951  	frontend := httptest.NewServer(proxyHandler)
  1952  	defer frontend.Close()
  1953  	frontendClient := frontend.Client()
  1954  
  1955  	checkLinkHeaders := func(t *testing.T, expected, got []string) {
  1956  		t.Helper()
  1957  
  1958  		if len(expected) != len(got) {
  1959  			t.Errorf("Expected %d link headers; got %d", len(expected), len(got))
  1960  		}
  1961  
  1962  		for i := range expected {
  1963  			if i >= len(got) {
  1964  				t.Errorf("Expected %q link header; got nothing", expected[i])
  1965  
  1966  				continue
  1967  			}
  1968  
  1969  			if expected[i] != got[i] {
  1970  				t.Errorf("Expected %q link header; got %q", expected[i], got[i])
  1971  			}
  1972  		}
  1973  	}
  1974  
  1975  	var respCounter uint8
  1976  	trace := &httptrace.ClientTrace{
  1977  		Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
  1978  			switch code {
  1979  			case http.StatusEarlyHints:
  1980  				checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script"}, header["Link"])
  1981  			case http.StatusProcessing:
  1982  				checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script", "</foo.js>; rel=preload; as=script"}, header["Link"])
  1983  			default:
  1984  				t.Error("Unexpected 1xx response")
  1985  			}
  1986  
  1987  			respCounter++
  1988  
  1989  			return nil
  1990  		},
  1991  	}
  1992  	req, _ := http.NewRequestWithContext(httptrace.WithClientTrace(context.Background(), trace), "GET", frontend.URL, nil)
  1993  
  1994  	res, err := frontendClient.Do(req)
  1995  	if err != nil {
  1996  		t.Fatalf("Get: %v", err)
  1997  	}
  1998  
  1999  	defer res.Body.Close()
  2000  
  2001  	if respCounter != 2 {
  2002  		t.Errorf("Expected 2 1xx responses; got %d", respCounter)
  2003  	}
  2004  	checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script", "</foo.js>; rel=preload; as=script"}, res.Header["Link"])
  2005  
  2006  	body, _ := io.ReadAll(res.Body)
  2007  	if string(body) != "Hello" {
  2008  		t.Errorf("Read body %q; want Hello", body)
  2009  	}
  2010  }
  2011  
  2012  const (
  2013  	testWantsCleanQuery = true
  2014  	testWantsRawQuery   = false
  2015  )
  2016  
  2017  func TestReverseProxyQueryParameterSmugglingDirectorDoesNotParseForm(t *testing.T) {
  2018  	testReverseProxyQueryParameterSmuggling(t, testWantsRawQuery, func(u *url.URL) *ReverseProxy {
  2019  		proxyHandler := NewSingleHostReverseProxy(u)
  2020  		oldDirector := proxyHandler.Director
  2021  		proxyHandler.Director = func(r *http.Request) {
  2022  			oldDirector(r)
  2023  		}
  2024  		return proxyHandler
  2025  	})
  2026  }
  2027  
  2028  func TestReverseProxyQueryParameterSmugglingDirectorParsesForm(t *testing.T) {
  2029  	testReverseProxyQueryParameterSmuggling(t, testWantsCleanQuery, func(u *url.URL) *ReverseProxy {
  2030  		proxyHandler := NewSingleHostReverseProxy(u)
  2031  		oldDirector := proxyHandler.Director
  2032  		proxyHandler.Director = func(r *http.Request) {
  2033  			// Parsing the form causes ReverseProxy to remove unparsable
  2034  			// query parameters before forwarding.
  2035  			r.FormValue("a")
  2036  			oldDirector(r)
  2037  		}
  2038  		return proxyHandler
  2039  	})
  2040  }
  2041  
  2042  func TestReverseProxyQueryParameterSmugglingRewrite(t *testing.T) {
  2043  	testReverseProxyQueryParameterSmuggling(t, testWantsCleanQuery, func(u *url.URL) *ReverseProxy {
  2044  		return &ReverseProxy{
  2045  			Rewrite: func(r *ProxyRequest) {
  2046  				r.SetURL(u)
  2047  			},
  2048  		}
  2049  	})
  2050  }
  2051  
  2052  func TestReverseProxyQueryParameterSmugglingRewritePreservesRawQuery(t *testing.T) {
  2053  	testReverseProxyQueryParameterSmuggling(t, testWantsRawQuery, func(u *url.URL) *ReverseProxy {
  2054  		return &ReverseProxy{
  2055  			Rewrite: func(r *ProxyRequest) {
  2056  				r.SetURL(u)
  2057  				r.Out.URL.RawQuery = r.In.URL.RawQuery
  2058  			},
  2059  		}
  2060  	})
  2061  }
  2062  
  2063  func testReverseProxyQueryParameterSmuggling(t *testing.T, wantCleanQuery bool, newProxy func(*url.URL) *ReverseProxy) {
  2064  	const content = "response_content"
  2065  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  2066  		w.Write([]byte(r.URL.RawQuery))
  2067  	}))
  2068  	defer backend.Close()
  2069  	backendURL, err := url.Parse(backend.URL)
  2070  	if err != nil {
  2071  		t.Fatal(err)
  2072  	}
  2073  	proxyHandler := newProxy(backendURL)
  2074  	frontend := httptest.NewServer(proxyHandler)
  2075  	defer frontend.Close()
  2076  
  2077  	// Don't spam output with logs of queries containing semicolons.
  2078  	backend.Config.ErrorLog = log.New(io.Discard, "", 0)
  2079  	frontend.Config.ErrorLog = log.New(io.Discard, "", 0)
  2080  
  2081  	for _, test := range []struct {
  2082  		rawQuery   string
  2083  		cleanQuery string
  2084  	}{{
  2085  		rawQuery:   "a=1&a=2;b=3",
  2086  		cleanQuery: "a=1",
  2087  	}, {
  2088  		rawQuery:   "a=1&a=%zz&b=3",
  2089  		cleanQuery: "a=1&b=3",
  2090  	}} {
  2091  		res, err := frontend.Client().Get(frontend.URL + "?" + test.rawQuery)
  2092  		if err != nil {
  2093  			t.Fatalf("Get: %v", err)
  2094  		}
  2095  		defer res.Body.Close()
  2096  		body, _ := io.ReadAll(res.Body)
  2097  		wantQuery := test.rawQuery
  2098  		if wantCleanQuery {
  2099  			wantQuery = test.cleanQuery
  2100  		}
  2101  		if got, want := string(body), wantQuery; got != want {
  2102  			t.Errorf("proxy forwarded raw query %q as %q, want %q", test.rawQuery, got, want)
  2103  		}
  2104  	}
  2105  }
  2106  
  2107  // Issue #72954: We should not call WriteHeader on a ResponseWriter after hijacking
  2108  // the connection.
  2109  func TestReverseProxyHijackCopyError(t *testing.T) {
  2110  	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  2111  		w.Header().Set("Upgrade", "someproto")
  2112  		w.WriteHeader(http.StatusSwitchingProtocols)
  2113  	}))
  2114  	defer backend.Close()
  2115  	backendURL, err := url.Parse(backend.URL)
  2116  	if err != nil {
  2117  		t.Fatal(err)
  2118  	}
  2119  	proxyHandler := &ReverseProxy{
  2120  		Rewrite: func(r *ProxyRequest) {
  2121  			r.SetURL(backendURL)
  2122  		},
  2123  		ModifyResponse: func(resp *http.Response) error {
  2124  			resp.Body = &testReadWriteCloser{
  2125  				read: func([]byte) (int, error) {
  2126  					return 0, errors.New("read error")
  2127  				},
  2128  			}
  2129  			return nil
  2130  		},
  2131  	}
  2132  
  2133  	hijacked := false
  2134  	rw := &testResponseWriter{
  2135  		writeHeader: func(statusCode int) {
  2136  			if hijacked {
  2137  				t.Errorf("WriteHeader(%v) called after Hijack", statusCode)
  2138  			}
  2139  		},
  2140  		hijack: func() (net.Conn, *bufio.ReadWriter, error) {
  2141  			hijacked = true
  2142  			cli, srv := net.Pipe()
  2143  			go io.Copy(io.Discard, cli)
  2144  			return srv, bufio.NewReadWriter(bufio.NewReader(srv), bufio.NewWriter(srv)), nil
  2145  		},
  2146  	}
  2147  	req, _ := http.NewRequest("GET", "http://example.tld/", nil)
  2148  	req.Header.Set("Upgrade", "someproto")
  2149  	proxyHandler.ServeHTTP(rw, req)
  2150  }
  2151  
  2152  type testResponseWriter struct {
  2153  	h           http.Header
  2154  	writeHeader func(int)
  2155  	write       func([]byte) (int, error)
  2156  	hijack      func() (net.Conn, *bufio.ReadWriter, error)
  2157  }
  2158  
  2159  func (rw *testResponseWriter) Header() http.Header {
  2160  	if rw.h == nil {
  2161  		rw.h = make(http.Header)
  2162  	}
  2163  	return rw.h
  2164  }
  2165  
  2166  func (rw *testResponseWriter) WriteHeader(statusCode int) {
  2167  	if rw.writeHeader != nil {
  2168  		rw.writeHeader(statusCode)
  2169  	}
  2170  }
  2171  
  2172  func (rw *testResponseWriter) Write(p []byte) (int, error) {
  2173  	if rw.write != nil {
  2174  		return rw.write(p)
  2175  	}
  2176  	return len(p), nil
  2177  }
  2178  
  2179  func (rw *testResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
  2180  	if rw.hijack != nil {
  2181  		return rw.hijack()
  2182  	}
  2183  	return nil, nil, errors.ErrUnsupported
  2184  }
  2185  
  2186  type testReadWriteCloser struct {
  2187  	read  func([]byte) (int, error)
  2188  	write func([]byte) (int, error)
  2189  	close func() error
  2190  }
  2191  
  2192  func (rc *testReadWriteCloser) Read(p []byte) (int, error) {
  2193  	if rc.read != nil {
  2194  		return rc.read(p)
  2195  	}
  2196  	return 0, io.EOF
  2197  }
  2198  
  2199  func (rc *testReadWriteCloser) Write(p []byte) (int, error) {
  2200  	if rc.write != nil {
  2201  		return rc.write(p)
  2202  	}
  2203  	return len(p), nil
  2204  }
  2205  
  2206  func (rc *testReadWriteCloser) Close() error {
  2207  	if rc.close != nil {
  2208  		return rc.close()
  2209  	}
  2210  	return nil
  2211  }
  2212  

View as plain text