1
2
3
4
5 package template
6
7 import (
8 "errors"
9 "fmt"
10 "io"
11 "net/url"
12 "reflect"
13 "strings"
14 "sync"
15 "unicode"
16 "unicode/utf8"
17 )
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33 type FuncMap map[string]any
34
35
36
37
38
39 func builtins() FuncMap {
40 return FuncMap{
41 "and": and,
42 "call": emptyCall,
43 "html": HTMLEscaper,
44 "index": index,
45 "slice": slice,
46 "js": JSEscaper,
47 "len": length,
48 "not": not,
49 "or": or,
50 "print": fmt.Sprint,
51 "printf": fmt.Sprintf,
52 "println": fmt.Sprintln,
53 "urlquery": URLQueryEscaper,
54
55
56 "eq": eq,
57 "ge": ge,
58 "gt": gt,
59 "le": le,
60 "lt": lt,
61 "ne": ne,
62 }
63 }
64
65
66 var builtinFuncs = sync.OnceValue(func() map[string]reflect.Value {
67 funcMap := builtins()
68 m := make(map[string]reflect.Value, len(funcMap))
69 addValueFuncs(m, funcMap)
70 return m
71 })
72
73
74 func addValueFuncs(out map[string]reflect.Value, in FuncMap) {
75 for name, fn := range in {
76 if !goodName(name) {
77 panic(fmt.Errorf("function name %q is not a valid identifier", name))
78 }
79 v := reflect.ValueOf(fn)
80 if v.Kind() != reflect.Func {
81 panic("value for " + name + " not a function")
82 }
83 if err := goodFunc(name, v.Type()); err != nil {
84 panic(err)
85 }
86 out[name] = v
87 }
88 }
89
90
91
92 func addFuncs(out, in FuncMap) {
93 for name, fn := range in {
94 out[name] = fn
95 }
96 }
97
98
99 func goodFunc(name string, typ reflect.Type) error {
100
101 switch numOut := typ.NumOut(); {
102 case numOut == 1:
103 return nil
104 case numOut == 2 && typ.Out(1) == errorType:
105 return nil
106 case numOut == 2:
107 return fmt.Errorf("invalid function signature for %s: second return value should be error; is %s", name, typ.Out(1))
108 default:
109 return fmt.Errorf("function %s has %d return values; should be 1 or 2", name, typ.NumOut())
110 }
111 }
112
113
114 func goodName(name string) bool {
115 if name == "" {
116 return false
117 }
118 for i, r := range name {
119 switch {
120 case r == '_':
121 case i == 0 && !unicode.IsLetter(r):
122 return false
123 case !unicode.IsLetter(r) && !unicode.IsDigit(r):
124 return false
125 }
126 }
127 return true
128 }
129
130
131 func findFunction(name string, tmpl *Template) (v reflect.Value, isBuiltin, ok bool) {
132 if tmpl != nil && tmpl.common != nil {
133 tmpl.muFuncs.RLock()
134 defer tmpl.muFuncs.RUnlock()
135 if fn := tmpl.execFuncs[name]; fn.IsValid() {
136 return fn, false, true
137 }
138 }
139 if fn := builtinFuncs()[name]; fn.IsValid() {
140 return fn, true, true
141 }
142 return reflect.Value{}, false, false
143 }
144
145
146
147 func prepareArg(value reflect.Value, argType reflect.Type) (reflect.Value, error) {
148 if !value.IsValid() {
149 if !canBeNil(argType) {
150 return reflect.Value{}, fmt.Errorf("value is nil; should be of type %s", argType)
151 }
152 value = reflect.Zero(argType)
153 }
154 if value.Type().AssignableTo(argType) {
155 return value, nil
156 }
157 if intLike(value.Kind()) && intLike(argType.Kind()) && value.Type().ConvertibleTo(argType) {
158 value = value.Convert(argType)
159 return value, nil
160 }
161 return reflect.Value{}, fmt.Errorf("value has type %s; should be %s", value.Type(), argType)
162 }
163
164 func intLike(typ reflect.Kind) bool {
165 switch typ {
166 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
167 return true
168 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
169 return true
170 }
171 return false
172 }
173
174
175 func indexArg(index reflect.Value, cap int) (int, error) {
176 var x int64
177 switch index.Kind() {
178 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
179 x = index.Int()
180 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
181 x = int64(index.Uint())
182 case reflect.Invalid:
183 return 0, fmt.Errorf("cannot index slice/array with nil")
184 default:
185 return 0, fmt.Errorf("cannot index slice/array with type %s", index.Type())
186 }
187 if x < 0 || int(x) < 0 || int(x) > cap {
188 return 0, fmt.Errorf("index out of range: %d", x)
189 }
190 return int(x), nil
191 }
192
193
194
195
196
197
198 func index(item reflect.Value, indexes ...reflect.Value) (reflect.Value, error) {
199 item = indirectInterface(item)
200 if !item.IsValid() {
201 return reflect.Value{}, fmt.Errorf("index of untyped nil")
202 }
203 for _, index := range indexes {
204 index = indirectInterface(index)
205 var isNil bool
206 if item, isNil = indirect(item); isNil {
207 return reflect.Value{}, fmt.Errorf("index of nil pointer")
208 }
209 switch item.Kind() {
210 case reflect.Array, reflect.Slice, reflect.String:
211 x, err := indexArg(index, item.Len())
212 if err != nil {
213 return reflect.Value{}, err
214 }
215 item = item.Index(x)
216 case reflect.Map:
217 index, err := prepareArg(index, item.Type().Key())
218 if err != nil {
219 return reflect.Value{}, err
220 }
221 if x := item.MapIndex(index); x.IsValid() {
222 item = x
223 } else {
224 item = reflect.Zero(item.Type().Elem())
225 }
226 case reflect.Invalid:
227
228 panic("unreachable")
229 default:
230 return reflect.Value{}, fmt.Errorf("can't index item of type %s", item.Type())
231 }
232 }
233 return item, nil
234 }
235
236
237
238
239
240
241
242 func slice(item reflect.Value, indexes ...reflect.Value) (reflect.Value, error) {
243 item = indirectInterface(item)
244 if !item.IsValid() {
245 return reflect.Value{}, fmt.Errorf("slice of untyped nil")
246 }
247 var isNil bool
248 if item, isNil = indirect(item); isNil {
249 return reflect.Value{}, fmt.Errorf("slice of nil pointer")
250 }
251 if len(indexes) > 3 {
252 return reflect.Value{}, fmt.Errorf("too many slice indexes: %d", len(indexes))
253 }
254 var cap int
255 switch item.Kind() {
256 case reflect.String:
257 if len(indexes) == 3 {
258 return reflect.Value{}, fmt.Errorf("cannot 3-index slice a string")
259 }
260 cap = item.Len()
261 case reflect.Array, reflect.Slice:
262 cap = item.Cap()
263 default:
264 return reflect.Value{}, fmt.Errorf("can't slice item of type %s", item.Type())
265 }
266
267 idx := [3]int{0, item.Len()}
268 for i, index := range indexes {
269 x, err := indexArg(index, cap)
270 if err != nil {
271 return reflect.Value{}, err
272 }
273 idx[i] = x
274 }
275
276 if idx[0] > idx[1] {
277 return reflect.Value{}, fmt.Errorf("invalid slice index: %d > %d", idx[0], idx[1])
278 }
279 if len(indexes) < 3 {
280 return item.Slice(idx[0], idx[1]), nil
281 }
282
283 if idx[1] > idx[2] {
284 return reflect.Value{}, fmt.Errorf("invalid slice index: %d > %d", idx[1], idx[2])
285 }
286 return item.Slice3(idx[0], idx[1], idx[2]), nil
287 }
288
289
290
291
292 func length(item reflect.Value) (int, error) {
293 item, isNil := indirect(item)
294 if isNil {
295 return 0, fmt.Errorf("len of nil pointer")
296 }
297 switch item.Kind() {
298 case reflect.Array, reflect.Chan, reflect.Map, reflect.Slice, reflect.String:
299 return item.Len(), nil
300 }
301 return 0, fmt.Errorf("len of type %s", item.Type())
302 }
303
304
305
306 func emptyCall(fn reflect.Value, args ...reflect.Value) reflect.Value {
307 panic("unreachable")
308 }
309
310
311
312 func call(name string, fn reflect.Value, args ...reflect.Value) (reflect.Value, error) {
313 fn = indirectInterface(fn)
314 if !fn.IsValid() {
315 return reflect.Value{}, fmt.Errorf("call of nil")
316 }
317 typ := fn.Type()
318 if typ.Kind() != reflect.Func {
319 return reflect.Value{}, fmt.Errorf("non-function %s of type %s", name, typ)
320 }
321
322 if err := goodFunc(name, typ); err != nil {
323 return reflect.Value{}, err
324 }
325 numIn := typ.NumIn()
326 var dddType reflect.Type
327 if typ.IsVariadic() {
328 if len(args) < numIn-1 {
329 return reflect.Value{}, fmt.Errorf("wrong number of args for %s: got %d want at least %d", name, len(args), numIn-1)
330 }
331 dddType = typ.In(numIn - 1).Elem()
332 } else {
333 if len(args) != numIn {
334 return reflect.Value{}, fmt.Errorf("wrong number of args for %s: got %d want %d", name, len(args), numIn)
335 }
336 }
337 argv := make([]reflect.Value, len(args))
338 for i, arg := range args {
339 arg = indirectInterface(arg)
340
341 argType := dddType
342 if !typ.IsVariadic() || i < numIn-1 {
343 argType = typ.In(i)
344 }
345
346 var err error
347 if argv[i], err = prepareArg(arg, argType); err != nil {
348 return reflect.Value{}, fmt.Errorf("arg %d: %w", i, err)
349 }
350 }
351 return safeCall(fn, argv)
352 }
353
354
355
356 func safeCall(fun reflect.Value, args []reflect.Value) (val reflect.Value, err error) {
357 defer func() {
358 if r := recover(); r != nil {
359 if e, ok := r.(error); ok {
360 err = e
361 } else {
362 err = fmt.Errorf("%v", r)
363 }
364 }
365 }()
366 ret := fun.Call(args)
367 if len(ret) == 2 && !ret[1].IsNil() {
368 return ret[0], ret[1].Interface().(error)
369 }
370 return ret[0], nil
371 }
372
373
374
375 func truth(arg reflect.Value) bool {
376 t, _ := isTrue(indirectInterface(arg))
377 return t
378 }
379
380
381
382 func and(arg0 reflect.Value, args ...reflect.Value) reflect.Value {
383 panic("unreachable")
384 }
385
386
387
388 func or(arg0 reflect.Value, args ...reflect.Value) reflect.Value {
389 panic("unreachable")
390 }
391
392
393 func not(arg reflect.Value) bool {
394 return !truth(arg)
395 }
396
397
398
399
400
401 var (
402 errBadComparisonType = errors.New("invalid type for comparison")
403 errNoComparison = errors.New("missing argument for comparison")
404 )
405
406 type kind int
407
408 const (
409 invalidKind kind = iota
410 boolKind
411 complexKind
412 intKind
413 floatKind
414 stringKind
415 uintKind
416 )
417
418 func basicKind(v reflect.Value) (kind, error) {
419 switch v.Kind() {
420 case reflect.Bool:
421 return boolKind, nil
422 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
423 return intKind, nil
424 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
425 return uintKind, nil
426 case reflect.Float32, reflect.Float64:
427 return floatKind, nil
428 case reflect.Complex64, reflect.Complex128:
429 return complexKind, nil
430 case reflect.String:
431 return stringKind, nil
432 }
433 return invalidKind, errBadComparisonType
434 }
435
436
437 func isNil(v reflect.Value) bool {
438 if !v.IsValid() {
439 return true
440 }
441 switch v.Kind() {
442 case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Pointer, reflect.Slice:
443 return v.IsNil()
444 }
445 return false
446 }
447
448
449
450 func canCompare(v1, v2 reflect.Value) bool {
451 k1 := v1.Kind()
452 k2 := v2.Kind()
453 if k1 == k2 {
454 return true
455 }
456
457 return k1 == reflect.Invalid || k2 == reflect.Invalid
458 }
459
460
461 func eq(arg1 reflect.Value, arg2 ...reflect.Value) (bool, error) {
462 arg1 = indirectInterface(arg1)
463 if len(arg2) == 0 {
464 return false, errNoComparison
465 }
466 k1, _ := basicKind(arg1)
467 for _, arg := range arg2 {
468 arg = indirectInterface(arg)
469 k2, _ := basicKind(arg)
470 truth := false
471 if k1 != k2 {
472
473 switch {
474 case k1 == intKind && k2 == uintKind:
475 truth = arg1.Int() >= 0 && uint64(arg1.Int()) == arg.Uint()
476 case k1 == uintKind && k2 == intKind:
477 truth = arg.Int() >= 0 && arg1.Uint() == uint64(arg.Int())
478 default:
479 if arg1.IsValid() && arg.IsValid() {
480 return false, fmt.Errorf("incompatible types for comparison: %v and %v", arg1.Type(), arg.Type())
481 }
482 }
483 } else {
484 switch k1 {
485 case boolKind:
486 truth = arg1.Bool() == arg.Bool()
487 case complexKind:
488 truth = arg1.Complex() == arg.Complex()
489 case floatKind:
490 truth = arg1.Float() == arg.Float()
491 case intKind:
492 truth = arg1.Int() == arg.Int()
493 case stringKind:
494 truth = arg1.String() == arg.String()
495 case uintKind:
496 truth = arg1.Uint() == arg.Uint()
497 default:
498 if !canCompare(arg1, arg) {
499 return false, fmt.Errorf("non-comparable types %s: %v, %s: %v", arg1, arg1.Type(), arg.Type(), arg)
500 }
501 if isNil(arg1) || isNil(arg) {
502 truth = isNil(arg) == isNil(arg1)
503 } else {
504 if !arg.Type().Comparable() {
505 return false, fmt.Errorf("non-comparable type %s: %v", arg, arg.Type())
506 }
507 truth = arg1.Interface() == arg.Interface()
508 }
509 }
510 }
511 if truth {
512 return true, nil
513 }
514 }
515 return false, nil
516 }
517
518
519 func ne(arg1, arg2 reflect.Value) (bool, error) {
520
521 equal, err := eq(arg1, arg2)
522 return !equal, err
523 }
524
525
526 func lt(arg1, arg2 reflect.Value) (bool, error) {
527 arg1 = indirectInterface(arg1)
528 k1, err := basicKind(arg1)
529 if err != nil {
530 return false, err
531 }
532 arg2 = indirectInterface(arg2)
533 k2, err := basicKind(arg2)
534 if err != nil {
535 return false, err
536 }
537 truth := false
538 if k1 != k2 {
539
540 switch {
541 case k1 == intKind && k2 == uintKind:
542 truth = arg1.Int() < 0 || uint64(arg1.Int()) < arg2.Uint()
543 case k1 == uintKind && k2 == intKind:
544 truth = arg2.Int() >= 0 && arg1.Uint() < uint64(arg2.Int())
545 default:
546 return false, fmt.Errorf("incompatible types for comparison: %v and %v", arg1.Type(), arg2.Type())
547 }
548 } else {
549 switch k1 {
550 case boolKind, complexKind:
551 return false, errBadComparisonType
552 case floatKind:
553 truth = arg1.Float() < arg2.Float()
554 case intKind:
555 truth = arg1.Int() < arg2.Int()
556 case stringKind:
557 truth = arg1.String() < arg2.String()
558 case uintKind:
559 truth = arg1.Uint() < arg2.Uint()
560 default:
561 panic("invalid kind")
562 }
563 }
564 return truth, nil
565 }
566
567
568 func le(arg1, arg2 reflect.Value) (bool, error) {
569
570 lessThan, err := lt(arg1, arg2)
571 if lessThan || err != nil {
572 return lessThan, err
573 }
574 return eq(arg1, arg2)
575 }
576
577
578 func gt(arg1, arg2 reflect.Value) (bool, error) {
579
580 lessOrEqual, err := le(arg1, arg2)
581 if err != nil {
582 return false, err
583 }
584 return !lessOrEqual, nil
585 }
586
587
588 func ge(arg1, arg2 reflect.Value) (bool, error) {
589
590 lessThan, err := lt(arg1, arg2)
591 if err != nil {
592 return false, err
593 }
594 return !lessThan, nil
595 }
596
597
598
599 var (
600 htmlQuot = []byte(""")
601 htmlApos = []byte("'")
602 htmlAmp = []byte("&")
603 htmlLt = []byte("<")
604 htmlGt = []byte(">")
605 htmlNull = []byte("\uFFFD")
606 )
607
608
609 func HTMLEscape(w io.Writer, b []byte) {
610 last := 0
611 for i, c := range b {
612 var html []byte
613 switch c {
614 case '\000':
615 html = htmlNull
616 case '"':
617 html = htmlQuot
618 case '\'':
619 html = htmlApos
620 case '&':
621 html = htmlAmp
622 case '<':
623 html = htmlLt
624 case '>':
625 html = htmlGt
626 default:
627 continue
628 }
629 w.Write(b[last:i])
630 w.Write(html)
631 last = i + 1
632 }
633 w.Write(b[last:])
634 }
635
636
637 func HTMLEscapeString(s string) string {
638
639 if !strings.ContainsAny(s, "'\"&<>\000") {
640 return s
641 }
642 var b strings.Builder
643 HTMLEscape(&b, []byte(s))
644 return b.String()
645 }
646
647
648
649 func HTMLEscaper(args ...any) string {
650 return HTMLEscapeString(evalArgs(args))
651 }
652
653
654
655 var (
656 jsLowUni = []byte(`\u00`)
657 hex = []byte("0123456789ABCDEF")
658
659 jsBackslash = []byte(`\\`)
660 jsApos = []byte(`\'`)
661 jsQuot = []byte(`\"`)
662 jsLt = []byte(`\u003C`)
663 jsGt = []byte(`\u003E`)
664 jsAmp = []byte(`\u0026`)
665 jsEq = []byte(`\u003D`)
666 )
667
668
669 func JSEscape(w io.Writer, b []byte) {
670 last := 0
671 for i := 0; i < len(b); i++ {
672 c := b[i]
673
674 if !jsIsSpecial(rune(c)) {
675
676 continue
677 }
678 w.Write(b[last:i])
679
680 if c < utf8.RuneSelf {
681
682
683 switch c {
684 case '\\':
685 w.Write(jsBackslash)
686 case '\'':
687 w.Write(jsApos)
688 case '"':
689 w.Write(jsQuot)
690 case '<':
691 w.Write(jsLt)
692 case '>':
693 w.Write(jsGt)
694 case '&':
695 w.Write(jsAmp)
696 case '=':
697 w.Write(jsEq)
698 default:
699 w.Write(jsLowUni)
700 t, b := c>>4, c&0x0f
701 w.Write(hex[t : t+1])
702 w.Write(hex[b : b+1])
703 }
704 } else {
705
706 r, size := utf8.DecodeRune(b[i:])
707 if unicode.IsPrint(r) {
708 w.Write(b[i : i+size])
709 } else {
710 fmt.Fprintf(w, "\\u%04X", r)
711 }
712 i += size - 1
713 }
714 last = i + 1
715 }
716 w.Write(b[last:])
717 }
718
719
720 func JSEscapeString(s string) string {
721
722 if strings.IndexFunc(s, jsIsSpecial) < 0 {
723 return s
724 }
725 var b strings.Builder
726 JSEscape(&b, []byte(s))
727 return b.String()
728 }
729
730 func jsIsSpecial(r rune) bool {
731 switch r {
732 case '\\', '\'', '"', '<', '>', '&', '=':
733 return true
734 }
735 return r < ' ' || utf8.RuneSelf <= r
736 }
737
738
739
740 func JSEscaper(args ...any) string {
741 return JSEscapeString(evalArgs(args))
742 }
743
744
745
746 func URLQueryEscaper(args ...any) string {
747 return url.QueryEscape(evalArgs(args))
748 }
749
750
751
752
753
754
755
756
757 func evalArgs(args []any) string {
758 ok := false
759 var s string
760
761 if len(args) == 1 {
762 s, ok = args[0].(string)
763 }
764 if !ok {
765 for i, arg := range args {
766 a, ok := printableValue(reflect.ValueOf(arg))
767 if ok {
768 args[i] = a
769 }
770 }
771 s = fmt.Sprint(args...)
772 }
773 return s
774 }
775
View as plain text