1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19 package tls
20
21
22
23
24
25
26 import (
27 "context"
28 "crypto"
29 "crypto/ecdsa"
30 "crypto/ed25519"
31 "crypto/mldsa"
32 "crypto/rsa"
33 "crypto/x509"
34 "encoding/pem"
35 "errors"
36 "fmt"
37 "net"
38 "os"
39 "strings"
40 )
41
42
43
44
45
46 func Server(conn net.Conn, config *Config) *Conn {
47 c := &Conn{
48 conn: conn,
49 config: config,
50 }
51 c.handshakeFn = c.serverHandshake
52 return c
53 }
54
55
56
57
58
59 func Client(conn net.Conn, config *Config) *Conn {
60 c := &Conn{
61 conn: conn,
62 config: config,
63 isClient: true,
64 }
65 c.handshakeFn = c.clientHandshake
66 return c
67 }
68
69
70 type listener struct {
71 net.Listener
72 config *Config
73 }
74
75
76
77 func (l *listener) Accept() (net.Conn, error) {
78 c, err := l.Listener.Accept()
79 if err != nil {
80 return nil, err
81 }
82 return Server(c, l.config), nil
83 }
84
85
86
87
88
89 func NewListener(inner net.Listener, config *Config) net.Listener {
90 l := new(listener)
91 l.Listener = inner
92 l.config = config
93 return l
94 }
95
96
97
98
99
100 func Listen(network, laddr string, config *Config) (net.Listener, error) {
101
102 if config == nil || len(config.Certificates) == 0 &&
103 config.GetCertificate == nil && config.GetConfigForClient == nil {
104 return nil, errors.New("tls: neither Certificates, GetCertificate, nor GetConfigForClient set in Config")
105 }
106 l, err := net.Listen(network, laddr)
107 if err != nil {
108 return nil, err
109 }
110 return NewListener(l, config), nil
111 }
112
113 type timeoutError struct{}
114
115 func (timeoutError) Error() string { return "tls: DialWithDialer timed out" }
116 func (timeoutError) Timeout() bool { return true }
117 func (timeoutError) Temporary() bool { return true }
118
119
120
121
122
123
124
125
126
127
128
129 func DialWithDialer(dialer *net.Dialer, network, addr string, config *Config) (*Conn, error) {
130 return dial(context.Background(), dialer, network, addr, config)
131 }
132
133 func dial(ctx context.Context, netDialer *net.Dialer, network, addr string, config *Config) (*Conn, error) {
134 if netDialer.Timeout != 0 {
135 var cancel context.CancelFunc
136 ctx, cancel = context.WithTimeout(ctx, netDialer.Timeout)
137 defer cancel()
138 }
139
140 if !netDialer.Deadline.IsZero() {
141 var cancel context.CancelFunc
142 ctx, cancel = context.WithDeadline(ctx, netDialer.Deadline)
143 defer cancel()
144 }
145
146 rawConn, err := netDialer.DialContext(ctx, network, addr)
147 if err != nil {
148 return nil, err
149 }
150
151 colonPos := strings.LastIndex(addr, ":")
152 if colonPos == -1 {
153 colonPos = len(addr)
154 }
155 hostname := addr[:colonPos]
156
157 if config == nil {
158 config = defaultConfig()
159 }
160
161
162 if config.ServerName == "" {
163
164 c := config.Clone()
165 c.ServerName = hostname
166 config = c
167 }
168
169 conn := Client(rawConn, config)
170 if err := conn.HandshakeContext(ctx); err != nil {
171 rawConn.Close()
172 return nil, err
173 }
174 return conn, nil
175 }
176
177
178
179
180
181
182
183 func Dial(network, addr string, config *Config) (*Conn, error) {
184 return DialWithDialer(new(net.Dialer), network, addr, config)
185 }
186
187
188
189 type Dialer struct {
190
191
192
193 NetDialer *net.Dialer
194
195
196
197
198
199 Config *Config
200 }
201
202
203
204
205
206
207
208
209 func (d *Dialer) Dial(network, addr string) (net.Conn, error) {
210 return d.DialContext(context.Background(), network, addr)
211 }
212
213 func (d *Dialer) netDialer() *net.Dialer {
214 if d.NetDialer != nil {
215 return d.NetDialer
216 }
217 return new(net.Dialer)
218 }
219
220
221
222
223
224
225
226
227
228
229 func (d *Dialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
230 c, err := dial(ctx, d.netDialer(), network, addr, d.Config)
231 if err != nil {
232
233 return nil, err
234 }
235 return c, nil
236 }
237
238
239
240
241
242 func LoadX509KeyPair(certFile, keyFile string) (Certificate, error) {
243 certPEMBlock, err := os.ReadFile(certFile)
244 if err != nil {
245 return Certificate{}, err
246 }
247 keyPEMBlock, err := os.ReadFile(keyFile)
248 if err != nil {
249 return Certificate{}, err
250 }
251 return X509KeyPair(certPEMBlock, keyPEMBlock)
252 }
253
254
255
256 func X509KeyPair(certPEMBlock, keyPEMBlock []byte) (Certificate, error) {
257 fail := func(err error) (Certificate, error) { return Certificate{}, err }
258
259 var cert Certificate
260 var skippedBlockTypes []string
261 for {
262 var certDERBlock *pem.Block
263 certDERBlock, certPEMBlock = pem.Decode(certPEMBlock)
264 if certDERBlock == nil {
265 break
266 }
267 if certDERBlock.Type == "CERTIFICATE" {
268 cert.Certificate = append(cert.Certificate, certDERBlock.Bytes)
269 } else {
270 skippedBlockTypes = append(skippedBlockTypes, certDERBlock.Type)
271 }
272 }
273
274 if len(cert.Certificate) == 0 {
275 if len(skippedBlockTypes) == 0 {
276 return fail(errors.New("tls: failed to find any PEM data in certificate input"))
277 }
278 if len(skippedBlockTypes) == 1 && strings.HasSuffix(skippedBlockTypes[0], "PRIVATE KEY") {
279 return fail(errors.New("tls: failed to find certificate PEM data in certificate input, but did find a private key; PEM inputs may have been switched"))
280 }
281 return fail(fmt.Errorf("tls: failed to find \"CERTIFICATE\" PEM block in certificate input after skipping PEM blocks of the following types: %v", skippedBlockTypes))
282 }
283
284 skippedBlockTypes = skippedBlockTypes[:0]
285 var keyDERBlock *pem.Block
286 for {
287 keyDERBlock, keyPEMBlock = pem.Decode(keyPEMBlock)
288 if keyDERBlock == nil {
289 if len(skippedBlockTypes) == 0 {
290 return fail(errors.New("tls: failed to find any PEM data in key input"))
291 }
292 if len(skippedBlockTypes) == 1 && skippedBlockTypes[0] == "CERTIFICATE" {
293 return fail(errors.New("tls: found a certificate rather than a key in the PEM for the private key"))
294 }
295 return fail(fmt.Errorf("tls: failed to find PEM block with type ending in \"PRIVATE KEY\" in key input after skipping PEM blocks of the following types: %v", skippedBlockTypes))
296 }
297 if keyDERBlock.Type == "PRIVATE KEY" || strings.HasSuffix(keyDERBlock.Type, " PRIVATE KEY") {
298 break
299 }
300 skippedBlockTypes = append(skippedBlockTypes, keyDERBlock.Type)
301 }
302
303
304
305 x509Cert, err := x509.ParseCertificate(cert.Certificate[0])
306 if err != nil {
307 return fail(err)
308 }
309 cert.Leaf = x509Cert
310
311 cert.PrivateKey, err = parsePrivateKey(keyDERBlock.Bytes)
312 if err != nil {
313 return fail(err)
314 }
315
316 switch pub := x509Cert.PublicKey.(type) {
317 case *rsa.PublicKey:
318 priv, ok := cert.PrivateKey.(*rsa.PrivateKey)
319 if !ok {
320 return fail(errors.New("tls: private key type does not match public key type"))
321 }
322 if !priv.PublicKey.Equal(pub) {
323 return fail(errors.New("tls: private key does not match public key"))
324 }
325 case *ecdsa.PublicKey:
326 priv, ok := cert.PrivateKey.(*ecdsa.PrivateKey)
327 if !ok {
328 return fail(errors.New("tls: private key type does not match public key type"))
329 }
330 if !priv.PublicKey.Equal(pub) {
331 return fail(errors.New("tls: private key does not match public key"))
332 }
333 case ed25519.PublicKey:
334 priv, ok := cert.PrivateKey.(ed25519.PrivateKey)
335 if !ok {
336 return fail(errors.New("tls: private key type does not match public key type"))
337 }
338 if !priv.Public().(ed25519.PublicKey).Equal(pub) {
339 return fail(errors.New("tls: private key does not match public key"))
340 }
341 case *mldsa.PublicKey:
342 priv, ok := cert.PrivateKey.(*mldsa.PrivateKey)
343 if !ok {
344 return fail(errors.New("tls: private key type does not match public key type"))
345 }
346 if !priv.PublicKey().Equal(pub) {
347 return fail(errors.New("tls: private key does not match public key"))
348 }
349 default:
350 return fail(errors.New("tls: unknown public key algorithm"))
351 }
352
353 return cert, nil
354 }
355
356
357
358
359 func parsePrivateKey(der []byte) (crypto.PrivateKey, error) {
360 key, err := x509.ParsePKCS8PrivateKey(der)
361 pkcs8Err := err
362 if err != nil {
363 key, err = x509.ParsePKCS1PrivateKey(der)
364 }
365 if err != nil {
366 key, err = x509.ParseECPrivateKey(der)
367 }
368 if err != nil {
369 return nil, fmt.Errorf("tls: failed to parse private key: %w", pkcs8Err)
370 }
371 switch key := key.(type) {
372 case *rsa.PrivateKey, *ecdsa.PrivateKey, ed25519.PrivateKey, *mldsa.PrivateKey:
373 return key, nil
374 default:
375 return nil, errors.New("tls: found unknown private key type in PKCS#8 wrapping")
376 }
377 }
378
View as plain text