Source file src/crypto/internal/fips140cache/cache_test.go

     1  // Copyright 2025 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package fips140cache
     6  
     7  import (
     8  	"context"
     9  	"errors"
    10  	"runtime"
    11  	"sync"
    12  	"testing"
    13  	"time"
    14  )
    15  
    16  func TestCache(t *testing.T) {
    17  	c := new(Cache[key, value])
    18  	checkTrue := func(*value) bool { return true }
    19  	checkFalse := func(*value) bool { return false }
    20  	newNotCalled := func() (*value, error) {
    21  		t.Helper()
    22  		t.Fatal("new called")
    23  		return nil, nil
    24  	}
    25  
    26  	k1 := newKey()
    27  	v1 := &value{}
    28  
    29  	v, err := c.Get(k1, func() (*value, error) { return v1, nil }, checkTrue)
    30  	expectValue(t, v, err, v1)
    31  
    32  	// Cached value is returned if check is true.
    33  	v, err = c.Get(k1, newNotCalled, checkTrue)
    34  	expectValue(t, v, err, v1)
    35  
    36  	// New value is returned and cached if check is false.
    37  	v2 := &value{}
    38  	v, err = c.Get(k1, func() (*value, error) { return v2, nil }, checkFalse)
    39  	expectValue(t, v, err, v2)
    40  	v, err = c.Get(k1, newNotCalled, checkTrue)
    41  	expectValue(t, v, err, v2)
    42  	expectMapSize(t, c, 1)
    43  
    44  	// Cache is evicted when key becomes unreachable.
    45  	waitUnreachable(t, &k1)
    46  	expectMapSize(t, c, 0)
    47  
    48  	// Value is not cached if new returns an error.
    49  	k2 := newKey()
    50  	err1 := errors.New("error")
    51  	_, err = c.Get(k2, func() (*value, error) { return nil, err1 }, checkTrue)
    52  	if err != err1 {
    53  		t.Errorf("got %v, want %v", err, err1)
    54  	}
    55  	expectMapSize(t, c, 0)
    56  
    57  	// Value is not replaced if check is false and new returns an error.
    58  	v, err = c.Get(k2, func() (*value, error) { return v1, nil }, checkTrue)
    59  	expectValue(t, v, err, v1)
    60  	_, err = c.Get(k2, func() (*value, error) { return v2, err1 }, checkFalse)
    61  	if err != err1 {
    62  		t.Errorf("got %v, want %v", err, err1)
    63  	}
    64  	v, err = c.Get(k2, newNotCalled, checkTrue)
    65  	expectValue(t, v, err, v1)
    66  	expectMapSize(t, c, 1)
    67  
    68  	// Cache is evicted for keys used only once.
    69  	k3 := newKey()
    70  	v, err = c.Get(k3, func() (*value, error) { return v1, nil }, checkTrue)
    71  	expectValue(t, v, err, v1)
    72  	expectMapSize(t, c, 2)
    73  	waitUnreachable(t, &k2)
    74  	waitUnreachable(t, &k3)
    75  	expectMapSize(t, c, 0)
    76  
    77  	// When two goroutines race, the returned value may be the new or old one,
    78  	// but the map must shrink to 0.
    79  	keys := make([]*key, 100)
    80  	for i := range keys {
    81  		keys[i] = newKey()
    82  		v1, v2 := &value{}, &value{}
    83  		start := make(chan struct{})
    84  		var wg sync.WaitGroup
    85  		wg.Add(2)
    86  		go func() {
    87  			<-start
    88  			v, err := c.Get(keys[i], func() (*value, error) { return v1, nil }, checkTrue)
    89  			expectValue(t, v, err, v1, v2)
    90  			wg.Done()
    91  		}()
    92  		go func() {
    93  			<-start
    94  			v, err := c.Get(keys[i], func() (*value, error) { return v2, nil }, checkTrue)
    95  			expectValue(t, v, err, v1, v2)
    96  			wg.Done()
    97  		}()
    98  		close(start)
    99  		wg.Wait()
   100  		v3 := &value{}
   101  		v, err := c.Get(keys[i], func() (*value, error) { return v3, nil }, checkTrue)
   102  		expectValue(t, v, err, v1, v2)
   103  	}
   104  	for i := range keys {
   105  		waitUnreachable(t, &keys[i])
   106  	}
   107  	expectMapSize(t, c, 0)
   108  }
   109  
   110  type key struct {
   111  	_ *int
   112  }
   113  
   114  type value struct {
   115  	_ *int
   116  }
   117  
   118  // newKey allocates a key value on the heap.
   119  //
   120  //go:noinline
   121  func newKey() *key {
   122  	return &key{}
   123  }
   124  
   125  func expectValue(t *testing.T, v *value, err error, want ...*value) {
   126  	t.Helper()
   127  	if err != nil {
   128  		t.Fatal(err)
   129  	}
   130  	for _, w := range want {
   131  		if v == w {
   132  			return
   133  		}
   134  	}
   135  	t.Errorf("got %p, want %p", v, want)
   136  }
   137  
   138  func expectMapSize(t *testing.T, c *Cache[key, value], want int) {
   139  	t.Helper()
   140  	var size int
   141  	// Loop a few times because the AddCleanup might not be done yet.
   142  	for range 10 {
   143  		size = 0
   144  		c.m.Range(func(_, _ any) bool {
   145  			size++
   146  			return true
   147  		})
   148  		if size == want {
   149  			return
   150  		}
   151  		time.Sleep(100 * time.Millisecond)
   152  	}
   153  	t.Errorf("got %d, want %d", size, want)
   154  }
   155  
   156  func waitUnreachable(t *testing.T, k **key) {
   157  	ctx, cancel := context.WithCancel(t.Context())
   158  	defer cancel()
   159  	runtime.AddCleanup(*k, func(_ *int) { cancel() }, nil)
   160  	*k = nil
   161  	for ctx.Err() == nil {
   162  		runtime.GC()
   163  	}
   164  	if ctx.Err() != context.Canceled {
   165  		t.Fatal(ctx.Err())
   166  	}
   167  }
   168  

View as plain text