Source file
src/context/afterfunc_test.go
1
2
3
4
5 package context_test
6
7 import (
8 "context"
9 "sync"
10 "testing"
11 "time"
12 )
13
14
15
16 type afterFuncContext struct {
17 mu sync.Mutex
18 afterFuncs map[*byte]func()
19 done chan struct{}
20 err error
21 }
22
23 var _ context.Context = (*afterFuncContext)(nil)
24
25 func (c *afterFuncContext) Deadline() (time.Time, bool) {
26 return time.Time{}, false
27 }
28
29 func (c *afterFuncContext) Done() <-chan struct{} {
30 c.mu.Lock()
31 defer c.mu.Unlock()
32 if c.done == nil {
33 c.done = make(chan struct{})
34 }
35 return c.done
36 }
37
38 func (c *afterFuncContext) Err() error {
39 c.mu.Lock()
40 defer c.mu.Unlock()
41 return c.err
42 }
43
44 func (c *afterFuncContext) Value(key any) any {
45 return nil
46 }
47
48 func (c *afterFuncContext) AfterFunc(f func()) func() bool {
49 c.mu.Lock()
50 defer c.mu.Unlock()
51 k := new(byte)
52 if c.afterFuncs == nil {
53 c.afterFuncs = make(map[*byte]func())
54 }
55 c.afterFuncs[k] = f
56 return func() bool {
57 c.mu.Lock()
58 defer c.mu.Unlock()
59 _, ok := c.afterFuncs[k]
60 delete(c.afterFuncs, k)
61 return ok
62 }
63 }
64
65 func (c *afterFuncContext) cancel(err error) {
66 c.mu.Lock()
67 defer c.mu.Unlock()
68 if c.err != nil {
69 return
70 }
71 c.err = err
72 for _, f := range c.afterFuncs {
73 go f()
74 }
75 c.afterFuncs = nil
76 }
77
78 func TestCustomContextAfterFuncCancel(t *testing.T) {
79 ctx0 := &afterFuncContext{}
80 ctx1, cancel := context.WithCancel(ctx0)
81 defer cancel()
82 ctx0.cancel(context.Canceled)
83 <-ctx1.Done()
84 }
85
86 func TestCustomContextAfterFuncTimeout(t *testing.T) {
87 ctx0 := &afterFuncContext{}
88 ctx1, cancel := context.WithTimeout(ctx0, veryLongDuration)
89 defer cancel()
90 ctx0.cancel(context.Canceled)
91 <-ctx1.Done()
92 }
93
94 func TestCustomContextAfterFuncAfterFunc(t *testing.T) {
95 ctx0 := &afterFuncContext{}
96 donec := make(chan struct{})
97 stop := context.AfterFunc(ctx0, func() {
98 close(donec)
99 })
100 defer stop()
101 ctx0.cancel(context.Canceled)
102 <-donec
103 }
104
105 func TestCustomContextAfterFuncUnregisterCancel(t *testing.T) {
106 ctx0 := &afterFuncContext{}
107 _, cancel1 := context.WithCancel(ctx0)
108 _, cancel2 := context.WithCancel(ctx0)
109 if got, want := len(ctx0.afterFuncs), 2; got != want {
110 t.Errorf("after WithCancel(ctx0): ctx0 has %v afterFuncs, want %v", got, want)
111 }
112 cancel1()
113 cancel2()
114 if got, want := len(ctx0.afterFuncs), 0; got != want {
115 t.Errorf("after canceling WithCancel(ctx0): ctx0 has %v afterFuncs, want %v", got, want)
116 }
117 }
118
119 func TestCustomContextAfterFuncUnregisterTimeout(t *testing.T) {
120 ctx0 := &afterFuncContext{}
121 _, cancel := context.WithTimeout(ctx0, veryLongDuration)
122 if got, want := len(ctx0.afterFuncs), 1; got != want {
123 t.Errorf("after WithTimeout(ctx0, d): ctx0 has %v afterFuncs, want %v", got, want)
124 }
125 cancel()
126 if got, want := len(ctx0.afterFuncs), 0; got != want {
127 t.Errorf("after canceling WithTimeout(ctx0, d): ctx0 has %v afterFuncs, want %v", got, want)
128 }
129 }
130
131 func TestCustomContextAfterFuncUnregisterAfterFunc(t *testing.T) {
132 ctx0 := &afterFuncContext{}
133 stop := context.AfterFunc(ctx0, func() {})
134 if got, want := len(ctx0.afterFuncs), 1; got != want {
135 t.Errorf("after AfterFunc(ctx0, f): ctx0 has %v afterFuncs, want %v", got, want)
136 }
137 stop()
138 if got, want := len(ctx0.afterFuncs), 0; got != want {
139 t.Errorf("after stopping AfterFunc(ctx0, f): ctx0 has %v afterFuncs, want %v", got, want)
140 }
141 }
142
View as plain text