1
2
3
4
5 package template
6
7 import (
8 "errors"
9 "fmt"
10 "io"
11 "net/url"
12 "reflect"
13 "strings"
14 "sync"
15 "unicode"
16 "unicode/utf8"
17 )
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33 type FuncMap map[string]any
34
35
36
37
38
39 func builtins() FuncMap {
40 return FuncMap{
41 "and": and,
42 "call": emptyCall,
43 "html": HTMLEscaper,
44 "index": index,
45 "slice": slice,
46 "js": JSEscaper,
47 "len": length,
48 "not": not,
49 "or": or,
50 "print": fmt.Sprint,
51 "printf": fmt.Sprintf,
52 "println": fmt.Sprintln,
53 "urlquery": URLQueryEscaper,
54
55
56 "eq": eq,
57 "ge": ge,
58 "gt": gt,
59 "le": le,
60 "lt": lt,
61 "ne": ne,
62 }
63 }
64
65 var builtinFuncsOnce struct {
66 sync.Once
67 v map[string]reflect.Value
68 }
69
70
71
72 func builtinFuncs() map[string]reflect.Value {
73 builtinFuncsOnce.Do(func() {
74 builtinFuncsOnce.v = createValueFuncs(builtins())
75 })
76 return builtinFuncsOnce.v
77 }
78
79
80 func createValueFuncs(funcMap FuncMap) map[string]reflect.Value {
81 m := make(map[string]reflect.Value)
82 addValueFuncs(m, funcMap)
83 return m
84 }
85
86
87 func addValueFuncs(out map[string]reflect.Value, in FuncMap) {
88 for name, fn := range in {
89 if !goodName(name) {
90 panic(fmt.Errorf("function name %q is not a valid identifier", name))
91 }
92 v := reflect.ValueOf(fn)
93 if v.Kind() != reflect.Func {
94 panic("value for " + name + " not a function")
95 }
96 if err := goodFunc(name, v.Type()); err != nil {
97 panic(err)
98 }
99 out[name] = v
100 }
101 }
102
103
104
105 func addFuncs(out, in FuncMap) {
106 for name, fn := range in {
107 out[name] = fn
108 }
109 }
110
111
112 func goodFunc(name string, typ reflect.Type) error {
113
114 switch numOut := typ.NumOut(); {
115 case numOut == 1:
116 return nil
117 case numOut == 2 && typ.Out(1) == errorType:
118 return nil
119 case numOut == 2:
120 return fmt.Errorf("invalid function signature for %s: second return value should be error; is %s", name, typ.Out(1))
121 default:
122 return fmt.Errorf("function %s has %d return values; should be 1 or 2", name, typ.NumOut())
123 }
124 }
125
126
127 func goodName(name string) bool {
128 if name == "" {
129 return false
130 }
131 for i, r := range name {
132 switch {
133 case r == '_':
134 case i == 0 && !unicode.IsLetter(r):
135 return false
136 case !unicode.IsLetter(r) && !unicode.IsDigit(r):
137 return false
138 }
139 }
140 return true
141 }
142
143
144 func findFunction(name string, tmpl *Template) (v reflect.Value, isBuiltin, ok bool) {
145 if tmpl != nil && tmpl.common != nil {
146 tmpl.muFuncs.RLock()
147 defer tmpl.muFuncs.RUnlock()
148 if fn := tmpl.execFuncs[name]; fn.IsValid() {
149 return fn, false, true
150 }
151 }
152 if fn := builtinFuncs()[name]; fn.IsValid() {
153 return fn, true, true
154 }
155 return reflect.Value{}, false, false
156 }
157
158
159
160 func prepareArg(value reflect.Value, argType reflect.Type) (reflect.Value, error) {
161 if !value.IsValid() {
162 if !canBeNil(argType) {
163 return reflect.Value{}, fmt.Errorf("value is nil; should be of type %s", argType)
164 }
165 value = reflect.Zero(argType)
166 }
167 if value.Type().AssignableTo(argType) {
168 return value, nil
169 }
170 if intLike(value.Kind()) && intLike(argType.Kind()) && value.Type().ConvertibleTo(argType) {
171 value = value.Convert(argType)
172 return value, nil
173 }
174 return reflect.Value{}, fmt.Errorf("value has type %s; should be %s", value.Type(), argType)
175 }
176
177 func intLike(typ reflect.Kind) bool {
178 switch typ {
179 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
180 return true
181 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
182 return true
183 }
184 return false
185 }
186
187
188 func indexArg(index reflect.Value, cap int) (int, error) {
189 var x int64
190 switch index.Kind() {
191 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
192 x = index.Int()
193 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
194 x = int64(index.Uint())
195 case reflect.Invalid:
196 return 0, fmt.Errorf("cannot index slice/array with nil")
197 default:
198 return 0, fmt.Errorf("cannot index slice/array with type %s", index.Type())
199 }
200 if x < 0 || int(x) < 0 || int(x) > cap {
201 return 0, fmt.Errorf("index out of range: %d", x)
202 }
203 return int(x), nil
204 }
205
206
207
208
209
210
211 func index(item reflect.Value, indexes ...reflect.Value) (reflect.Value, error) {
212 item = indirectInterface(item)
213 if !item.IsValid() {
214 return reflect.Value{}, fmt.Errorf("index of untyped nil")
215 }
216 for _, index := range indexes {
217 index = indirectInterface(index)
218 var isNil bool
219 if item, isNil = indirect(item); isNil {
220 return reflect.Value{}, fmt.Errorf("index of nil pointer")
221 }
222 switch item.Kind() {
223 case reflect.Array, reflect.Slice, reflect.String:
224 x, err := indexArg(index, item.Len())
225 if err != nil {
226 return reflect.Value{}, err
227 }
228 item = item.Index(x)
229 case reflect.Map:
230 index, err := prepareArg(index, item.Type().Key())
231 if err != nil {
232 return reflect.Value{}, err
233 }
234 if x := item.MapIndex(index); x.IsValid() {
235 item = x
236 } else {
237 item = reflect.Zero(item.Type().Elem())
238 }
239 case reflect.Invalid:
240
241 panic("unreachable")
242 default:
243 return reflect.Value{}, fmt.Errorf("can't index item of type %s", item.Type())
244 }
245 }
246 return item, nil
247 }
248
249
250
251
252
253
254
255 func slice(item reflect.Value, indexes ...reflect.Value) (reflect.Value, error) {
256 item = indirectInterface(item)
257 if !item.IsValid() {
258 return reflect.Value{}, fmt.Errorf("slice of untyped nil")
259 }
260 if len(indexes) > 3 {
261 return reflect.Value{}, fmt.Errorf("too many slice indexes: %d", len(indexes))
262 }
263 var cap int
264 switch item.Kind() {
265 case reflect.String:
266 if len(indexes) == 3 {
267 return reflect.Value{}, fmt.Errorf("cannot 3-index slice a string")
268 }
269 cap = item.Len()
270 case reflect.Array, reflect.Slice:
271 cap = item.Cap()
272 default:
273 return reflect.Value{}, fmt.Errorf("can't slice item of type %s", item.Type())
274 }
275
276 idx := [3]int{0, item.Len()}
277 for i, index := range indexes {
278 x, err := indexArg(index, cap)
279 if err != nil {
280 return reflect.Value{}, err
281 }
282 idx[i] = x
283 }
284
285 if idx[0] > idx[1] {
286 return reflect.Value{}, fmt.Errorf("invalid slice index: %d > %d", idx[0], idx[1])
287 }
288 if len(indexes) < 3 {
289 return item.Slice(idx[0], idx[1]), nil
290 }
291
292 if idx[1] > idx[2] {
293 return reflect.Value{}, fmt.Errorf("invalid slice index: %d > %d", idx[1], idx[2])
294 }
295 return item.Slice3(idx[0], idx[1], idx[2]), nil
296 }
297
298
299
300
301 func length(item reflect.Value) (int, error) {
302 item, isNil := indirect(item)
303 if isNil {
304 return 0, fmt.Errorf("len of nil pointer")
305 }
306 switch item.Kind() {
307 case reflect.Array, reflect.Chan, reflect.Map, reflect.Slice, reflect.String:
308 return item.Len(), nil
309 }
310 return 0, fmt.Errorf("len of type %s", item.Type())
311 }
312
313
314
315 func emptyCall(fn reflect.Value, args ...reflect.Value) reflect.Value {
316 panic("unreachable")
317 }
318
319
320
321 func call(name string, fn reflect.Value, args ...reflect.Value) (reflect.Value, error) {
322 fn = indirectInterface(fn)
323 if !fn.IsValid() {
324 return reflect.Value{}, fmt.Errorf("call of nil")
325 }
326 typ := fn.Type()
327 if typ.Kind() != reflect.Func {
328 return reflect.Value{}, fmt.Errorf("non-function %s of type %s", name, typ)
329 }
330
331 if err := goodFunc(name, typ); err != nil {
332 return reflect.Value{}, err
333 }
334 numIn := typ.NumIn()
335 var dddType reflect.Type
336 if typ.IsVariadic() {
337 if len(args) < numIn-1 {
338 return reflect.Value{}, fmt.Errorf("wrong number of args for %s: got %d want at least %d", name, len(args), numIn-1)
339 }
340 dddType = typ.In(numIn - 1).Elem()
341 } else {
342 if len(args) != numIn {
343 return reflect.Value{}, fmt.Errorf("wrong number of args for %s: got %d want %d", name, len(args), numIn)
344 }
345 }
346 argv := make([]reflect.Value, len(args))
347 for i, arg := range args {
348 arg = indirectInterface(arg)
349
350 argType := dddType
351 if !typ.IsVariadic() || i < numIn-1 {
352 argType = typ.In(i)
353 }
354
355 var err error
356 if argv[i], err = prepareArg(arg, argType); err != nil {
357 return reflect.Value{}, fmt.Errorf("arg %d: %w", i, err)
358 }
359 }
360 return safeCall(fn, argv)
361 }
362
363
364
365 func safeCall(fun reflect.Value, args []reflect.Value) (val reflect.Value, err error) {
366 defer func() {
367 if r := recover(); r != nil {
368 if e, ok := r.(error); ok {
369 err = e
370 } else {
371 err = fmt.Errorf("%v", r)
372 }
373 }
374 }()
375 ret := fun.Call(args)
376 if len(ret) == 2 && !ret[1].IsNil() {
377 return ret[0], ret[1].Interface().(error)
378 }
379 return ret[0], nil
380 }
381
382
383
384 func truth(arg reflect.Value) bool {
385 t, _ := isTrue(indirectInterface(arg))
386 return t
387 }
388
389
390
391 func and(arg0 reflect.Value, args ...reflect.Value) reflect.Value {
392 panic("unreachable")
393 }
394
395
396
397 func or(arg0 reflect.Value, args ...reflect.Value) reflect.Value {
398 panic("unreachable")
399 }
400
401
402 func not(arg reflect.Value) bool {
403 return !truth(arg)
404 }
405
406
407
408
409
410 var (
411 errBadComparisonType = errors.New("invalid type for comparison")
412 errNoComparison = errors.New("missing argument for comparison")
413 )
414
415 type kind int
416
417 const (
418 invalidKind kind = iota
419 boolKind
420 complexKind
421 intKind
422 floatKind
423 stringKind
424 uintKind
425 )
426
427 func basicKind(v reflect.Value) (kind, error) {
428 switch v.Kind() {
429 case reflect.Bool:
430 return boolKind, nil
431 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
432 return intKind, nil
433 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
434 return uintKind, nil
435 case reflect.Float32, reflect.Float64:
436 return floatKind, nil
437 case reflect.Complex64, reflect.Complex128:
438 return complexKind, nil
439 case reflect.String:
440 return stringKind, nil
441 }
442 return invalidKind, errBadComparisonType
443 }
444
445
446 func isNil(v reflect.Value) bool {
447 if !v.IsValid() {
448 return true
449 }
450 switch v.Kind() {
451 case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Pointer, reflect.Slice:
452 return v.IsNil()
453 }
454 return false
455 }
456
457
458
459 func canCompare(v1, v2 reflect.Value) bool {
460 k1 := v1.Kind()
461 k2 := v2.Kind()
462 if k1 == k2 {
463 return true
464 }
465
466 return k1 == reflect.Invalid || k2 == reflect.Invalid
467 }
468
469
470 func eq(arg1 reflect.Value, arg2 ...reflect.Value) (bool, error) {
471 arg1 = indirectInterface(arg1)
472 if len(arg2) == 0 {
473 return false, errNoComparison
474 }
475 k1, _ := basicKind(arg1)
476 for _, arg := range arg2 {
477 arg = indirectInterface(arg)
478 k2, _ := basicKind(arg)
479 truth := false
480 if k1 != k2 {
481
482 switch {
483 case k1 == intKind && k2 == uintKind:
484 truth = arg1.Int() >= 0 && uint64(arg1.Int()) == arg.Uint()
485 case k1 == uintKind && k2 == intKind:
486 truth = arg.Int() >= 0 && arg1.Uint() == uint64(arg.Int())
487 default:
488 if arg1.IsValid() && arg.IsValid() {
489 return false, fmt.Errorf("incompatible types for comparison: %v and %v", arg1.Type(), arg.Type())
490 }
491 }
492 } else {
493 switch k1 {
494 case boolKind:
495 truth = arg1.Bool() == arg.Bool()
496 case complexKind:
497 truth = arg1.Complex() == arg.Complex()
498 case floatKind:
499 truth = arg1.Float() == arg.Float()
500 case intKind:
501 truth = arg1.Int() == arg.Int()
502 case stringKind:
503 truth = arg1.String() == arg.String()
504 case uintKind:
505 truth = arg1.Uint() == arg.Uint()
506 default:
507 if !canCompare(arg1, arg) {
508 return false, fmt.Errorf("non-comparable types %s: %v, %s: %v", arg1, arg1.Type(), arg.Type(), arg)
509 }
510 if isNil(arg1) || isNil(arg) {
511 truth = isNil(arg) == isNil(arg1)
512 } else {
513 if !arg.Type().Comparable() {
514 return false, fmt.Errorf("non-comparable type %s: %v", arg, arg.Type())
515 }
516 truth = arg1.Interface() == arg.Interface()
517 }
518 }
519 }
520 if truth {
521 return true, nil
522 }
523 }
524 return false, nil
525 }
526
527
528 func ne(arg1, arg2 reflect.Value) (bool, error) {
529
530 equal, err := eq(arg1, arg2)
531 return !equal, err
532 }
533
534
535 func lt(arg1, arg2 reflect.Value) (bool, error) {
536 arg1 = indirectInterface(arg1)
537 k1, err := basicKind(arg1)
538 if err != nil {
539 return false, err
540 }
541 arg2 = indirectInterface(arg2)
542 k2, err := basicKind(arg2)
543 if err != nil {
544 return false, err
545 }
546 truth := false
547 if k1 != k2 {
548
549 switch {
550 case k1 == intKind && k2 == uintKind:
551 truth = arg1.Int() < 0 || uint64(arg1.Int()) < arg2.Uint()
552 case k1 == uintKind && k2 == intKind:
553 truth = arg2.Int() >= 0 && arg1.Uint() < uint64(arg2.Int())
554 default:
555 return false, fmt.Errorf("incompatible types for comparison: %v and %v", arg1.Type(), arg2.Type())
556 }
557 } else {
558 switch k1 {
559 case boolKind, complexKind:
560 return false, errBadComparisonType
561 case floatKind:
562 truth = arg1.Float() < arg2.Float()
563 case intKind:
564 truth = arg1.Int() < arg2.Int()
565 case stringKind:
566 truth = arg1.String() < arg2.String()
567 case uintKind:
568 truth = arg1.Uint() < arg2.Uint()
569 default:
570 panic("invalid kind")
571 }
572 }
573 return truth, nil
574 }
575
576
577 func le(arg1, arg2 reflect.Value) (bool, error) {
578
579 lessThan, err := lt(arg1, arg2)
580 if lessThan || err != nil {
581 return lessThan, err
582 }
583 return eq(arg1, arg2)
584 }
585
586
587 func gt(arg1, arg2 reflect.Value) (bool, error) {
588
589 lessOrEqual, err := le(arg1, arg2)
590 if err != nil {
591 return false, err
592 }
593 return !lessOrEqual, nil
594 }
595
596
597 func ge(arg1, arg2 reflect.Value) (bool, error) {
598
599 lessThan, err := lt(arg1, arg2)
600 if err != nil {
601 return false, err
602 }
603 return !lessThan, nil
604 }
605
606
607
608 var (
609 htmlQuot = []byte(""")
610 htmlApos = []byte("'")
611 htmlAmp = []byte("&")
612 htmlLt = []byte("<")
613 htmlGt = []byte(">")
614 htmlNull = []byte("\uFFFD")
615 )
616
617
618 func HTMLEscape(w io.Writer, b []byte) {
619 last := 0
620 for i, c := range b {
621 var html []byte
622 switch c {
623 case '\000':
624 html = htmlNull
625 case '"':
626 html = htmlQuot
627 case '\'':
628 html = htmlApos
629 case '&':
630 html = htmlAmp
631 case '<':
632 html = htmlLt
633 case '>':
634 html = htmlGt
635 default:
636 continue
637 }
638 w.Write(b[last:i])
639 w.Write(html)
640 last = i + 1
641 }
642 w.Write(b[last:])
643 }
644
645
646 func HTMLEscapeString(s string) string {
647
648 if !strings.ContainsAny(s, "'\"&<>\000") {
649 return s
650 }
651 var b strings.Builder
652 HTMLEscape(&b, []byte(s))
653 return b.String()
654 }
655
656
657
658 func HTMLEscaper(args ...any) string {
659 return HTMLEscapeString(evalArgs(args))
660 }
661
662
663
664 var (
665 jsLowUni = []byte(`\u00`)
666 hex = []byte("0123456789ABCDEF")
667
668 jsBackslash = []byte(`\\`)
669 jsApos = []byte(`\'`)
670 jsQuot = []byte(`\"`)
671 jsLt = []byte(`\u003C`)
672 jsGt = []byte(`\u003E`)
673 jsAmp = []byte(`\u0026`)
674 jsEq = []byte(`\u003D`)
675 )
676
677
678 func JSEscape(w io.Writer, b []byte) {
679 last := 0
680 for i := 0; i < len(b); i++ {
681 c := b[i]
682
683 if !jsIsSpecial(rune(c)) {
684
685 continue
686 }
687 w.Write(b[last:i])
688
689 if c < utf8.RuneSelf {
690
691
692 switch c {
693 case '\\':
694 w.Write(jsBackslash)
695 case '\'':
696 w.Write(jsApos)
697 case '"':
698 w.Write(jsQuot)
699 case '<':
700 w.Write(jsLt)
701 case '>':
702 w.Write(jsGt)
703 case '&':
704 w.Write(jsAmp)
705 case '=':
706 w.Write(jsEq)
707 default:
708 w.Write(jsLowUni)
709 t, b := c>>4, c&0x0f
710 w.Write(hex[t : t+1])
711 w.Write(hex[b : b+1])
712 }
713 } else {
714
715 r, size := utf8.DecodeRune(b[i:])
716 if unicode.IsPrint(r) {
717 w.Write(b[i : i+size])
718 } else {
719 fmt.Fprintf(w, "\\u%04X", r)
720 }
721 i += size - 1
722 }
723 last = i + 1
724 }
725 w.Write(b[last:])
726 }
727
728
729 func JSEscapeString(s string) string {
730
731 if strings.IndexFunc(s, jsIsSpecial) < 0 {
732 return s
733 }
734 var b strings.Builder
735 JSEscape(&b, []byte(s))
736 return b.String()
737 }
738
739 func jsIsSpecial(r rune) bool {
740 switch r {
741 case '\\', '\'', '"', '<', '>', '&', '=':
742 return true
743 }
744 return r < ' ' || utf8.RuneSelf <= r
745 }
746
747
748
749 func JSEscaper(args ...any) string {
750 return JSEscapeString(evalArgs(args))
751 }
752
753
754
755 func URLQueryEscaper(args ...any) string {
756 return url.QueryEscape(evalArgs(args))
757 }
758
759
760
761
762
763
764
765
766 func evalArgs(args []any) string {
767 ok := false
768 var s string
769
770 if len(args) == 1 {
771 s, ok = args[0].(string)
772 }
773 if !ok {
774 for i, arg := range args {
775 a, ok := printableValue(reflect.ValueOf(arg))
776 if ok {
777 args[i] = a
778 }
779 }
780 s = fmt.Sprint(args...)
781 }
782 return s
783 }
784
View as plain text