Source file src/text/template/funcs.go

     1  // Copyright 2011 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package template
     6  
     7  import (
     8  	"errors"
     9  	"fmt"
    10  	"io"
    11  	"net/url"
    12  	"reflect"
    13  	"strings"
    14  	"sync"
    15  	"unicode"
    16  	"unicode/utf8"
    17  )
    18  
    19  // FuncMap is the type of the map defining the mapping from names to functions.
    20  // Each function must have either a single return value, or two return values of
    21  // which the second has type error. In that case, if the second (error)
    22  // return value evaluates to non-nil during execution, execution terminates and
    23  // Execute returns that error.
    24  //
    25  // Errors returned by Execute wrap the underlying error; call [errors.As] to
    26  // unwrap them.
    27  //
    28  // When template execution invokes a function with an argument list, that list
    29  // must be assignable to the function's parameter types. Functions meant to
    30  // apply to arguments of arbitrary type can use parameters of type interface{} or
    31  // of type [reflect.Value]. Similarly, functions meant to return a result of arbitrary
    32  // type can return interface{} or [reflect.Value].
    33  type FuncMap map[string]any
    34  
    35  // builtins returns the FuncMap.
    36  // It is not a global variable so the linker can dead code eliminate
    37  // more when this isn't called. See golang.org/issue/36021.
    38  // TODO: revert this back to a global map once golang.org/issue/2559 is fixed.
    39  func builtins() FuncMap {
    40  	return FuncMap{
    41  		"and":      and,
    42  		"call":     emptyCall,
    43  		"html":     HTMLEscaper,
    44  		"index":    index,
    45  		"slice":    slice,
    46  		"js":       JSEscaper,
    47  		"len":      length,
    48  		"not":      not,
    49  		"or":       or,
    50  		"print":    fmt.Sprint,
    51  		"printf":   fmt.Sprintf,
    52  		"println":  fmt.Sprintln,
    53  		"urlquery": URLQueryEscaper,
    54  
    55  		// Comparisons
    56  		"eq": eq, // ==
    57  		"ge": ge, // >=
    58  		"gt": gt, // >
    59  		"le": le, // <=
    60  		"lt": lt, // <
    61  		"ne": ne, // !=
    62  	}
    63  }
    64  
    65  var builtinFuncsOnce struct {
    66  	sync.Once
    67  	v map[string]reflect.Value
    68  }
    69  
    70  // builtinFuncsOnce lazily computes & caches the builtinFuncs map.
    71  // TODO: revert this back to a global map once golang.org/issue/2559 is fixed.
    72  func builtinFuncs() map[string]reflect.Value {
    73  	builtinFuncsOnce.Do(func() {
    74  		builtinFuncsOnce.v = createValueFuncs(builtins())
    75  	})
    76  	return builtinFuncsOnce.v
    77  }
    78  
    79  // createValueFuncs turns a FuncMap into a map[string]reflect.Value
    80  func createValueFuncs(funcMap FuncMap) map[string]reflect.Value {
    81  	m := make(map[string]reflect.Value)
    82  	addValueFuncs(m, funcMap)
    83  	return m
    84  }
    85  
    86  // addValueFuncs adds to values the functions in funcs, converting them to reflect.Values.
    87  func addValueFuncs(out map[string]reflect.Value, in FuncMap) {
    88  	for name, fn := range in {
    89  		if !goodName(name) {
    90  			panic(fmt.Errorf("function name %q is not a valid identifier", name))
    91  		}
    92  		v := reflect.ValueOf(fn)
    93  		if v.Kind() != reflect.Func {
    94  			panic("value for " + name + " not a function")
    95  		}
    96  		if err := goodFunc(name, v.Type()); err != nil {
    97  			panic(err)
    98  		}
    99  		out[name] = v
   100  	}
   101  }
   102  
   103  // addFuncs adds to values the functions in funcs. It does no checking of the input -
   104  // call addValueFuncs first.
   105  func addFuncs(out, in FuncMap) {
   106  	for name, fn := range in {
   107  		out[name] = fn
   108  	}
   109  }
   110  
   111  // goodFunc reports whether the function or method has the right result signature.
   112  func goodFunc(name string, typ reflect.Type) error {
   113  	// We allow functions with 1 result or 2 results where the second is an error.
   114  	switch numOut := typ.NumOut(); {
   115  	case numOut == 1:
   116  		return nil
   117  	case numOut == 2 && typ.Out(1) == errorType:
   118  		return nil
   119  	case numOut == 2:
   120  		return fmt.Errorf("invalid function signature for %s: second return value should be error; is %s", name, typ.Out(1))
   121  	default:
   122  		return fmt.Errorf("function %s has %d return values; should be 1 or 2", name, typ.NumOut())
   123  	}
   124  }
   125  
   126  // goodName reports whether the function name is a valid identifier.
   127  func goodName(name string) bool {
   128  	if name == "" {
   129  		return false
   130  	}
   131  	for i, r := range name {
   132  		switch {
   133  		case r == '_':
   134  		case i == 0 && !unicode.IsLetter(r):
   135  			return false
   136  		case !unicode.IsLetter(r) && !unicode.IsDigit(r):
   137  			return false
   138  		}
   139  	}
   140  	return true
   141  }
   142  
   143  // findFunction looks for a function in the template, and global map.
   144  func findFunction(name string, tmpl *Template) (v reflect.Value, isBuiltin, ok bool) {
   145  	if tmpl != nil && tmpl.common != nil {
   146  		tmpl.muFuncs.RLock()
   147  		defer tmpl.muFuncs.RUnlock()
   148  		if fn := tmpl.execFuncs[name]; fn.IsValid() {
   149  			return fn, false, true
   150  		}
   151  	}
   152  	if fn := builtinFuncs()[name]; fn.IsValid() {
   153  		return fn, true, true
   154  	}
   155  	return reflect.Value{}, false, false
   156  }
   157  
   158  // prepareArg checks if value can be used as an argument of type argType, and
   159  // converts an invalid value to appropriate zero if possible.
   160  func prepareArg(value reflect.Value, argType reflect.Type) (reflect.Value, error) {
   161  	if !value.IsValid() {
   162  		if !canBeNil(argType) {
   163  			return reflect.Value{}, fmt.Errorf("value is nil; should be of type %s", argType)
   164  		}
   165  		value = reflect.Zero(argType)
   166  	}
   167  	if value.Type().AssignableTo(argType) {
   168  		return value, nil
   169  	}
   170  	if intLike(value.Kind()) && intLike(argType.Kind()) && value.Type().ConvertibleTo(argType) {
   171  		value = value.Convert(argType)
   172  		return value, nil
   173  	}
   174  	return reflect.Value{}, fmt.Errorf("value has type %s; should be %s", value.Type(), argType)
   175  }
   176  
   177  func intLike(typ reflect.Kind) bool {
   178  	switch typ {
   179  	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
   180  		return true
   181  	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
   182  		return true
   183  	}
   184  	return false
   185  }
   186  
   187  // indexArg checks if a reflect.Value can be used as an index, and converts it to int if possible.
   188  func indexArg(index reflect.Value, cap int) (int, error) {
   189  	var x int64
   190  	switch index.Kind() {
   191  	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
   192  		x = index.Int()
   193  	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
   194  		x = int64(index.Uint())
   195  	case reflect.Invalid:
   196  		return 0, fmt.Errorf("cannot index slice/array with nil")
   197  	default:
   198  		return 0, fmt.Errorf("cannot index slice/array with type %s", index.Type())
   199  	}
   200  	if x < 0 || int(x) < 0 || int(x) > cap {
   201  		return 0, fmt.Errorf("index out of range: %d", x)
   202  	}
   203  	return int(x), nil
   204  }
   205  
   206  // Indexing.
   207  
   208  // index returns the result of indexing its first argument by the following
   209  // arguments. Thus "index x 1 2 3" is, in Go syntax, x[1][2][3]. Each
   210  // indexed item must be a map, slice, or array.
   211  func index(item reflect.Value, indexes ...reflect.Value) (reflect.Value, error) {
   212  	item = indirectInterface(item)
   213  	if !item.IsValid() {
   214  		return reflect.Value{}, fmt.Errorf("index of untyped nil")
   215  	}
   216  	for _, index := range indexes {
   217  		index = indirectInterface(index)
   218  		var isNil bool
   219  		if item, isNil = indirect(item); isNil {
   220  			return reflect.Value{}, fmt.Errorf("index of nil pointer")
   221  		}
   222  		switch item.Kind() {
   223  		case reflect.Array, reflect.Slice, reflect.String:
   224  			x, err := indexArg(index, item.Len())
   225  			if err != nil {
   226  				return reflect.Value{}, err
   227  			}
   228  			item = item.Index(x)
   229  		case reflect.Map:
   230  			index, err := prepareArg(index, item.Type().Key())
   231  			if err != nil {
   232  				return reflect.Value{}, err
   233  			}
   234  			if x := item.MapIndex(index); x.IsValid() {
   235  				item = x
   236  			} else {
   237  				item = reflect.Zero(item.Type().Elem())
   238  			}
   239  		case reflect.Invalid:
   240  			// the loop holds invariant: item.IsValid()
   241  			panic("unreachable")
   242  		default:
   243  			return reflect.Value{}, fmt.Errorf("can't index item of type %s", item.Type())
   244  		}
   245  	}
   246  	return item, nil
   247  }
   248  
   249  // Slicing.
   250  
   251  // slice returns the result of slicing its first argument by the remaining
   252  // arguments. Thus "slice x 1 2" is, in Go syntax, x[1:2], while "slice x"
   253  // is x[:], "slice x 1" is x[1:], and "slice x 1 2 3" is x[1:2:3]. The first
   254  // argument must be a string, slice, or array.
   255  func slice(item reflect.Value, indexes ...reflect.Value) (reflect.Value, error) {
   256  	item = indirectInterface(item)
   257  	if !item.IsValid() {
   258  		return reflect.Value{}, fmt.Errorf("slice of untyped nil")
   259  	}
   260  	if len(indexes) > 3 {
   261  		return reflect.Value{}, fmt.Errorf("too many slice indexes: %d", len(indexes))
   262  	}
   263  	var cap int
   264  	switch item.Kind() {
   265  	case reflect.String:
   266  		if len(indexes) == 3 {
   267  			return reflect.Value{}, fmt.Errorf("cannot 3-index slice a string")
   268  		}
   269  		cap = item.Len()
   270  	case reflect.Array, reflect.Slice:
   271  		cap = item.Cap()
   272  	default:
   273  		return reflect.Value{}, fmt.Errorf("can't slice item of type %s", item.Type())
   274  	}
   275  	// set default values for cases item[:], item[i:].
   276  	idx := [3]int{0, item.Len()}
   277  	for i, index := range indexes {
   278  		x, err := indexArg(index, cap)
   279  		if err != nil {
   280  			return reflect.Value{}, err
   281  		}
   282  		idx[i] = x
   283  	}
   284  	// given item[i:j], make sure i <= j.
   285  	if idx[0] > idx[1] {
   286  		return reflect.Value{}, fmt.Errorf("invalid slice index: %d > %d", idx[0], idx[1])
   287  	}
   288  	if len(indexes) < 3 {
   289  		return item.Slice(idx[0], idx[1]), nil
   290  	}
   291  	// given item[i:j:k], make sure i <= j <= k.
   292  	if idx[1] > idx[2] {
   293  		return reflect.Value{}, fmt.Errorf("invalid slice index: %d > %d", idx[1], idx[2])
   294  	}
   295  	return item.Slice3(idx[0], idx[1], idx[2]), nil
   296  }
   297  
   298  // Length
   299  
   300  // length returns the length of the item, with an error if it has no defined length.
   301  func length(item reflect.Value) (int, error) {
   302  	item, isNil := indirect(item)
   303  	if isNil {
   304  		return 0, fmt.Errorf("len of nil pointer")
   305  	}
   306  	switch item.Kind() {
   307  	case reflect.Array, reflect.Chan, reflect.Map, reflect.Slice, reflect.String:
   308  		return item.Len(), nil
   309  	}
   310  	return 0, fmt.Errorf("len of type %s", item.Type())
   311  }
   312  
   313  // Function invocation
   314  
   315  func emptyCall(fn reflect.Value, args ...reflect.Value) reflect.Value {
   316  	panic("unreachable") // implemented as a special case in evalCall
   317  }
   318  
   319  // call returns the result of evaluating the first argument as a function.
   320  // The function must return 1 result, or 2 results, the second of which is an error.
   321  func call(name string, fn reflect.Value, args ...reflect.Value) (reflect.Value, error) {
   322  	fn = indirectInterface(fn)
   323  	if !fn.IsValid() {
   324  		return reflect.Value{}, fmt.Errorf("call of nil")
   325  	}
   326  	typ := fn.Type()
   327  	if typ.Kind() != reflect.Func {
   328  		return reflect.Value{}, fmt.Errorf("non-function %s of type %s", name, typ)
   329  	}
   330  
   331  	if err := goodFunc(name, typ); err != nil {
   332  		return reflect.Value{}, err
   333  	}
   334  	numIn := typ.NumIn()
   335  	var dddType reflect.Type
   336  	if typ.IsVariadic() {
   337  		if len(args) < numIn-1 {
   338  			return reflect.Value{}, fmt.Errorf("wrong number of args for %s: got %d want at least %d", name, len(args), numIn-1)
   339  		}
   340  		dddType = typ.In(numIn - 1).Elem()
   341  	} else {
   342  		if len(args) != numIn {
   343  			return reflect.Value{}, fmt.Errorf("wrong number of args for %s: got %d want %d", name, len(args), numIn)
   344  		}
   345  	}
   346  	argv := make([]reflect.Value, len(args))
   347  	for i, arg := range args {
   348  		arg = indirectInterface(arg)
   349  		// Compute the expected type. Clumsy because of variadics.
   350  		argType := dddType
   351  		if !typ.IsVariadic() || i < numIn-1 {
   352  			argType = typ.In(i)
   353  		}
   354  
   355  		var err error
   356  		if argv[i], err = prepareArg(arg, argType); err != nil {
   357  			return reflect.Value{}, fmt.Errorf("arg %d: %w", i, err)
   358  		}
   359  	}
   360  	return safeCall(fn, argv)
   361  }
   362  
   363  // safeCall runs fun.Call(args), and returns the resulting value and error, if
   364  // any. If the call panics, the panic value is returned as an error.
   365  func safeCall(fun reflect.Value, args []reflect.Value) (val reflect.Value, err error) {
   366  	defer func() {
   367  		if r := recover(); r != nil {
   368  			if e, ok := r.(error); ok {
   369  				err = e
   370  			} else {
   371  				err = fmt.Errorf("%v", r)
   372  			}
   373  		}
   374  	}()
   375  	ret := fun.Call(args)
   376  	if len(ret) == 2 && !ret[1].IsNil() {
   377  		return ret[0], ret[1].Interface().(error)
   378  	}
   379  	return ret[0], nil
   380  }
   381  
   382  // Boolean logic.
   383  
   384  func truth(arg reflect.Value) bool {
   385  	t, _ := isTrue(indirectInterface(arg))
   386  	return t
   387  }
   388  
   389  // and computes the Boolean AND of its arguments, returning
   390  // the first false argument it encounters, or the last argument.
   391  func and(arg0 reflect.Value, args ...reflect.Value) reflect.Value {
   392  	panic("unreachable") // implemented as a special case in evalCall
   393  }
   394  
   395  // or computes the Boolean OR of its arguments, returning
   396  // the first true argument it encounters, or the last argument.
   397  func or(arg0 reflect.Value, args ...reflect.Value) reflect.Value {
   398  	panic("unreachable") // implemented as a special case in evalCall
   399  }
   400  
   401  // not returns the Boolean negation of its argument.
   402  func not(arg reflect.Value) bool {
   403  	return !truth(arg)
   404  }
   405  
   406  // Comparison.
   407  
   408  // TODO: Perhaps allow comparison between signed and unsigned integers.
   409  
   410  var (
   411  	errBadComparisonType = errors.New("invalid type for comparison")
   412  	errNoComparison      = errors.New("missing argument for comparison")
   413  )
   414  
   415  type kind int
   416  
   417  const (
   418  	invalidKind kind = iota
   419  	boolKind
   420  	complexKind
   421  	intKind
   422  	floatKind
   423  	stringKind
   424  	uintKind
   425  )
   426  
   427  func basicKind(v reflect.Value) (kind, error) {
   428  	switch v.Kind() {
   429  	case reflect.Bool:
   430  		return boolKind, nil
   431  	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
   432  		return intKind, nil
   433  	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
   434  		return uintKind, nil
   435  	case reflect.Float32, reflect.Float64:
   436  		return floatKind, nil
   437  	case reflect.Complex64, reflect.Complex128:
   438  		return complexKind, nil
   439  	case reflect.String:
   440  		return stringKind, nil
   441  	}
   442  	return invalidKind, errBadComparisonType
   443  }
   444  
   445  // isNil returns true if v is the zero reflect.Value, or nil of its type.
   446  func isNil(v reflect.Value) bool {
   447  	if !v.IsValid() {
   448  		return true
   449  	}
   450  	switch v.Kind() {
   451  	case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Pointer, reflect.Slice:
   452  		return v.IsNil()
   453  	}
   454  	return false
   455  }
   456  
   457  // canCompare reports whether v1 and v2 are both the same kind, or one is nil.
   458  // Called only when dealing with nillable types, or there's about to be an error.
   459  func canCompare(v1, v2 reflect.Value) bool {
   460  	k1 := v1.Kind()
   461  	k2 := v2.Kind()
   462  	if k1 == k2 {
   463  		return true
   464  	}
   465  	// We know the type can be compared to nil.
   466  	return k1 == reflect.Invalid || k2 == reflect.Invalid
   467  }
   468  
   469  // eq evaluates the comparison a == b || a == c || ...
   470  func eq(arg1 reflect.Value, arg2 ...reflect.Value) (bool, error) {
   471  	arg1 = indirectInterface(arg1)
   472  	if len(arg2) == 0 {
   473  		return false, errNoComparison
   474  	}
   475  	k1, _ := basicKind(arg1)
   476  	for _, arg := range arg2 {
   477  		arg = indirectInterface(arg)
   478  		k2, _ := basicKind(arg)
   479  		truth := false
   480  		if k1 != k2 {
   481  			// Special case: Can compare integer values regardless of type's sign.
   482  			switch {
   483  			case k1 == intKind && k2 == uintKind:
   484  				truth = arg1.Int() >= 0 && uint64(arg1.Int()) == arg.Uint()
   485  			case k1 == uintKind && k2 == intKind:
   486  				truth = arg.Int() >= 0 && arg1.Uint() == uint64(arg.Int())
   487  			default:
   488  				if arg1.IsValid() && arg.IsValid() {
   489  					return false, fmt.Errorf("incompatible types for comparison: %v and %v", arg1.Type(), arg.Type())
   490  				}
   491  			}
   492  		} else {
   493  			switch k1 {
   494  			case boolKind:
   495  				truth = arg1.Bool() == arg.Bool()
   496  			case complexKind:
   497  				truth = arg1.Complex() == arg.Complex()
   498  			case floatKind:
   499  				truth = arg1.Float() == arg.Float()
   500  			case intKind:
   501  				truth = arg1.Int() == arg.Int()
   502  			case stringKind:
   503  				truth = arg1.String() == arg.String()
   504  			case uintKind:
   505  				truth = arg1.Uint() == arg.Uint()
   506  			default:
   507  				if !canCompare(arg1, arg) {
   508  					return false, fmt.Errorf("non-comparable types %s: %v, %s: %v", arg1, arg1.Type(), arg.Type(), arg)
   509  				}
   510  				if isNil(arg1) || isNil(arg) {
   511  					truth = isNil(arg) == isNil(arg1)
   512  				} else {
   513  					if !arg.Type().Comparable() {
   514  						return false, fmt.Errorf("non-comparable type %s: %v", arg, arg.Type())
   515  					}
   516  					truth = arg1.Interface() == arg.Interface()
   517  				}
   518  			}
   519  		}
   520  		if truth {
   521  			return true, nil
   522  		}
   523  	}
   524  	return false, nil
   525  }
   526  
   527  // ne evaluates the comparison a != b.
   528  func ne(arg1, arg2 reflect.Value) (bool, error) {
   529  	// != is the inverse of ==.
   530  	equal, err := eq(arg1, arg2)
   531  	return !equal, err
   532  }
   533  
   534  // lt evaluates the comparison a < b.
   535  func lt(arg1, arg2 reflect.Value) (bool, error) {
   536  	arg1 = indirectInterface(arg1)
   537  	k1, err := basicKind(arg1)
   538  	if err != nil {
   539  		return false, err
   540  	}
   541  	arg2 = indirectInterface(arg2)
   542  	k2, err := basicKind(arg2)
   543  	if err != nil {
   544  		return false, err
   545  	}
   546  	truth := false
   547  	if k1 != k2 {
   548  		// Special case: Can compare integer values regardless of type's sign.
   549  		switch {
   550  		case k1 == intKind && k2 == uintKind:
   551  			truth = arg1.Int() < 0 || uint64(arg1.Int()) < arg2.Uint()
   552  		case k1 == uintKind && k2 == intKind:
   553  			truth = arg2.Int() >= 0 && arg1.Uint() < uint64(arg2.Int())
   554  		default:
   555  			return false, fmt.Errorf("incompatible types for comparison: %v and %v", arg1.Type(), arg2.Type())
   556  		}
   557  	} else {
   558  		switch k1 {
   559  		case boolKind, complexKind:
   560  			return false, errBadComparisonType
   561  		case floatKind:
   562  			truth = arg1.Float() < arg2.Float()
   563  		case intKind:
   564  			truth = arg1.Int() < arg2.Int()
   565  		case stringKind:
   566  			truth = arg1.String() < arg2.String()
   567  		case uintKind:
   568  			truth = arg1.Uint() < arg2.Uint()
   569  		default:
   570  			panic("invalid kind")
   571  		}
   572  	}
   573  	return truth, nil
   574  }
   575  
   576  // le evaluates the comparison <= b.
   577  func le(arg1, arg2 reflect.Value) (bool, error) {
   578  	// <= is < or ==.
   579  	lessThan, err := lt(arg1, arg2)
   580  	if lessThan || err != nil {
   581  		return lessThan, err
   582  	}
   583  	return eq(arg1, arg2)
   584  }
   585  
   586  // gt evaluates the comparison a > b.
   587  func gt(arg1, arg2 reflect.Value) (bool, error) {
   588  	// > is the inverse of <=.
   589  	lessOrEqual, err := le(arg1, arg2)
   590  	if err != nil {
   591  		return false, err
   592  	}
   593  	return !lessOrEqual, nil
   594  }
   595  
   596  // ge evaluates the comparison a >= b.
   597  func ge(arg1, arg2 reflect.Value) (bool, error) {
   598  	// >= is the inverse of <.
   599  	lessThan, err := lt(arg1, arg2)
   600  	if err != nil {
   601  		return false, err
   602  	}
   603  	return !lessThan, nil
   604  }
   605  
   606  // HTML escaping.
   607  
   608  var (
   609  	htmlQuot = []byte("&#34;") // shorter than "&quot;"
   610  	htmlApos = []byte("&#39;") // shorter than "&apos;" and apos was not in HTML until HTML5
   611  	htmlAmp  = []byte("&amp;")
   612  	htmlLt   = []byte("&lt;")
   613  	htmlGt   = []byte("&gt;")
   614  	htmlNull = []byte("\uFFFD")
   615  )
   616  
   617  // HTMLEscape writes to w the escaped HTML equivalent of the plain text data b.
   618  func HTMLEscape(w io.Writer, b []byte) {
   619  	last := 0
   620  	for i, c := range b {
   621  		var html []byte
   622  		switch c {
   623  		case '\000':
   624  			html = htmlNull
   625  		case '"':
   626  			html = htmlQuot
   627  		case '\'':
   628  			html = htmlApos
   629  		case '&':
   630  			html = htmlAmp
   631  		case '<':
   632  			html = htmlLt
   633  		case '>':
   634  			html = htmlGt
   635  		default:
   636  			continue
   637  		}
   638  		w.Write(b[last:i])
   639  		w.Write(html)
   640  		last = i + 1
   641  	}
   642  	w.Write(b[last:])
   643  }
   644  
   645  // HTMLEscapeString returns the escaped HTML equivalent of the plain text data s.
   646  func HTMLEscapeString(s string) string {
   647  	// Avoid allocation if we can.
   648  	if !strings.ContainsAny(s, "'\"&<>\000") {
   649  		return s
   650  	}
   651  	var b strings.Builder
   652  	HTMLEscape(&b, []byte(s))
   653  	return b.String()
   654  }
   655  
   656  // HTMLEscaper returns the escaped HTML equivalent of the textual
   657  // representation of its arguments.
   658  func HTMLEscaper(args ...any) string {
   659  	return HTMLEscapeString(evalArgs(args))
   660  }
   661  
   662  // JavaScript escaping.
   663  
   664  var (
   665  	jsLowUni = []byte(`\u00`)
   666  	hex      = []byte("0123456789ABCDEF")
   667  
   668  	jsBackslash = []byte(`\\`)
   669  	jsApos      = []byte(`\'`)
   670  	jsQuot      = []byte(`\"`)
   671  	jsLt        = []byte(`\u003C`)
   672  	jsGt        = []byte(`\u003E`)
   673  	jsAmp       = []byte(`\u0026`)
   674  	jsEq        = []byte(`\u003D`)
   675  )
   676  
   677  // JSEscape writes to w the escaped JavaScript equivalent of the plain text data b.
   678  func JSEscape(w io.Writer, b []byte) {
   679  	last := 0
   680  	for i := 0; i < len(b); i++ {
   681  		c := b[i]
   682  
   683  		if !jsIsSpecial(rune(c)) {
   684  			// fast path: nothing to do
   685  			continue
   686  		}
   687  		w.Write(b[last:i])
   688  
   689  		if c < utf8.RuneSelf {
   690  			// Quotes, slashes and angle brackets get quoted.
   691  			// Control characters get written as \u00XX.
   692  			switch c {
   693  			case '\\':
   694  				w.Write(jsBackslash)
   695  			case '\'':
   696  				w.Write(jsApos)
   697  			case '"':
   698  				w.Write(jsQuot)
   699  			case '<':
   700  				w.Write(jsLt)
   701  			case '>':
   702  				w.Write(jsGt)
   703  			case '&':
   704  				w.Write(jsAmp)
   705  			case '=':
   706  				w.Write(jsEq)
   707  			default:
   708  				w.Write(jsLowUni)
   709  				t, b := c>>4, c&0x0f
   710  				w.Write(hex[t : t+1])
   711  				w.Write(hex[b : b+1])
   712  			}
   713  		} else {
   714  			// Unicode rune.
   715  			r, size := utf8.DecodeRune(b[i:])
   716  			if unicode.IsPrint(r) {
   717  				w.Write(b[i : i+size])
   718  			} else {
   719  				fmt.Fprintf(w, "\\u%04X", r)
   720  			}
   721  			i += size - 1
   722  		}
   723  		last = i + 1
   724  	}
   725  	w.Write(b[last:])
   726  }
   727  
   728  // JSEscapeString returns the escaped JavaScript equivalent of the plain text data s.
   729  func JSEscapeString(s string) string {
   730  	// Avoid allocation if we can.
   731  	if strings.IndexFunc(s, jsIsSpecial) < 0 {
   732  		return s
   733  	}
   734  	var b strings.Builder
   735  	JSEscape(&b, []byte(s))
   736  	return b.String()
   737  }
   738  
   739  func jsIsSpecial(r rune) bool {
   740  	switch r {
   741  	case '\\', '\'', '"', '<', '>', '&', '=':
   742  		return true
   743  	}
   744  	return r < ' ' || utf8.RuneSelf <= r
   745  }
   746  
   747  // JSEscaper returns the escaped JavaScript equivalent of the textual
   748  // representation of its arguments.
   749  func JSEscaper(args ...any) string {
   750  	return JSEscapeString(evalArgs(args))
   751  }
   752  
   753  // URLQueryEscaper returns the escaped value of the textual representation of
   754  // its arguments in a form suitable for embedding in a URL query.
   755  func URLQueryEscaper(args ...any) string {
   756  	return url.QueryEscape(evalArgs(args))
   757  }
   758  
   759  // evalArgs formats the list of arguments into a string. It is therefore equivalent to
   760  //
   761  //	fmt.Sprint(args...)
   762  //
   763  // except that each argument is indirected (if a pointer), as required,
   764  // using the same rules as the default string evaluation during template
   765  // execution.
   766  func evalArgs(args []any) string {
   767  	ok := false
   768  	var s string
   769  	// Fast path for simple common case.
   770  	if len(args) == 1 {
   771  		s, ok = args[0].(string)
   772  	}
   773  	if !ok {
   774  		for i, arg := range args {
   775  			a, ok := printableValue(reflect.ValueOf(arg))
   776  			if ok {
   777  				args[i] = a
   778  			} // else let fmt do its thing
   779  		}
   780  		s = fmt.Sprint(args...)
   781  	}
   782  	return s
   783  }
   784  

View as plain text