Source file src/context/afterfunc_test.go

     1  // Copyright 2023 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package context_test
     6  
     7  import (
     8  	"context"
     9  	"sync"
    10  	"testing"
    11  	"time"
    12  )
    13  
    14  // afterFuncContext is a context that's not one of the types
    15  // defined in context.go, that supports registering AfterFuncs.
    16  type afterFuncContext struct {
    17  	mu         sync.Mutex
    18  	afterFuncs map[*byte]func()
    19  	done       chan struct{}
    20  	err        error
    21  }
    22  
    23  var _ context.Context = (*afterFuncContext)(nil)
    24  
    25  func (c *afterFuncContext) Deadline() (time.Time, bool) {
    26  	return time.Time{}, false
    27  }
    28  
    29  func (c *afterFuncContext) Done() <-chan struct{} {
    30  	c.mu.Lock()
    31  	defer c.mu.Unlock()
    32  	if c.done == nil {
    33  		c.done = make(chan struct{})
    34  	}
    35  	return c.done
    36  }
    37  
    38  func (c *afterFuncContext) Err() error {
    39  	c.mu.Lock()
    40  	defer c.mu.Unlock()
    41  	return c.err
    42  }
    43  
    44  func (c *afterFuncContext) Value(key any) any {
    45  	return nil
    46  }
    47  
    48  func (c *afterFuncContext) AfterFunc(f func()) func() bool {
    49  	c.mu.Lock()
    50  	defer c.mu.Unlock()
    51  	k := new(byte)
    52  	if c.afterFuncs == nil {
    53  		c.afterFuncs = make(map[*byte]func())
    54  	}
    55  	c.afterFuncs[k] = f
    56  	return func() bool {
    57  		c.mu.Lock()
    58  		defer c.mu.Unlock()
    59  		_, ok := c.afterFuncs[k]
    60  		delete(c.afterFuncs, k)
    61  		return ok
    62  	}
    63  }
    64  
    65  func (c *afterFuncContext) cancel(err error) {
    66  	c.mu.Lock()
    67  	defer c.mu.Unlock()
    68  	if c.err != nil {
    69  		return
    70  	}
    71  	c.err = err
    72  	for _, f := range c.afterFuncs {
    73  		go f()
    74  	}
    75  	c.afterFuncs = nil
    76  }
    77  
    78  func TestCustomContextAfterFuncCancel(t *testing.T) {
    79  	ctx0 := &afterFuncContext{}
    80  	ctx1, cancel := context.WithCancel(ctx0)
    81  	defer cancel()
    82  	ctx0.cancel(context.Canceled)
    83  	<-ctx1.Done()
    84  }
    85  
    86  func TestCustomContextAfterFuncTimeout(t *testing.T) {
    87  	ctx0 := &afterFuncContext{}
    88  	ctx1, cancel := context.WithTimeout(ctx0, veryLongDuration)
    89  	defer cancel()
    90  	ctx0.cancel(context.Canceled)
    91  	<-ctx1.Done()
    92  }
    93  
    94  func TestCustomContextAfterFuncAfterFunc(t *testing.T) {
    95  	ctx0 := &afterFuncContext{}
    96  	donec := make(chan struct{})
    97  	stop := context.AfterFunc(ctx0, func() {
    98  		close(donec)
    99  	})
   100  	defer stop()
   101  	ctx0.cancel(context.Canceled)
   102  	<-donec
   103  }
   104  
   105  func TestCustomContextAfterFuncUnregisterCancel(t *testing.T) {
   106  	ctx0 := &afterFuncContext{}
   107  	_, cancel1 := context.WithCancel(ctx0)
   108  	_, cancel2 := context.WithCancel(ctx0)
   109  	if got, want := len(ctx0.afterFuncs), 2; got != want {
   110  		t.Errorf("after WithCancel(ctx0): ctx0 has %v afterFuncs, want %v", got, want)
   111  	}
   112  	cancel1()
   113  	cancel2()
   114  	if got, want := len(ctx0.afterFuncs), 0; got != want {
   115  		t.Errorf("after canceling WithCancel(ctx0): ctx0 has %v afterFuncs, want %v", got, want)
   116  	}
   117  }
   118  
   119  func TestCustomContextAfterFuncUnregisterTimeout(t *testing.T) {
   120  	ctx0 := &afterFuncContext{}
   121  	_, cancel := context.WithTimeout(ctx0, veryLongDuration)
   122  	if got, want := len(ctx0.afterFuncs), 1; got != want {
   123  		t.Errorf("after WithTimeout(ctx0, d): ctx0 has %v afterFuncs, want %v", got, want)
   124  	}
   125  	cancel()
   126  	if got, want := len(ctx0.afterFuncs), 0; got != want {
   127  		t.Errorf("after canceling WithTimeout(ctx0, d): ctx0 has %v afterFuncs, want %v", got, want)
   128  	}
   129  }
   130  
   131  func TestCustomContextAfterFuncUnregisterAfterFunc(t *testing.T) {
   132  	ctx0 := &afterFuncContext{}
   133  	stop := context.AfterFunc(ctx0, func() {})
   134  	if got, want := len(ctx0.afterFuncs), 1; got != want {
   135  		t.Errorf("after AfterFunc(ctx0, f): ctx0 has %v afterFuncs, want %v", got, want)
   136  	}
   137  	stop()
   138  	if got, want := len(ctx0.afterFuncs), 0; got != want {
   139  		t.Errorf("after stopping AfterFunc(ctx0, f): ctx0 has %v afterFuncs, want %v", got, want)
   140  	}
   141  }
   142  

View as plain text