1
2
3
4
5 package template
6
7 import (
8 "errors"
9 "fmt"
10 "io"
11 "net/url"
12 "reflect"
13 "strings"
14 "sync"
15 "unicode"
16 "unicode/utf8"
17 )
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33 type FuncMap map[string]any
34
35
36
37
38
39 func builtins() FuncMap {
40 return FuncMap{
41 "and": and,
42 "call": emptyCall,
43 "html": HTMLEscaper,
44 "index": index,
45 "slice": slice,
46 "js": JSEscaper,
47 "len": length,
48 "not": not,
49 "or": or,
50 "print": fmt.Sprint,
51 "printf": fmt.Sprintf,
52 "println": fmt.Sprintln,
53 "urlquery": URLQueryEscaper,
54
55
56 "eq": eq,
57 "ge": ge,
58 "gt": gt,
59 "le": le,
60 "lt": lt,
61 "ne": ne,
62 }
63 }
64
65
66 var builtinFuncs = sync.OnceValue(func() map[string]reflect.Value {
67 funcMap := builtins()
68 m := make(map[string]reflect.Value, len(funcMap))
69 addValueFuncs(m, funcMap)
70 return m
71 })
72
73
74 func addValueFuncs(out map[string]reflect.Value, in FuncMap) {
75 for name, fn := range in {
76 if !goodName(name) {
77 panic(fmt.Errorf("function name %q is not a valid identifier", name))
78 }
79 v := reflect.ValueOf(fn)
80 if v.Kind() != reflect.Func {
81 panic("value for " + name + " not a function")
82 }
83 if err := goodFunc(name, v.Type()); err != nil {
84 panic(err)
85 }
86 out[name] = v
87 }
88 }
89
90
91
92 func addFuncs(out, in FuncMap) {
93 for name, fn := range in {
94 out[name] = fn
95 }
96 }
97
98
99 func goodFunc(name string, typ reflect.Type) error {
100
101 switch numOut := typ.NumOut(); {
102 case numOut == 1:
103 return nil
104 case numOut == 2 && typ.Out(1) == errorType:
105 return nil
106 case numOut == 2:
107 return fmt.Errorf("invalid function signature for %s: second return value should be error; is %s", name, typ.Out(1))
108 default:
109 return fmt.Errorf("function %s has %d return values; should be 1 or 2", name, typ.NumOut())
110 }
111 }
112
113
114 func goodName(name string) bool {
115 if name == "" {
116 return false
117 }
118 for i, r := range name {
119 switch {
120 case r == '_':
121 case i == 0 && !unicode.IsLetter(r):
122 return false
123 case !unicode.IsLetter(r) && !unicode.IsDigit(r):
124 return false
125 }
126 }
127 return true
128 }
129
130
131 func findFunction(name string, tmpl *Template) (v reflect.Value, isBuiltin, ok bool) {
132 if tmpl != nil && tmpl.common != nil {
133 tmpl.muFuncs.RLock()
134 defer tmpl.muFuncs.RUnlock()
135 if fn := tmpl.execFuncs[name]; fn.IsValid() {
136 return fn, false, true
137 }
138 }
139 if fn := builtinFuncs()[name]; fn.IsValid() {
140 return fn, true, true
141 }
142 return reflect.Value{}, false, false
143 }
144
145
146
147 func prepareArg(value reflect.Value, argType reflect.Type) (reflect.Value, error) {
148 if !value.IsValid() {
149 if !canBeNil(argType) {
150 return reflect.Value{}, fmt.Errorf("value is nil; should be of type %s", argType)
151 }
152 value = reflect.Zero(argType)
153 }
154 if value.Type().AssignableTo(argType) {
155 return value, nil
156 }
157 if intLike(value.Kind()) && intLike(argType.Kind()) && value.Type().ConvertibleTo(argType) {
158 value = value.Convert(argType)
159 return value, nil
160 }
161 return reflect.Value{}, fmt.Errorf("value has type %s; should be %s", value.Type(), argType)
162 }
163
164 func intLike(typ reflect.Kind) bool {
165 switch typ {
166 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
167 return true
168 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
169 return true
170 }
171 return false
172 }
173
174
175 func indexArg(index reflect.Value, cap int) (int, error) {
176 var x int64
177 switch index.Kind() {
178 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
179 x = index.Int()
180 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
181 x = int64(index.Uint())
182 case reflect.Invalid:
183 return 0, fmt.Errorf("cannot index slice/array with nil")
184 default:
185 return 0, fmt.Errorf("cannot index slice/array with type %s", index.Type())
186 }
187 if x < 0 || int(x) < 0 || int(x) > cap {
188 return 0, fmt.Errorf("index out of range: %d", x)
189 }
190 return int(x), nil
191 }
192
193
194
195
196
197
198 func index(item reflect.Value, indexes ...reflect.Value) (reflect.Value, error) {
199 item = indirectInterface(item)
200 if !item.IsValid() {
201 return reflect.Value{}, fmt.Errorf("index of untyped nil")
202 }
203 for _, index := range indexes {
204 index = indirectInterface(index)
205 var isNil bool
206 if item, isNil = indirect(item); isNil {
207 return reflect.Value{}, fmt.Errorf("index of nil pointer")
208 }
209 switch item.Kind() {
210 case reflect.Array, reflect.Slice, reflect.String:
211 x, err := indexArg(index, item.Len())
212 if err != nil {
213 return reflect.Value{}, err
214 }
215 item = item.Index(x)
216 case reflect.Map:
217 index, err := prepareArg(index, item.Type().Key())
218 if err != nil {
219 return reflect.Value{}, err
220 }
221 if x := item.MapIndex(index); x.IsValid() {
222 item = x
223 } else {
224 item = reflect.Zero(item.Type().Elem())
225 }
226 case reflect.Invalid:
227
228 panic("unreachable")
229 default:
230 return reflect.Value{}, fmt.Errorf("can't index item of type %s", item.Type())
231 }
232 }
233 return item, nil
234 }
235
236
237
238
239
240
241
242 func slice(item reflect.Value, indexes ...reflect.Value) (reflect.Value, error) {
243 item = indirectInterface(item)
244 if !item.IsValid() {
245 return reflect.Value{}, fmt.Errorf("slice of untyped nil")
246 }
247 if len(indexes) > 3 {
248 return reflect.Value{}, fmt.Errorf("too many slice indexes: %d", len(indexes))
249 }
250 var cap int
251 switch item.Kind() {
252 case reflect.String:
253 if len(indexes) == 3 {
254 return reflect.Value{}, fmt.Errorf("cannot 3-index slice a string")
255 }
256 cap = item.Len()
257 case reflect.Array, reflect.Slice:
258 cap = item.Cap()
259 default:
260 return reflect.Value{}, fmt.Errorf("can't slice item of type %s", item.Type())
261 }
262
263 idx := [3]int{0, item.Len()}
264 for i, index := range indexes {
265 x, err := indexArg(index, cap)
266 if err != nil {
267 return reflect.Value{}, err
268 }
269 idx[i] = x
270 }
271
272 if idx[0] > idx[1] {
273 return reflect.Value{}, fmt.Errorf("invalid slice index: %d > %d", idx[0], idx[1])
274 }
275 if len(indexes) < 3 {
276 return item.Slice(idx[0], idx[1]), nil
277 }
278
279 if idx[1] > idx[2] {
280 return reflect.Value{}, fmt.Errorf("invalid slice index: %d > %d", idx[1], idx[2])
281 }
282 return item.Slice3(idx[0], idx[1], idx[2]), nil
283 }
284
285
286
287
288 func length(item reflect.Value) (int, error) {
289 item, isNil := indirect(item)
290 if isNil {
291 return 0, fmt.Errorf("len of nil pointer")
292 }
293 switch item.Kind() {
294 case reflect.Array, reflect.Chan, reflect.Map, reflect.Slice, reflect.String:
295 return item.Len(), nil
296 }
297 return 0, fmt.Errorf("len of type %s", item.Type())
298 }
299
300
301
302 func emptyCall(fn reflect.Value, args ...reflect.Value) reflect.Value {
303 panic("unreachable")
304 }
305
306
307
308 func call(name string, fn reflect.Value, args ...reflect.Value) (reflect.Value, error) {
309 fn = indirectInterface(fn)
310 if !fn.IsValid() {
311 return reflect.Value{}, fmt.Errorf("call of nil")
312 }
313 typ := fn.Type()
314 if typ.Kind() != reflect.Func {
315 return reflect.Value{}, fmt.Errorf("non-function %s of type %s", name, typ)
316 }
317
318 if err := goodFunc(name, typ); err != nil {
319 return reflect.Value{}, err
320 }
321 numIn := typ.NumIn()
322 var dddType reflect.Type
323 if typ.IsVariadic() {
324 if len(args) < numIn-1 {
325 return reflect.Value{}, fmt.Errorf("wrong number of args for %s: got %d want at least %d", name, len(args), numIn-1)
326 }
327 dddType = typ.In(numIn - 1).Elem()
328 } else {
329 if len(args) != numIn {
330 return reflect.Value{}, fmt.Errorf("wrong number of args for %s: got %d want %d", name, len(args), numIn)
331 }
332 }
333 argv := make([]reflect.Value, len(args))
334 for i, arg := range args {
335 arg = indirectInterface(arg)
336
337 argType := dddType
338 if !typ.IsVariadic() || i < numIn-1 {
339 argType = typ.In(i)
340 }
341
342 var err error
343 if argv[i], err = prepareArg(arg, argType); err != nil {
344 return reflect.Value{}, fmt.Errorf("arg %d: %w", i, err)
345 }
346 }
347 return safeCall(fn, argv)
348 }
349
350
351
352 func safeCall(fun reflect.Value, args []reflect.Value) (val reflect.Value, err error) {
353 defer func() {
354 if r := recover(); r != nil {
355 if e, ok := r.(error); ok {
356 err = e
357 } else {
358 err = fmt.Errorf("%v", r)
359 }
360 }
361 }()
362 ret := fun.Call(args)
363 if len(ret) == 2 && !ret[1].IsNil() {
364 return ret[0], ret[1].Interface().(error)
365 }
366 return ret[0], nil
367 }
368
369
370
371 func truth(arg reflect.Value) bool {
372 t, _ := isTrue(indirectInterface(arg))
373 return t
374 }
375
376
377
378 func and(arg0 reflect.Value, args ...reflect.Value) reflect.Value {
379 panic("unreachable")
380 }
381
382
383
384 func or(arg0 reflect.Value, args ...reflect.Value) reflect.Value {
385 panic("unreachable")
386 }
387
388
389 func not(arg reflect.Value) bool {
390 return !truth(arg)
391 }
392
393
394
395
396
397 var (
398 errBadComparisonType = errors.New("invalid type for comparison")
399 errNoComparison = errors.New("missing argument for comparison")
400 )
401
402 type kind int
403
404 const (
405 invalidKind kind = iota
406 boolKind
407 complexKind
408 intKind
409 floatKind
410 stringKind
411 uintKind
412 )
413
414 func basicKind(v reflect.Value) (kind, error) {
415 switch v.Kind() {
416 case reflect.Bool:
417 return boolKind, nil
418 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
419 return intKind, nil
420 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
421 return uintKind, nil
422 case reflect.Float32, reflect.Float64:
423 return floatKind, nil
424 case reflect.Complex64, reflect.Complex128:
425 return complexKind, nil
426 case reflect.String:
427 return stringKind, nil
428 }
429 return invalidKind, errBadComparisonType
430 }
431
432
433 func isNil(v reflect.Value) bool {
434 if !v.IsValid() {
435 return true
436 }
437 switch v.Kind() {
438 case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Pointer, reflect.Slice:
439 return v.IsNil()
440 }
441 return false
442 }
443
444
445
446 func canCompare(v1, v2 reflect.Value) bool {
447 k1 := v1.Kind()
448 k2 := v2.Kind()
449 if k1 == k2 {
450 return true
451 }
452
453 return k1 == reflect.Invalid || k2 == reflect.Invalid
454 }
455
456
457 func eq(arg1 reflect.Value, arg2 ...reflect.Value) (bool, error) {
458 arg1 = indirectInterface(arg1)
459 if len(arg2) == 0 {
460 return false, errNoComparison
461 }
462 k1, _ := basicKind(arg1)
463 for _, arg := range arg2 {
464 arg = indirectInterface(arg)
465 k2, _ := basicKind(arg)
466 truth := false
467 if k1 != k2 {
468
469 switch {
470 case k1 == intKind && k2 == uintKind:
471 truth = arg1.Int() >= 0 && uint64(arg1.Int()) == arg.Uint()
472 case k1 == uintKind && k2 == intKind:
473 truth = arg.Int() >= 0 && arg1.Uint() == uint64(arg.Int())
474 default:
475 if arg1.IsValid() && arg.IsValid() {
476 return false, fmt.Errorf("incompatible types for comparison: %v and %v", arg1.Type(), arg.Type())
477 }
478 }
479 } else {
480 switch k1 {
481 case boolKind:
482 truth = arg1.Bool() == arg.Bool()
483 case complexKind:
484 truth = arg1.Complex() == arg.Complex()
485 case floatKind:
486 truth = arg1.Float() == arg.Float()
487 case intKind:
488 truth = arg1.Int() == arg.Int()
489 case stringKind:
490 truth = arg1.String() == arg.String()
491 case uintKind:
492 truth = arg1.Uint() == arg.Uint()
493 default:
494 if !canCompare(arg1, arg) {
495 return false, fmt.Errorf("non-comparable types %s: %v, %s: %v", arg1, arg1.Type(), arg.Type(), arg)
496 }
497 if isNil(arg1) || isNil(arg) {
498 truth = isNil(arg) == isNil(arg1)
499 } else {
500 if !arg.Type().Comparable() {
501 return false, fmt.Errorf("non-comparable type %s: %v", arg, arg.Type())
502 }
503 truth = arg1.Interface() == arg.Interface()
504 }
505 }
506 }
507 if truth {
508 return true, nil
509 }
510 }
511 return false, nil
512 }
513
514
515 func ne(arg1, arg2 reflect.Value) (bool, error) {
516
517 equal, err := eq(arg1, arg2)
518 return !equal, err
519 }
520
521
522 func lt(arg1, arg2 reflect.Value) (bool, error) {
523 arg1 = indirectInterface(arg1)
524 k1, err := basicKind(arg1)
525 if err != nil {
526 return false, err
527 }
528 arg2 = indirectInterface(arg2)
529 k2, err := basicKind(arg2)
530 if err != nil {
531 return false, err
532 }
533 truth := false
534 if k1 != k2 {
535
536 switch {
537 case k1 == intKind && k2 == uintKind:
538 truth = arg1.Int() < 0 || uint64(arg1.Int()) < arg2.Uint()
539 case k1 == uintKind && k2 == intKind:
540 truth = arg2.Int() >= 0 && arg1.Uint() < uint64(arg2.Int())
541 default:
542 return false, fmt.Errorf("incompatible types for comparison: %v and %v", arg1.Type(), arg2.Type())
543 }
544 } else {
545 switch k1 {
546 case boolKind, complexKind:
547 return false, errBadComparisonType
548 case floatKind:
549 truth = arg1.Float() < arg2.Float()
550 case intKind:
551 truth = arg1.Int() < arg2.Int()
552 case stringKind:
553 truth = arg1.String() < arg2.String()
554 case uintKind:
555 truth = arg1.Uint() < arg2.Uint()
556 default:
557 panic("invalid kind")
558 }
559 }
560 return truth, nil
561 }
562
563
564 func le(arg1, arg2 reflect.Value) (bool, error) {
565
566 lessThan, err := lt(arg1, arg2)
567 if lessThan || err != nil {
568 return lessThan, err
569 }
570 return eq(arg1, arg2)
571 }
572
573
574 func gt(arg1, arg2 reflect.Value) (bool, error) {
575
576 lessOrEqual, err := le(arg1, arg2)
577 if err != nil {
578 return false, err
579 }
580 return !lessOrEqual, nil
581 }
582
583
584 func ge(arg1, arg2 reflect.Value) (bool, error) {
585
586 lessThan, err := lt(arg1, arg2)
587 if err != nil {
588 return false, err
589 }
590 return !lessThan, nil
591 }
592
593
594
595 var (
596 htmlQuot = []byte(""")
597 htmlApos = []byte("'")
598 htmlAmp = []byte("&")
599 htmlLt = []byte("<")
600 htmlGt = []byte(">")
601 htmlNull = []byte("\uFFFD")
602 )
603
604
605 func HTMLEscape(w io.Writer, b []byte) {
606 last := 0
607 for i, c := range b {
608 var html []byte
609 switch c {
610 case '\000':
611 html = htmlNull
612 case '"':
613 html = htmlQuot
614 case '\'':
615 html = htmlApos
616 case '&':
617 html = htmlAmp
618 case '<':
619 html = htmlLt
620 case '>':
621 html = htmlGt
622 default:
623 continue
624 }
625 w.Write(b[last:i])
626 w.Write(html)
627 last = i + 1
628 }
629 w.Write(b[last:])
630 }
631
632
633 func HTMLEscapeString(s string) string {
634
635 if !strings.ContainsAny(s, "'\"&<>\000") {
636 return s
637 }
638 var b strings.Builder
639 HTMLEscape(&b, []byte(s))
640 return b.String()
641 }
642
643
644
645 func HTMLEscaper(args ...any) string {
646 return HTMLEscapeString(evalArgs(args))
647 }
648
649
650
651 var (
652 jsLowUni = []byte(`\u00`)
653 hex = []byte("0123456789ABCDEF")
654
655 jsBackslash = []byte(`\\`)
656 jsApos = []byte(`\'`)
657 jsQuot = []byte(`\"`)
658 jsLt = []byte(`\u003C`)
659 jsGt = []byte(`\u003E`)
660 jsAmp = []byte(`\u0026`)
661 jsEq = []byte(`\u003D`)
662 )
663
664
665 func JSEscape(w io.Writer, b []byte) {
666 last := 0
667 for i := 0; i < len(b); i++ {
668 c := b[i]
669
670 if !jsIsSpecial(rune(c)) {
671
672 continue
673 }
674 w.Write(b[last:i])
675
676 if c < utf8.RuneSelf {
677
678
679 switch c {
680 case '\\':
681 w.Write(jsBackslash)
682 case '\'':
683 w.Write(jsApos)
684 case '"':
685 w.Write(jsQuot)
686 case '<':
687 w.Write(jsLt)
688 case '>':
689 w.Write(jsGt)
690 case '&':
691 w.Write(jsAmp)
692 case '=':
693 w.Write(jsEq)
694 default:
695 w.Write(jsLowUni)
696 t, b := c>>4, c&0x0f
697 w.Write(hex[t : t+1])
698 w.Write(hex[b : b+1])
699 }
700 } else {
701
702 r, size := utf8.DecodeRune(b[i:])
703 if unicode.IsPrint(r) {
704 w.Write(b[i : i+size])
705 } else {
706 fmt.Fprintf(w, "\\u%04X", r)
707 }
708 i += size - 1
709 }
710 last = i + 1
711 }
712 w.Write(b[last:])
713 }
714
715
716 func JSEscapeString(s string) string {
717
718 if strings.IndexFunc(s, jsIsSpecial) < 0 {
719 return s
720 }
721 var b strings.Builder
722 JSEscape(&b, []byte(s))
723 return b.String()
724 }
725
726 func jsIsSpecial(r rune) bool {
727 switch r {
728 case '\\', '\'', '"', '<', '>', '&', '=':
729 return true
730 }
731 return r < ' ' || utf8.RuneSelf <= r
732 }
733
734
735
736 func JSEscaper(args ...any) string {
737 return JSEscapeString(evalArgs(args))
738 }
739
740
741
742 func URLQueryEscaper(args ...any) string {
743 return url.QueryEscape(evalArgs(args))
744 }
745
746
747
748
749
750
751
752
753 func evalArgs(args []any) string {
754 ok := false
755 var s string
756
757 if len(args) == 1 {
758 s, ok = args[0].(string)
759 }
760 if !ok {
761 for i, arg := range args {
762 a, ok := printableValue(reflect.ValueOf(arg))
763 if ok {
764 args[i] = a
765 }
766 }
767 s = fmt.Sprint(args...)
768 }
769 return s
770 }
771
View as plain text