1
2
3
4
5 package fips140cache
6
7 import (
8 "context"
9 "errors"
10 "runtime"
11 "sync"
12 "testing"
13 "time"
14 )
15
16 func TestCache(t *testing.T) {
17 c := new(Cache[key, value])
18 checkTrue := func(*value) bool { return true }
19 checkFalse := func(*value) bool { return false }
20 newNotCalled := func() (*value, error) {
21 t.Helper()
22 t.Fatal("new called")
23 return nil, nil
24 }
25
26 k1 := newKey()
27 v1 := &value{}
28
29 v, err := c.Get(k1, func() (*value, error) { return v1, nil }, checkTrue)
30 expectValue(t, v, err, v1)
31
32
33 v, err = c.Get(k1, newNotCalled, checkTrue)
34 expectValue(t, v, err, v1)
35
36
37 v2 := &value{}
38 v, err = c.Get(k1, func() (*value, error) { return v2, nil }, checkFalse)
39 expectValue(t, v, err, v2)
40 v, err = c.Get(k1, newNotCalled, checkTrue)
41 expectValue(t, v, err, v2)
42 expectMapSize(t, c, 1)
43
44
45 waitUnreachable(t, &k1)
46 expectMapSize(t, c, 0)
47
48
49 k2 := newKey()
50 err1 := errors.New("error")
51 _, err = c.Get(k2, func() (*value, error) { return nil, err1 }, checkTrue)
52 if err != err1 {
53 t.Errorf("got %v, want %v", err, err1)
54 }
55 expectMapSize(t, c, 0)
56
57
58 v, err = c.Get(k2, func() (*value, error) { return v1, nil }, checkTrue)
59 expectValue(t, v, err, v1)
60 _, err = c.Get(k2, func() (*value, error) { return v2, err1 }, checkFalse)
61 if err != err1 {
62 t.Errorf("got %v, want %v", err, err1)
63 }
64 v, err = c.Get(k2, newNotCalled, checkTrue)
65 expectValue(t, v, err, v1)
66 expectMapSize(t, c, 1)
67
68
69 k3 := newKey()
70 v, err = c.Get(k3, func() (*value, error) { return v1, nil }, checkTrue)
71 expectValue(t, v, err, v1)
72 expectMapSize(t, c, 2)
73 waitUnreachable(t, &k2)
74 waitUnreachable(t, &k3)
75 expectMapSize(t, c, 0)
76
77
78
79 keys := make([]*key, 100)
80 for i := range keys {
81 keys[i] = newKey()
82 v1, v2 := &value{}, &value{}
83 start := make(chan struct{})
84 var wg sync.WaitGroup
85 wg.Add(2)
86 go func() {
87 <-start
88 v, err := c.Get(keys[i], func() (*value, error) { return v1, nil }, checkTrue)
89 expectValue(t, v, err, v1, v2)
90 wg.Done()
91 }()
92 go func() {
93 <-start
94 v, err := c.Get(keys[i], func() (*value, error) { return v2, nil }, checkTrue)
95 expectValue(t, v, err, v1, v2)
96 wg.Done()
97 }()
98 close(start)
99 wg.Wait()
100 v3 := &value{}
101 v, err := c.Get(keys[i], func() (*value, error) { return v3, nil }, checkTrue)
102 expectValue(t, v, err, v1, v2)
103 }
104 for i := range keys {
105 waitUnreachable(t, &keys[i])
106 }
107 expectMapSize(t, c, 0)
108 }
109
110 type key struct {
111 _ *int
112 }
113
114 type value struct {
115 _ *int
116 }
117
118
119
120
121 func newKey() *key {
122 return &key{}
123 }
124
125 func expectValue(t *testing.T, v *value, err error, want ...*value) {
126 t.Helper()
127 if err != nil {
128 t.Fatal(err)
129 }
130 for _, w := range want {
131 if v == w {
132 return
133 }
134 }
135 t.Errorf("got %p, want %p", v, want)
136 }
137
138 func expectMapSize(t *testing.T, c *Cache[key, value], want int) {
139 t.Helper()
140 var size int
141
142 for range 10 {
143 size = 0
144 c.m.Range(func(_, _ any) bool {
145 size++
146 return true
147 })
148 if size == want {
149 return
150 }
151 time.Sleep(100 * time.Millisecond)
152 }
153 t.Errorf("got %d, want %d", size, want)
154 }
155
156 func waitUnreachable(t *testing.T, k **key) {
157 ctx, cancel := context.WithCancel(t.Context())
158 defer cancel()
159 runtime.AddCleanup(*k, func(_ *int) { cancel() }, nil)
160 *k = nil
161 for ctx.Err() == nil {
162 runtime.GC()
163 }
164 if ctx.Err() != context.Canceled {
165 t.Fatal(ctx.Err())
166 }
167 }
168
View as plain text