1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19 package tls
20
21
22
23
24
25
26 import (
27 "bytes"
28 "context"
29 "crypto"
30 "crypto/ecdsa"
31 "crypto/ed25519"
32 "crypto/rsa"
33 "crypto/x509"
34 "encoding/pem"
35 "errors"
36 "fmt"
37 "internal/godebug"
38 "net"
39 "os"
40 "strings"
41 )
42
43
44
45
46
47 func Server(conn net.Conn, config *Config) *Conn {
48 c := &Conn{
49 conn: conn,
50 config: config,
51 }
52 c.handshakeFn = c.serverHandshake
53 return c
54 }
55
56
57
58
59
60 func Client(conn net.Conn, config *Config) *Conn {
61 c := &Conn{
62 conn: conn,
63 config: config,
64 isClient: true,
65 }
66 c.handshakeFn = c.clientHandshake
67 return c
68 }
69
70
71 type listener struct {
72 net.Listener
73 config *Config
74 }
75
76
77
78 func (l *listener) Accept() (net.Conn, error) {
79 c, err := l.Listener.Accept()
80 if err != nil {
81 return nil, err
82 }
83 return Server(c, l.config), nil
84 }
85
86
87
88
89
90 func NewListener(inner net.Listener, config *Config) net.Listener {
91 l := new(listener)
92 l.Listener = inner
93 l.config = config
94 return l
95 }
96
97
98
99
100
101 func Listen(network, laddr string, config *Config) (net.Listener, error) {
102
103 if config == nil || len(config.Certificates) == 0 &&
104 config.GetCertificate == nil && config.GetConfigForClient == nil {
105 return nil, errors.New("tls: neither Certificates, GetCertificate, nor GetConfigForClient set in Config")
106 }
107 l, err := net.Listen(network, laddr)
108 if err != nil {
109 return nil, err
110 }
111 return NewListener(l, config), nil
112 }
113
114 type timeoutError struct{}
115
116 func (timeoutError) Error() string { return "tls: DialWithDialer timed out" }
117 func (timeoutError) Timeout() bool { return true }
118 func (timeoutError) Temporary() bool { return true }
119
120
121
122
123
124
125
126
127
128
129
130 func DialWithDialer(dialer *net.Dialer, network, addr string, config *Config) (*Conn, error) {
131 return dial(context.Background(), dialer, network, addr, config)
132 }
133
134 func dial(ctx context.Context, netDialer *net.Dialer, network, addr string, config *Config) (*Conn, error) {
135 if netDialer.Timeout != 0 {
136 var cancel context.CancelFunc
137 ctx, cancel = context.WithTimeout(ctx, netDialer.Timeout)
138 defer cancel()
139 }
140
141 if !netDialer.Deadline.IsZero() {
142 var cancel context.CancelFunc
143 ctx, cancel = context.WithDeadline(ctx, netDialer.Deadline)
144 defer cancel()
145 }
146
147 rawConn, err := netDialer.DialContext(ctx, network, addr)
148 if err != nil {
149 return nil, err
150 }
151
152 colonPos := strings.LastIndex(addr, ":")
153 if colonPos == -1 {
154 colonPos = len(addr)
155 }
156 hostname := addr[:colonPos]
157
158 if config == nil {
159 config = defaultConfig()
160 }
161
162
163 if config.ServerName == "" {
164
165 c := config.Clone()
166 c.ServerName = hostname
167 config = c
168 }
169
170 conn := Client(rawConn, config)
171 if err := conn.HandshakeContext(ctx); err != nil {
172 rawConn.Close()
173 return nil, err
174 }
175 return conn, nil
176 }
177
178
179
180
181
182
183
184 func Dial(network, addr string, config *Config) (*Conn, error) {
185 return DialWithDialer(new(net.Dialer), network, addr, config)
186 }
187
188
189
190 type Dialer struct {
191
192
193
194 NetDialer *net.Dialer
195
196
197
198
199
200 Config *Config
201 }
202
203
204
205
206
207
208
209
210 func (d *Dialer) Dial(network, addr string) (net.Conn, error) {
211 return d.DialContext(context.Background(), network, addr)
212 }
213
214 func (d *Dialer) netDialer() *net.Dialer {
215 if d.NetDialer != nil {
216 return d.NetDialer
217 }
218 return new(net.Dialer)
219 }
220
221
222
223
224
225
226
227
228
229
230 func (d *Dialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
231 c, err := dial(ctx, d.netDialer(), network, addr, d.Config)
232 if err != nil {
233
234 return nil, err
235 }
236 return c, nil
237 }
238
239
240
241
242
243
244
245
246
247 func LoadX509KeyPair(certFile, keyFile string) (Certificate, error) {
248 certPEMBlock, err := os.ReadFile(certFile)
249 if err != nil {
250 return Certificate{}, err
251 }
252 keyPEMBlock, err := os.ReadFile(keyFile)
253 if err != nil {
254 return Certificate{}, err
255 }
256 return X509KeyPair(certPEMBlock, keyPEMBlock)
257 }
258
259 var x509keypairleaf = godebug.New("x509keypairleaf")
260
261
262
263
264
265
266
267 func X509KeyPair(certPEMBlock, keyPEMBlock []byte) (Certificate, error) {
268 fail := func(err error) (Certificate, error) { return Certificate{}, err }
269
270 var cert Certificate
271 var skippedBlockTypes []string
272 for {
273 var certDERBlock *pem.Block
274 certDERBlock, certPEMBlock = pem.Decode(certPEMBlock)
275 if certDERBlock == nil {
276 break
277 }
278 if certDERBlock.Type == "CERTIFICATE" {
279 cert.Certificate = append(cert.Certificate, certDERBlock.Bytes)
280 } else {
281 skippedBlockTypes = append(skippedBlockTypes, certDERBlock.Type)
282 }
283 }
284
285 if len(cert.Certificate) == 0 {
286 if len(skippedBlockTypes) == 0 {
287 return fail(errors.New("tls: failed to find any PEM data in certificate input"))
288 }
289 if len(skippedBlockTypes) == 1 && strings.HasSuffix(skippedBlockTypes[0], "PRIVATE KEY") {
290 return fail(errors.New("tls: failed to find certificate PEM data in certificate input, but did find a private key; PEM inputs may have been switched"))
291 }
292 return fail(fmt.Errorf("tls: failed to find \"CERTIFICATE\" PEM block in certificate input after skipping PEM blocks of the following types: %v", skippedBlockTypes))
293 }
294
295 skippedBlockTypes = skippedBlockTypes[:0]
296 var keyDERBlock *pem.Block
297 for {
298 keyDERBlock, keyPEMBlock = pem.Decode(keyPEMBlock)
299 if keyDERBlock == nil {
300 if len(skippedBlockTypes) == 0 {
301 return fail(errors.New("tls: failed to find any PEM data in key input"))
302 }
303 if len(skippedBlockTypes) == 1 && skippedBlockTypes[0] == "CERTIFICATE" {
304 return fail(errors.New("tls: found a certificate rather than a key in the PEM for the private key"))
305 }
306 return fail(fmt.Errorf("tls: failed to find PEM block with type ending in \"PRIVATE KEY\" in key input after skipping PEM blocks of the following types: %v", skippedBlockTypes))
307 }
308 if keyDERBlock.Type == "PRIVATE KEY" || strings.HasSuffix(keyDERBlock.Type, " PRIVATE KEY") {
309 break
310 }
311 skippedBlockTypes = append(skippedBlockTypes, keyDERBlock.Type)
312 }
313
314
315
316 x509Cert, err := x509.ParseCertificate(cert.Certificate[0])
317 if err != nil {
318 return fail(err)
319 }
320
321 if x509keypairleaf.Value() != "0" {
322 cert.Leaf = x509Cert
323 } else {
324 x509keypairleaf.IncNonDefault()
325 }
326
327 cert.PrivateKey, err = parsePrivateKey(keyDERBlock.Bytes)
328 if err != nil {
329 return fail(err)
330 }
331
332 switch pub := x509Cert.PublicKey.(type) {
333 case *rsa.PublicKey:
334 priv, ok := cert.PrivateKey.(*rsa.PrivateKey)
335 if !ok {
336 return fail(errors.New("tls: private key type does not match public key type"))
337 }
338 if pub.N.Cmp(priv.N) != 0 {
339 return fail(errors.New("tls: private key does not match public key"))
340 }
341 case *ecdsa.PublicKey:
342 priv, ok := cert.PrivateKey.(*ecdsa.PrivateKey)
343 if !ok {
344 return fail(errors.New("tls: private key type does not match public key type"))
345 }
346 if pub.X.Cmp(priv.X) != 0 || pub.Y.Cmp(priv.Y) != 0 {
347 return fail(errors.New("tls: private key does not match public key"))
348 }
349 case ed25519.PublicKey:
350 priv, ok := cert.PrivateKey.(ed25519.PrivateKey)
351 if !ok {
352 return fail(errors.New("tls: private key type does not match public key type"))
353 }
354 if !bytes.Equal(priv.Public().(ed25519.PublicKey), pub) {
355 return fail(errors.New("tls: private key does not match public key"))
356 }
357 default:
358 return fail(errors.New("tls: unknown public key algorithm"))
359 }
360
361 return cert, nil
362 }
363
364
365
366
367 func parsePrivateKey(der []byte) (crypto.PrivateKey, error) {
368 if key, err := x509.ParsePKCS1PrivateKey(der); err == nil {
369 return key, nil
370 }
371 if key, err := x509.ParsePKCS8PrivateKey(der); err == nil {
372 switch key := key.(type) {
373 case *rsa.PrivateKey, *ecdsa.PrivateKey, ed25519.PrivateKey:
374 return key, nil
375 default:
376 return nil, errors.New("tls: found unknown private key type in PKCS#8 wrapping")
377 }
378 }
379 if key, err := x509.ParseECPrivateKey(der); err == nil {
380 return key, nil
381 }
382
383 return nil, errors.New("tls: failed to parse private key")
384 }
385
View as plain text