Source file src/context/x_test.go

     1  // Copyright 2016 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package context_test
     6  
     7  import (
     8  	. "context"
     9  	"errors"
    10  	"fmt"
    11  	"internal/asan"
    12  	"math/rand"
    13  	"runtime"
    14  	"strings"
    15  	"sync"
    16  	"testing"
    17  	"time"
    18  )
    19  
    20  // Each XTestFoo in context_test.go must be called from a TestFoo here to run.
    21  func TestParentFinishesChild(t *testing.T) {
    22  	XTestParentFinishesChild(t) // uses unexported context types
    23  }
    24  func TestChildFinishesFirst(t *testing.T) {
    25  	XTestChildFinishesFirst(t) // uses unexported context types
    26  }
    27  func TestCancelRemoves(t *testing.T) {
    28  	XTestCancelRemoves(t) // uses unexported context types
    29  }
    30  func TestCustomContextGoroutines(t *testing.T) {
    31  	XTestCustomContextGoroutines(t) // reads the context.goroutines counter
    32  }
    33  
    34  // The following are regular tests in package context_test.
    35  
    36  // otherContext is a Context that's not one of the types defined in context.go.
    37  // This lets us test code paths that differ based on the underlying type of the
    38  // Context.
    39  type otherContext struct {
    40  	Context
    41  }
    42  
    43  const (
    44  	shortDuration    = 1 * time.Millisecond // a reasonable duration to block in a test
    45  	veryLongDuration = 1000 * time.Hour     // an arbitrary upper bound on the test's running time
    46  )
    47  
    48  // quiescent returns an arbitrary duration by which the program should have
    49  // completed any remaining work and reached a steady (idle) state.
    50  func quiescent(t *testing.T) time.Duration {
    51  	deadline, ok := t.Deadline()
    52  	if !ok {
    53  		return 5 * time.Second
    54  	}
    55  
    56  	const arbitraryCleanupMargin = 1 * time.Second
    57  	return time.Until(deadline) - arbitraryCleanupMargin
    58  }
    59  func TestBackground(t *testing.T) {
    60  	c := Background()
    61  	if c == nil {
    62  		t.Fatalf("Background returned nil")
    63  	}
    64  	select {
    65  	case x := <-c.Done():
    66  		t.Errorf("<-c.Done() == %v want nothing (it should block)", x)
    67  	default:
    68  	}
    69  	if got, want := fmt.Sprint(c), "context.Background"; got != want {
    70  		t.Errorf("Background().String() = %q want %q", got, want)
    71  	}
    72  }
    73  
    74  func TestTODO(t *testing.T) {
    75  	c := TODO()
    76  	if c == nil {
    77  		t.Fatalf("TODO returned nil")
    78  	}
    79  	select {
    80  	case x := <-c.Done():
    81  		t.Errorf("<-c.Done() == %v want nothing (it should block)", x)
    82  	default:
    83  	}
    84  	if got, want := fmt.Sprint(c), "context.TODO"; got != want {
    85  		t.Errorf("TODO().String() = %q want %q", got, want)
    86  	}
    87  }
    88  
    89  func TestWithCancel(t *testing.T) {
    90  	c1, cancel := WithCancel(Background())
    91  
    92  	if got, want := fmt.Sprint(c1), "context.Background.WithCancel"; got != want {
    93  		t.Errorf("c1.String() = %q want %q", got, want)
    94  	}
    95  
    96  	o := otherContext{c1}
    97  	c2, _ := WithCancel(o)
    98  	contexts := []Context{c1, o, c2}
    99  
   100  	for i, c := range contexts {
   101  		if d := c.Done(); d == nil {
   102  			t.Errorf("c[%d].Done() == %v want non-nil", i, d)
   103  		}
   104  		if e := c.Err(); e != nil {
   105  			t.Errorf("c[%d].Err() == %v want nil", i, e)
   106  		}
   107  
   108  		select {
   109  		case x := <-c.Done():
   110  			t.Errorf("<-c.Done() == %v want nothing (it should block)", x)
   111  		default:
   112  		}
   113  	}
   114  
   115  	cancel() // Should propagate synchronously.
   116  	for i, c := range contexts {
   117  		select {
   118  		case <-c.Done():
   119  		default:
   120  			t.Errorf("<-c[%d].Done() blocked, but shouldn't have", i)
   121  		}
   122  		if e := c.Err(); e != Canceled {
   123  			t.Errorf("c[%d].Err() == %v want %v", i, e, Canceled)
   124  		}
   125  	}
   126  }
   127  
   128  func testDeadline(c Context, name string, t *testing.T) {
   129  	t.Helper()
   130  	d := quiescent(t)
   131  	timer := time.NewTimer(d)
   132  	defer timer.Stop()
   133  	select {
   134  	case <-timer.C:
   135  		t.Fatalf("%s: context not timed out after %v", name, d)
   136  	case <-c.Done():
   137  	}
   138  	if e := c.Err(); e != DeadlineExceeded {
   139  		t.Errorf("%s: c.Err() == %v; want %v", name, e, DeadlineExceeded)
   140  	}
   141  }
   142  
   143  func TestDeadline(t *testing.T) {
   144  	t.Parallel()
   145  
   146  	c, _ := WithDeadline(Background(), time.Now().Add(shortDuration))
   147  	if got, prefix := fmt.Sprint(c), "context.Background.WithDeadline("; !strings.HasPrefix(got, prefix) {
   148  		t.Errorf("c.String() = %q want prefix %q", got, prefix)
   149  	}
   150  	testDeadline(c, "WithDeadline", t)
   151  
   152  	c, _ = WithDeadline(Background(), time.Now().Add(shortDuration))
   153  	o := otherContext{c}
   154  	testDeadline(o, "WithDeadline+otherContext", t)
   155  
   156  	c, _ = WithDeadline(Background(), time.Now().Add(shortDuration))
   157  	o = otherContext{c}
   158  	c, _ = WithDeadline(o, time.Now().Add(veryLongDuration))
   159  	testDeadline(c, "WithDeadline+otherContext+WithDeadline", t)
   160  
   161  	c, _ = WithDeadline(Background(), time.Now().Add(-shortDuration))
   162  	testDeadline(c, "WithDeadline+inthepast", t)
   163  
   164  	c, _ = WithDeadline(Background(), time.Now())
   165  	testDeadline(c, "WithDeadline+now", t)
   166  }
   167  
   168  func TestTimeout(t *testing.T) {
   169  	t.Parallel()
   170  
   171  	c, _ := WithTimeout(Background(), shortDuration)
   172  	if got, prefix := fmt.Sprint(c), "context.Background.WithDeadline("; !strings.HasPrefix(got, prefix) {
   173  		t.Errorf("c.String() = %q want prefix %q", got, prefix)
   174  	}
   175  	testDeadline(c, "WithTimeout", t)
   176  
   177  	c, _ = WithTimeout(Background(), shortDuration)
   178  	o := otherContext{c}
   179  	testDeadline(o, "WithTimeout+otherContext", t)
   180  
   181  	c, _ = WithTimeout(Background(), shortDuration)
   182  	o = otherContext{c}
   183  	c, _ = WithTimeout(o, veryLongDuration)
   184  	testDeadline(c, "WithTimeout+otherContext+WithTimeout", t)
   185  }
   186  
   187  func TestCanceledTimeout(t *testing.T) {
   188  	c, _ := WithTimeout(Background(), time.Second)
   189  	o := otherContext{c}
   190  	c, cancel := WithTimeout(o, veryLongDuration)
   191  	cancel() // Should propagate synchronously.
   192  	select {
   193  	case <-c.Done():
   194  	default:
   195  		t.Errorf("<-c.Done() blocked, but shouldn't have")
   196  	}
   197  	if e := c.Err(); e != Canceled {
   198  		t.Errorf("c.Err() == %v want %v", e, Canceled)
   199  	}
   200  }
   201  
   202  type key1 int
   203  type key2 int
   204  
   205  func (k key2) String() string { return fmt.Sprintf("%[1]T(%[1]d)", k) }
   206  
   207  var k1 = key1(1)
   208  var k2 = key2(1) // same int as k1, different type
   209  var k3 = key2(3) // same type as k2, different int
   210  
   211  func TestValues(t *testing.T) {
   212  	check := func(c Context, nm, v1, v2, v3 string) {
   213  		if v, ok := c.Value(k1).(string); ok == (len(v1) == 0) || v != v1 {
   214  			t.Errorf(`%s.Value(k1).(string) = %q, %t want %q, %t`, nm, v, ok, v1, len(v1) != 0)
   215  		}
   216  		if v, ok := c.Value(k2).(string); ok == (len(v2) == 0) || v != v2 {
   217  			t.Errorf(`%s.Value(k2).(string) = %q, %t want %q, %t`, nm, v, ok, v2, len(v2) != 0)
   218  		}
   219  		if v, ok := c.Value(k3).(string); ok == (len(v3) == 0) || v != v3 {
   220  			t.Errorf(`%s.Value(k3).(string) = %q, %t want %q, %t`, nm, v, ok, v3, len(v3) != 0)
   221  		}
   222  	}
   223  
   224  	c0 := Background()
   225  	check(c0, "c0", "", "", "")
   226  
   227  	c1 := WithValue(Background(), k1, "c1k1")
   228  	check(c1, "c1", "c1k1", "", "")
   229  
   230  	if got, want := fmt.Sprint(c1), `context.Background.WithValue(context_test.key1, c1k1)`; got != want {
   231  		t.Errorf("c.String() = %q want %q", got, want)
   232  	}
   233  
   234  	c2 := WithValue(c1, k2, "c2k2")
   235  	check(c2, "c2", "c1k1", "c2k2", "")
   236  
   237  	if got, want := fmt.Sprint(c2), `context.Background.WithValue(context_test.key1, c1k1).WithValue(context_test.key2(1), c2k2)`; got != want {
   238  		t.Errorf("c.String() = %q want %q", got, want)
   239  	}
   240  
   241  	c3 := WithValue(c2, k3, "c3k3")
   242  	check(c3, "c2", "c1k1", "c2k2", "c3k3")
   243  
   244  	c4 := WithValue(c3, k1, nil)
   245  	check(c4, "c4", "", "c2k2", "c3k3")
   246  
   247  	if got, want := fmt.Sprint(c4), `context.Background.WithValue(context_test.key1, c1k1).WithValue(context_test.key2(1), c2k2).WithValue(context_test.key2(3), c3k3).WithValue(context_test.key1, <nil>)`; got != want {
   248  		t.Errorf("c.String() = %q want %q", got, want)
   249  	}
   250  
   251  	o0 := otherContext{Background()}
   252  	check(o0, "o0", "", "", "")
   253  
   254  	o1 := otherContext{WithValue(Background(), k1, "c1k1")}
   255  	check(o1, "o1", "c1k1", "", "")
   256  
   257  	o2 := WithValue(o1, k2, "o2k2")
   258  	check(o2, "o2", "c1k1", "o2k2", "")
   259  
   260  	o3 := otherContext{c4}
   261  	check(o3, "o3", "", "c2k2", "c3k3")
   262  
   263  	o4 := WithValue(o3, k3, nil)
   264  	check(o4, "o4", "", "c2k2", "")
   265  }
   266  
   267  func TestAllocs(t *testing.T) {
   268  	if asan.Enabled {
   269  		t.Skip("test allocates more with -asan")
   270  	}
   271  	bg := Background()
   272  	for _, test := range []struct {
   273  		desc       string
   274  		f          func()
   275  		limit      float64
   276  		gccgoLimit float64
   277  	}{
   278  		{
   279  			desc:       "Background()",
   280  			f:          func() { Background() },
   281  			limit:      0,
   282  			gccgoLimit: 0,
   283  		},
   284  		{
   285  			desc: fmt.Sprintf("WithValue(bg, %v, nil)", k1),
   286  			f: func() {
   287  				c := WithValue(bg, k1, nil)
   288  				c.Value(k1)
   289  			},
   290  			limit:      3,
   291  			gccgoLimit: 3,
   292  		},
   293  		{
   294  			desc: "WithTimeout(bg, 1*time.Nanosecond)",
   295  			f: func() {
   296  				c, _ := WithTimeout(bg, 1*time.Nanosecond)
   297  				<-c.Done()
   298  			},
   299  			limit:      12,
   300  			gccgoLimit: 15,
   301  		},
   302  		{
   303  			desc: "WithCancel(bg)",
   304  			f: func() {
   305  				c, cancel := WithCancel(bg)
   306  				cancel()
   307  				<-c.Done()
   308  			},
   309  			limit:      5,
   310  			gccgoLimit: 8,
   311  		},
   312  		{
   313  			desc: "WithTimeout(bg, 5*time.Millisecond)",
   314  			f: func() {
   315  				c, cancel := WithTimeout(bg, 5*time.Millisecond)
   316  				cancel()
   317  				<-c.Done()
   318  			},
   319  			limit:      8,
   320  			gccgoLimit: 25,
   321  		},
   322  	} {
   323  		limit := test.limit
   324  		if runtime.Compiler == "gccgo" {
   325  			// gccgo does not yet do escape analysis.
   326  			// TODO(iant): Remove this when gccgo does do escape analysis.
   327  			limit = test.gccgoLimit
   328  		}
   329  		numRuns := 100
   330  		if testing.Short() {
   331  			numRuns = 10
   332  		}
   333  		if n := testing.AllocsPerRun(numRuns, test.f); n > limit {
   334  			t.Errorf("%s allocs = %f want %d", test.desc, n, int(limit))
   335  		}
   336  	}
   337  }
   338  
   339  func TestSimultaneousCancels(t *testing.T) {
   340  	root, cancel := WithCancel(Background())
   341  	m := map[Context]CancelFunc{root: cancel}
   342  	q := []Context{root}
   343  	// Create a tree of contexts.
   344  	for len(q) != 0 && len(m) < 100 {
   345  		parent := q[0]
   346  		q = q[1:]
   347  		for i := 0; i < 4; i++ {
   348  			ctx, cancel := WithCancel(parent)
   349  			m[ctx] = cancel
   350  			q = append(q, ctx)
   351  		}
   352  	}
   353  	// Start all the cancels in a random order.
   354  	var wg sync.WaitGroup
   355  	wg.Add(len(m))
   356  	for _, cancel := range m {
   357  		go func(cancel CancelFunc) {
   358  			cancel()
   359  			wg.Done()
   360  		}(cancel)
   361  	}
   362  
   363  	d := quiescent(t)
   364  	stuck := make(chan struct{})
   365  	timer := time.AfterFunc(d, func() { close(stuck) })
   366  	defer timer.Stop()
   367  
   368  	// Wait on all the contexts in a random order.
   369  	for ctx := range m {
   370  		select {
   371  		case <-ctx.Done():
   372  		case <-stuck:
   373  			buf := make([]byte, 10<<10)
   374  			n := runtime.Stack(buf, true)
   375  			t.Fatalf("timed out after %v waiting for <-ctx.Done(); stacks:\n%s", d, buf[:n])
   376  		}
   377  	}
   378  	// Wait for all the cancel functions to return.
   379  	done := make(chan struct{})
   380  	go func() {
   381  		wg.Wait()
   382  		close(done)
   383  	}()
   384  	select {
   385  	case <-done:
   386  	case <-stuck:
   387  		buf := make([]byte, 10<<10)
   388  		n := runtime.Stack(buf, true)
   389  		t.Fatalf("timed out after %v waiting for cancel functions; stacks:\n%s", d, buf[:n])
   390  	}
   391  }
   392  
   393  func TestInterlockedCancels(t *testing.T) {
   394  	parent, cancelParent := WithCancel(Background())
   395  	child, cancelChild := WithCancel(parent)
   396  	go func() {
   397  		<-parent.Done()
   398  		cancelChild()
   399  	}()
   400  	cancelParent()
   401  	d := quiescent(t)
   402  	timer := time.NewTimer(d)
   403  	defer timer.Stop()
   404  	select {
   405  	case <-child.Done():
   406  	case <-timer.C:
   407  		buf := make([]byte, 10<<10)
   408  		n := runtime.Stack(buf, true)
   409  		t.Fatalf("timed out after %v waiting for child.Done(); stacks:\n%s", d, buf[:n])
   410  	}
   411  }
   412  
   413  func TestLayersCancel(t *testing.T) {
   414  	testLayers(t, time.Now().UnixNano(), false)
   415  }
   416  
   417  func TestLayersTimeout(t *testing.T) {
   418  	testLayers(t, time.Now().UnixNano(), true)
   419  }
   420  
   421  func testLayers(t *testing.T, seed int64, testTimeout bool) {
   422  	t.Parallel()
   423  
   424  	r := rand.New(rand.NewSource(seed))
   425  	prefix := fmt.Sprintf("seed=%d", seed)
   426  	errorf := func(format string, a ...any) {
   427  		t.Errorf(prefix+format, a...)
   428  	}
   429  	const (
   430  		minLayers = 30
   431  	)
   432  	type value int
   433  	var (
   434  		vals      []*value
   435  		cancels   []CancelFunc
   436  		numTimers int
   437  		ctx       = Background()
   438  	)
   439  	for i := 0; i < minLayers || numTimers == 0 || len(cancels) == 0 || len(vals) == 0; i++ {
   440  		switch r.Intn(3) {
   441  		case 0:
   442  			v := new(value)
   443  			ctx = WithValue(ctx, v, v)
   444  			vals = append(vals, v)
   445  		case 1:
   446  			var cancel CancelFunc
   447  			ctx, cancel = WithCancel(ctx)
   448  			cancels = append(cancels, cancel)
   449  		case 2:
   450  			var cancel CancelFunc
   451  			d := veryLongDuration
   452  			if testTimeout {
   453  				d = shortDuration
   454  			}
   455  			ctx, cancel = WithTimeout(ctx, d)
   456  			cancels = append(cancels, cancel)
   457  			numTimers++
   458  		}
   459  	}
   460  	checkValues := func(when string) {
   461  		for _, key := range vals {
   462  			if val := ctx.Value(key).(*value); key != val {
   463  				errorf("%s: ctx.Value(%p) = %p want %p", when, key, val, key)
   464  			}
   465  		}
   466  	}
   467  	if !testTimeout {
   468  		select {
   469  		case <-ctx.Done():
   470  			errorf("ctx should not be canceled yet")
   471  		default:
   472  		}
   473  	}
   474  	if s, prefix := fmt.Sprint(ctx), "context.Background."; !strings.HasPrefix(s, prefix) {
   475  		t.Errorf("ctx.String() = %q want prefix %q", s, prefix)
   476  	}
   477  	t.Log(ctx)
   478  	checkValues("before cancel")
   479  	if testTimeout {
   480  		d := quiescent(t)
   481  		timer := time.NewTimer(d)
   482  		defer timer.Stop()
   483  		select {
   484  		case <-ctx.Done():
   485  		case <-timer.C:
   486  			errorf("ctx should have timed out after %v", d)
   487  		}
   488  		checkValues("after timeout")
   489  	} else {
   490  		cancel := cancels[r.Intn(len(cancels))]
   491  		cancel()
   492  		select {
   493  		case <-ctx.Done():
   494  		default:
   495  			errorf("ctx should be canceled")
   496  		}
   497  		checkValues("after cancel")
   498  	}
   499  }
   500  
   501  func TestWithCancelCanceledParent(t *testing.T) {
   502  	parent, pcancel := WithCancelCause(Background())
   503  	cause := fmt.Errorf("Because!")
   504  	pcancel(cause)
   505  
   506  	c, _ := WithCancel(parent)
   507  	select {
   508  	case <-c.Done():
   509  	default:
   510  		t.Errorf("child not done immediately upon construction")
   511  	}
   512  	if got, want := c.Err(), Canceled; got != want {
   513  		t.Errorf("child not canceled; got = %v, want = %v", got, want)
   514  	}
   515  	if got, want := Cause(c), cause; got != want {
   516  		t.Errorf("child has wrong cause; got = %v, want = %v", got, want)
   517  	}
   518  }
   519  
   520  func TestWithCancelSimultaneouslyCanceledParent(t *testing.T) {
   521  	// Cancel the parent goroutine concurrently with creating a child.
   522  	for i := 0; i < 100; i++ {
   523  		parent, pcancel := WithCancelCause(Background())
   524  		cause := fmt.Errorf("Because!")
   525  		go pcancel(cause)
   526  
   527  		c, _ := WithCancel(parent)
   528  		<-c.Done()
   529  		if got, want := c.Err(), Canceled; got != want {
   530  			t.Errorf("child not canceled; got = %v, want = %v", got, want)
   531  		}
   532  		if got, want := Cause(c), cause; got != want {
   533  			t.Errorf("child has wrong cause; got = %v, want = %v", got, want)
   534  		}
   535  	}
   536  }
   537  
   538  func TestWithValueChecksKey(t *testing.T) {
   539  	panicVal := recoveredValue(func() { _ = WithValue(Background(), []byte("foo"), "bar") })
   540  	if panicVal == nil {
   541  		t.Error("expected panic")
   542  	}
   543  	panicVal = recoveredValue(func() { _ = WithValue(Background(), nil, "bar") })
   544  	if got, want := fmt.Sprint(panicVal), "nil key"; got != want {
   545  		t.Errorf("panic = %q; want %q", got, want)
   546  	}
   547  }
   548  
   549  func TestInvalidDerivedFail(t *testing.T) {
   550  	panicVal := recoveredValue(func() { _, _ = WithCancel(nil) })
   551  	if panicVal == nil {
   552  		t.Error("expected panic")
   553  	}
   554  	panicVal = recoveredValue(func() { _, _ = WithDeadline(nil, time.Now().Add(shortDuration)) })
   555  	if panicVal == nil {
   556  		t.Error("expected panic")
   557  	}
   558  	panicVal = recoveredValue(func() { _ = WithValue(nil, "foo", "bar") })
   559  	if panicVal == nil {
   560  		t.Error("expected panic")
   561  	}
   562  }
   563  
   564  func recoveredValue(fn func()) (v any) {
   565  	defer func() { v = recover() }()
   566  	fn()
   567  	return
   568  }
   569  
   570  func TestDeadlineExceededSupportsTimeout(t *testing.T) {
   571  	i, ok := DeadlineExceeded.(interface {
   572  		Timeout() bool
   573  	})
   574  	if !ok {
   575  		t.Fatal("DeadlineExceeded does not support Timeout interface")
   576  	}
   577  	if !i.Timeout() {
   578  		t.Fatal("wrong value for timeout")
   579  	}
   580  }
   581  func TestCause(t *testing.T) {
   582  	var (
   583  		forever       = 1e6 * time.Second
   584  		parentCause   = fmt.Errorf("parentCause")
   585  		childCause    = fmt.Errorf("childCause")
   586  		tooSlow       = fmt.Errorf("tooSlow")
   587  		finishedEarly = fmt.Errorf("finishedEarly")
   588  	)
   589  	for _, test := range []struct {
   590  		name  string
   591  		ctx   func() Context
   592  		err   error
   593  		cause error
   594  	}{
   595  		{
   596  			name:  "Background",
   597  			ctx:   Background,
   598  			err:   nil,
   599  			cause: nil,
   600  		},
   601  		{
   602  			name:  "TODO",
   603  			ctx:   TODO,
   604  			err:   nil,
   605  			cause: nil,
   606  		},
   607  		{
   608  			name: "WithCancel",
   609  			ctx: func() Context {
   610  				ctx, cancel := WithCancel(Background())
   611  				cancel()
   612  				return ctx
   613  			},
   614  			err:   Canceled,
   615  			cause: Canceled,
   616  		},
   617  		{
   618  			name: "WithCancelCause",
   619  			ctx: func() Context {
   620  				ctx, cancel := WithCancelCause(Background())
   621  				cancel(parentCause)
   622  				return ctx
   623  			},
   624  			err:   Canceled,
   625  			cause: parentCause,
   626  		},
   627  		{
   628  			name: "WithCancelCause nil",
   629  			ctx: func() Context {
   630  				ctx, cancel := WithCancelCause(Background())
   631  				cancel(nil)
   632  				return ctx
   633  			},
   634  			err:   Canceled,
   635  			cause: Canceled,
   636  		},
   637  		{
   638  			name: "WithCancelCause: parent cause before child",
   639  			ctx: func() Context {
   640  				ctx, cancelParent := WithCancelCause(Background())
   641  				ctx, cancelChild := WithCancelCause(ctx)
   642  				cancelParent(parentCause)
   643  				cancelChild(childCause)
   644  				return ctx
   645  			},
   646  			err:   Canceled,
   647  			cause: parentCause,
   648  		},
   649  		{
   650  			name: "WithCancelCause: parent cause after child",
   651  			ctx: func() Context {
   652  				ctx, cancelParent := WithCancelCause(Background())
   653  				ctx, cancelChild := WithCancelCause(ctx)
   654  				cancelChild(childCause)
   655  				cancelParent(parentCause)
   656  				return ctx
   657  			},
   658  			err:   Canceled,
   659  			cause: childCause,
   660  		},
   661  		{
   662  			name: "WithCancelCause: parent cause before nil",
   663  			ctx: func() Context {
   664  				ctx, cancelParent := WithCancelCause(Background())
   665  				ctx, cancelChild := WithCancel(ctx)
   666  				cancelParent(parentCause)
   667  				cancelChild()
   668  				return ctx
   669  			},
   670  			err:   Canceled,
   671  			cause: parentCause,
   672  		},
   673  		{
   674  			name: "WithCancelCause: parent cause after nil",
   675  			ctx: func() Context {
   676  				ctx, cancelParent := WithCancelCause(Background())
   677  				ctx, cancelChild := WithCancel(ctx)
   678  				cancelChild()
   679  				cancelParent(parentCause)
   680  				return ctx
   681  			},
   682  			err:   Canceled,
   683  			cause: Canceled,
   684  		},
   685  		{
   686  			name: "WithCancelCause: child cause after nil",
   687  			ctx: func() Context {
   688  				ctx, cancelParent := WithCancel(Background())
   689  				ctx, cancelChild := WithCancelCause(ctx)
   690  				cancelParent()
   691  				cancelChild(childCause)
   692  				return ctx
   693  			},
   694  			err:   Canceled,
   695  			cause: Canceled,
   696  		},
   697  		{
   698  			name: "WithCancelCause: child cause before nil",
   699  			ctx: func() Context {
   700  				ctx, cancelParent := WithCancel(Background())
   701  				ctx, cancelChild := WithCancelCause(ctx)
   702  				cancelChild(childCause)
   703  				cancelParent()
   704  				return ctx
   705  			},
   706  			err:   Canceled,
   707  			cause: childCause,
   708  		},
   709  		{
   710  			name: "WithTimeout",
   711  			ctx: func() Context {
   712  				ctx, cancel := WithTimeout(Background(), 0)
   713  				cancel()
   714  				return ctx
   715  			},
   716  			err:   DeadlineExceeded,
   717  			cause: DeadlineExceeded,
   718  		},
   719  		{
   720  			name: "WithTimeout canceled",
   721  			ctx: func() Context {
   722  				ctx, cancel := WithTimeout(Background(), forever)
   723  				cancel()
   724  				return ctx
   725  			},
   726  			err:   Canceled,
   727  			cause: Canceled,
   728  		},
   729  		{
   730  			name: "WithTimeoutCause",
   731  			ctx: func() Context {
   732  				ctx, cancel := WithTimeoutCause(Background(), 0, tooSlow)
   733  				cancel()
   734  				return ctx
   735  			},
   736  			err:   DeadlineExceeded,
   737  			cause: tooSlow,
   738  		},
   739  		{
   740  			name: "WithTimeoutCause canceled",
   741  			ctx: func() Context {
   742  				ctx, cancel := WithTimeoutCause(Background(), forever, tooSlow)
   743  				cancel()
   744  				return ctx
   745  			},
   746  			err:   Canceled,
   747  			cause: Canceled,
   748  		},
   749  		{
   750  			name: "WithTimeoutCause stacked",
   751  			ctx: func() Context {
   752  				ctx, cancel := WithCancelCause(Background())
   753  				ctx, _ = WithTimeoutCause(ctx, 0, tooSlow)
   754  				cancel(finishedEarly)
   755  				return ctx
   756  			},
   757  			err:   DeadlineExceeded,
   758  			cause: tooSlow,
   759  		},
   760  		{
   761  			name: "WithTimeoutCause stacked canceled",
   762  			ctx: func() Context {
   763  				ctx, cancel := WithCancelCause(Background())
   764  				ctx, _ = WithTimeoutCause(ctx, forever, tooSlow)
   765  				cancel(finishedEarly)
   766  				return ctx
   767  			},
   768  			err:   Canceled,
   769  			cause: finishedEarly,
   770  		},
   771  		{
   772  			name: "WithoutCancel",
   773  			ctx: func() Context {
   774  				return WithoutCancel(Background())
   775  			},
   776  			err:   nil,
   777  			cause: nil,
   778  		},
   779  		{
   780  			name: "WithoutCancel canceled",
   781  			ctx: func() Context {
   782  				ctx, cancel := WithCancelCause(Background())
   783  				ctx = WithoutCancel(ctx)
   784  				cancel(finishedEarly)
   785  				return ctx
   786  			},
   787  			err:   nil,
   788  			cause: nil,
   789  		},
   790  		{
   791  			name: "WithoutCancel timeout",
   792  			ctx: func() Context {
   793  				ctx, cancel := WithTimeoutCause(Background(), 0, tooSlow)
   794  				ctx = WithoutCancel(ctx)
   795  				cancel()
   796  				return ctx
   797  			},
   798  			err:   nil,
   799  			cause: nil,
   800  		},
   801  	} {
   802  		test := test
   803  		t.Run(test.name, func(t *testing.T) {
   804  			t.Parallel()
   805  			ctx := test.ctx()
   806  			if got, want := ctx.Err(), test.err; want != got {
   807  				t.Errorf("ctx.Err() = %v want %v", got, want)
   808  			}
   809  			if got, want := Cause(ctx), test.cause; want != got {
   810  				t.Errorf("Cause(ctx) = %v want %v", got, want)
   811  			}
   812  		})
   813  	}
   814  }
   815  
   816  func TestCauseRace(t *testing.T) {
   817  	cause := errors.New("TestCauseRace")
   818  	ctx, cancel := WithCancelCause(Background())
   819  	go func() {
   820  		cancel(cause)
   821  	}()
   822  	for {
   823  		// Poll Cause, rather than waiting for Done, to test that
   824  		// access to the underlying cause is synchronized properly.
   825  		if err := Cause(ctx); err != nil {
   826  			if err != cause {
   827  				t.Errorf("Cause returned %v, want %v", err, cause)
   828  			}
   829  			break
   830  		}
   831  		runtime.Gosched()
   832  	}
   833  }
   834  
   835  func TestWithoutCancel(t *testing.T) {
   836  	key, value := "key", "value"
   837  	ctx := WithValue(Background(), key, value)
   838  	ctx = WithoutCancel(ctx)
   839  	if d, ok := ctx.Deadline(); !d.IsZero() || ok != false {
   840  		t.Errorf("ctx.Deadline() = %v, %v want zero, false", d, ok)
   841  	}
   842  	if done := ctx.Done(); done != nil {
   843  		t.Errorf("ctx.Deadline() = %v want nil", done)
   844  	}
   845  	if err := ctx.Err(); err != nil {
   846  		t.Errorf("ctx.Err() = %v want nil", err)
   847  	}
   848  	if v := ctx.Value(key); v != value {
   849  		t.Errorf("ctx.Value(%q) = %q want %q", key, v, value)
   850  	}
   851  }
   852  
   853  type customDoneContext struct {
   854  	Context
   855  	donec chan struct{}
   856  }
   857  
   858  func (c *customDoneContext) Done() <-chan struct{} {
   859  	return c.donec
   860  }
   861  
   862  func TestCustomContextPropagation(t *testing.T) {
   863  	cause := errors.New("TestCustomContextPropagation")
   864  	donec := make(chan struct{})
   865  	ctx1, cancel1 := WithCancelCause(Background())
   866  	ctx2 := &customDoneContext{
   867  		Context: ctx1,
   868  		donec:   donec,
   869  	}
   870  	ctx3, cancel3 := WithCancel(ctx2)
   871  	defer cancel3()
   872  
   873  	cancel1(cause)
   874  	close(donec)
   875  
   876  	<-ctx3.Done()
   877  	if got, want := ctx3.Err(), Canceled; got != want {
   878  		t.Errorf("child not canceled; got = %v, want = %v", got, want)
   879  	}
   880  	if got, want := Cause(ctx3), cause; got != want {
   881  		t.Errorf("child has wrong cause; got = %v, want = %v", got, want)
   882  	}
   883  }
   884  
   885  // customCauseContext is a custom Context used to test context.Cause.
   886  type customCauseContext struct {
   887  	mu   sync.Mutex
   888  	done chan struct{}
   889  	err  error
   890  
   891  	cancelChild CancelFunc
   892  }
   893  
   894  func (ccc *customCauseContext) Deadline() (deadline time.Time, ok bool) {
   895  	return
   896  }
   897  
   898  func (ccc *customCauseContext) Done() <-chan struct{} {
   899  	ccc.mu.Lock()
   900  	defer ccc.mu.Unlock()
   901  	return ccc.done
   902  }
   903  
   904  func (ccc *customCauseContext) Err() error {
   905  	ccc.mu.Lock()
   906  	defer ccc.mu.Unlock()
   907  	return ccc.err
   908  }
   909  
   910  func (ccc *customCauseContext) Value(key any) any {
   911  	return nil
   912  }
   913  
   914  func (ccc *customCauseContext) cancel() {
   915  	ccc.mu.Lock()
   916  	ccc.err = Canceled
   917  	close(ccc.done)
   918  	cancelChild := ccc.cancelChild
   919  	ccc.mu.Unlock()
   920  
   921  	if cancelChild != nil {
   922  		cancelChild()
   923  	}
   924  }
   925  
   926  func (ccc *customCauseContext) setCancelChild(cancelChild CancelFunc) {
   927  	ccc.cancelChild = cancelChild
   928  }
   929  
   930  func TestCustomContextCause(t *testing.T) {
   931  	// Test if we cancel a custom context, Err and Cause return Canceled.
   932  	ccc := &customCauseContext{
   933  		done: make(chan struct{}),
   934  	}
   935  	ccc.cancel()
   936  	if got := ccc.Err(); got != Canceled {
   937  		t.Errorf("ccc.Err() = %v, want %v", got, Canceled)
   938  	}
   939  	if got := Cause(ccc); got != Canceled {
   940  		t.Errorf("Cause(ccc) = %v, want %v", got, Canceled)
   941  	}
   942  
   943  	// Test that if we pass a custom context to WithCancelCause,
   944  	// and then cancel that child context with a cause,
   945  	// that the cause of the child canceled context is correct
   946  	// but that the parent custom context is not canceled.
   947  	ccc = &customCauseContext{
   948  		done: make(chan struct{}),
   949  	}
   950  	ctx, causeFunc := WithCancelCause(ccc)
   951  	cause := errors.New("TestCustomContextCause")
   952  	causeFunc(cause)
   953  	if got := ctx.Err(); got != Canceled {
   954  		t.Errorf("after CancelCauseFunc ctx.Err() = %v, want %v", got, Canceled)
   955  	}
   956  	if got := Cause(ctx); got != cause {
   957  		t.Errorf("after CancelCauseFunc Cause(ctx) = %v, want %v", got, cause)
   958  	}
   959  	if got := ccc.Err(); got != nil {
   960  		t.Errorf("after CancelCauseFunc ccc.Err() = %v, want %v", got, nil)
   961  	}
   962  	if got := Cause(ccc); got != nil {
   963  		t.Errorf("after CancelCauseFunc Cause(ccc) = %v, want %v", got, nil)
   964  	}
   965  
   966  	// Test that if we now cancel the parent custom context,
   967  	// the cause of the child canceled context is still correct,
   968  	// and the parent custom context is canceled without a cause.
   969  	ccc.cancel()
   970  	if got := ctx.Err(); got != Canceled {
   971  		t.Errorf("after CancelCauseFunc ctx.Err() = %v, want %v", got, Canceled)
   972  	}
   973  	if got := Cause(ctx); got != cause {
   974  		t.Errorf("after CancelCauseFunc Cause(ctx) = %v, want %v", got, cause)
   975  	}
   976  	if got := ccc.Err(); got != Canceled {
   977  		t.Errorf("after CancelCauseFunc ccc.Err() = %v, want %v", got, Canceled)
   978  	}
   979  	if got := Cause(ccc); got != Canceled {
   980  		t.Errorf("after CancelCauseFunc Cause(ccc) = %v, want %v", got, Canceled)
   981  	}
   982  
   983  	// Test that if we associate a custom context with a child,
   984  	// then canceling the custom context cancels the child.
   985  	ccc = &customCauseContext{
   986  		done: make(chan struct{}),
   987  	}
   988  	ctx, cancelFunc := WithCancel(ccc)
   989  	ccc.setCancelChild(cancelFunc)
   990  	ccc.cancel()
   991  	if got := ctx.Err(); got != Canceled {
   992  		t.Errorf("after CancelCauseFunc ctx.Err() = %v, want %v", got, Canceled)
   993  	}
   994  	if got := Cause(ctx); got != Canceled {
   995  		t.Errorf("after CancelCauseFunc Cause(ctx) = %v, want %v", got, Canceled)
   996  	}
   997  	if got := ccc.Err(); got != Canceled {
   998  		t.Errorf("after CancelCauseFunc ccc.Err() = %v, want %v", got, Canceled)
   999  	}
  1000  	if got := Cause(ccc); got != Canceled {
  1001  		t.Errorf("after CancelCauseFunc Cause(ccc) = %v, want %v", got, Canceled)
  1002  	}
  1003  }
  1004  
  1005  func TestAfterFuncCalledAfterCancel(t *testing.T) {
  1006  	ctx, cancel := WithCancel(Background())
  1007  	donec := make(chan struct{})
  1008  	stop := AfterFunc(ctx, func() {
  1009  		close(donec)
  1010  	})
  1011  	select {
  1012  	case <-donec:
  1013  		t.Fatalf("AfterFunc called before context is done")
  1014  	case <-time.After(shortDuration):
  1015  	}
  1016  	cancel()
  1017  	select {
  1018  	case <-donec:
  1019  	case <-time.After(veryLongDuration):
  1020  		t.Fatalf("AfterFunc not called after context is canceled")
  1021  	}
  1022  	if stop() {
  1023  		t.Fatalf("stop() = true, want false")
  1024  	}
  1025  }
  1026  
  1027  func TestAfterFuncCalledAfterTimeout(t *testing.T) {
  1028  	ctx, cancel := WithTimeout(Background(), shortDuration)
  1029  	defer cancel()
  1030  	donec := make(chan struct{})
  1031  	AfterFunc(ctx, func() {
  1032  		close(donec)
  1033  	})
  1034  	select {
  1035  	case <-donec:
  1036  	case <-time.After(veryLongDuration):
  1037  		t.Fatalf("AfterFunc not called after context is canceled")
  1038  	}
  1039  }
  1040  
  1041  func TestAfterFuncCalledImmediately(t *testing.T) {
  1042  	ctx, cancel := WithCancel(Background())
  1043  	cancel()
  1044  	donec := make(chan struct{})
  1045  	AfterFunc(ctx, func() {
  1046  		close(donec)
  1047  	})
  1048  	select {
  1049  	case <-donec:
  1050  	case <-time.After(veryLongDuration):
  1051  		t.Fatalf("AfterFunc not called for already-canceled context")
  1052  	}
  1053  }
  1054  
  1055  func TestAfterFuncNotCalledAfterStop(t *testing.T) {
  1056  	ctx, cancel := WithCancel(Background())
  1057  	donec := make(chan struct{})
  1058  	stop := AfterFunc(ctx, func() {
  1059  		close(donec)
  1060  	})
  1061  	if !stop() {
  1062  		t.Fatalf("stop() = false, want true")
  1063  	}
  1064  	cancel()
  1065  	select {
  1066  	case <-donec:
  1067  		t.Fatalf("AfterFunc called for already-canceled context")
  1068  	case <-time.After(shortDuration):
  1069  	}
  1070  	if stop() {
  1071  		t.Fatalf("stop() = true, want false")
  1072  	}
  1073  }
  1074  
  1075  // This test verifies that canceling a context does not block waiting for AfterFuncs to finish.
  1076  func TestAfterFuncCalledAsynchronously(t *testing.T) {
  1077  	ctx, cancel := WithCancel(Background())
  1078  	donec := make(chan struct{})
  1079  	stop := AfterFunc(ctx, func() {
  1080  		// The channel send blocks until donec is read from.
  1081  		donec <- struct{}{}
  1082  	})
  1083  	defer stop()
  1084  	cancel()
  1085  	// After cancel returns, read from donec and unblock the AfterFunc.
  1086  	select {
  1087  	case <-donec:
  1088  	case <-time.After(veryLongDuration):
  1089  		t.Fatalf("AfterFunc not called after context is canceled")
  1090  	}
  1091  }
  1092  

View as plain text