Source file src/text/template/funcs.go

     1  // Copyright 2011 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package template
     6  
     7  import (
     8  	"errors"
     9  	"fmt"
    10  	"io"
    11  	"net/url"
    12  	"reflect"
    13  	"strings"
    14  	"sync"
    15  	"unicode"
    16  	"unicode/utf8"
    17  )
    18  
    19  // FuncMap is the type of the map defining the mapping from names to functions.
    20  // Each function must have either a single return value, or two return values of
    21  // which the second has type error. In that case, if the second (error)
    22  // return value evaluates to non-nil during execution, execution terminates and
    23  // Execute returns that error.
    24  //
    25  // Errors returned by Execute wrap the underlying error; call [errors.AsType] to
    26  // unwrap them.
    27  //
    28  // When template execution invokes a function with an argument list, that list
    29  // must be assignable to the function's parameter types. Functions meant to
    30  // apply to arguments of arbitrary type can use parameters of type interface{} or
    31  // of type [reflect.Value]. Similarly, functions meant to return a result of arbitrary
    32  // type can return interface{} or [reflect.Value].
    33  type FuncMap map[string]any
    34  
    35  // builtins returns the FuncMap.
    36  // It is not a global variable so the linker can dead code eliminate
    37  // more when this isn't called. See golang.org/issue/36021.
    38  // TODO: revert this back to a global map once golang.org/issue/2559 is fixed.
    39  func builtins() FuncMap {
    40  	return FuncMap{
    41  		"and":      and,
    42  		"call":     emptyCall,
    43  		"html":     HTMLEscaper,
    44  		"index":    index,
    45  		"slice":    slice,
    46  		"js":       JSEscaper,
    47  		"len":      length,
    48  		"not":      not,
    49  		"or":       or,
    50  		"print":    fmt.Sprint,
    51  		"printf":   fmt.Sprintf,
    52  		"println":  fmt.Sprintln,
    53  		"urlquery": URLQueryEscaper,
    54  
    55  		// Comparisons
    56  		"eq": eq, // ==
    57  		"ge": ge, // >=
    58  		"gt": gt, // >
    59  		"le": le, // <=
    60  		"lt": lt, // <
    61  		"ne": ne, // !=
    62  	}
    63  }
    64  
    65  // builtinFuncs lazily computes & caches the builtinFuncs map.
    66  var builtinFuncs = sync.OnceValue(func() map[string]reflect.Value {
    67  	funcMap := builtins()
    68  	m := make(map[string]reflect.Value, len(funcMap))
    69  	addValueFuncs(m, funcMap)
    70  	return m
    71  })
    72  
    73  // addValueFuncs adds to values the functions in funcs, converting them to reflect.Values.
    74  func addValueFuncs(out map[string]reflect.Value, in FuncMap) {
    75  	for name, fn := range in {
    76  		if !goodName(name) {
    77  			panic(fmt.Errorf("function name %q is not a valid identifier", name))
    78  		}
    79  		v := reflect.ValueOf(fn)
    80  		if v.Kind() != reflect.Func {
    81  			panic("value for " + name + " not a function")
    82  		}
    83  		if err := goodFunc(name, v.Type()); err != nil {
    84  			panic(err)
    85  		}
    86  		out[name] = v
    87  	}
    88  }
    89  
    90  // addFuncs adds to values the functions in funcs. It does no checking of the input -
    91  // call addValueFuncs first.
    92  func addFuncs(out, in FuncMap) {
    93  	for name, fn := range in {
    94  		out[name] = fn
    95  	}
    96  }
    97  
    98  // goodFunc reports whether the function or method has the right result signature.
    99  func goodFunc(name string, typ reflect.Type) error {
   100  	// We allow functions with 1 result or 2 results where the second is an error.
   101  	switch numOut := typ.NumOut(); {
   102  	case numOut == 1:
   103  		return nil
   104  	case numOut == 2 && typ.Out(1) == errorType:
   105  		return nil
   106  	case numOut == 2:
   107  		return fmt.Errorf("invalid function signature for %s: second return value should be error; is %s", name, typ.Out(1))
   108  	default:
   109  		return fmt.Errorf("function %s has %d return values; should be 1 or 2", name, typ.NumOut())
   110  	}
   111  }
   112  
   113  // goodName reports whether the function name is a valid identifier.
   114  func goodName(name string) bool {
   115  	if name == "" {
   116  		return false
   117  	}
   118  	for i, r := range name {
   119  		switch {
   120  		case r == '_':
   121  		case i == 0 && !unicode.IsLetter(r):
   122  			return false
   123  		case !unicode.IsLetter(r) && !unicode.IsDigit(r):
   124  			return false
   125  		}
   126  	}
   127  	return true
   128  }
   129  
   130  // findFunction looks for a function in the template, and global map.
   131  func findFunction(name string, tmpl *Template) (v reflect.Value, isBuiltin, ok bool) {
   132  	if tmpl != nil && tmpl.common != nil {
   133  		tmpl.muFuncs.RLock()
   134  		defer tmpl.muFuncs.RUnlock()
   135  		if fn := tmpl.execFuncs[name]; fn.IsValid() {
   136  			return fn, false, true
   137  		}
   138  	}
   139  	if fn := builtinFuncs()[name]; fn.IsValid() {
   140  		return fn, true, true
   141  	}
   142  	return reflect.Value{}, false, false
   143  }
   144  
   145  // prepareArg checks if value can be used as an argument of type argType, and
   146  // converts an invalid value to appropriate zero if possible.
   147  func prepareArg(value reflect.Value, argType reflect.Type) (reflect.Value, error) {
   148  	if !value.IsValid() {
   149  		if !canBeNil(argType) {
   150  			return reflect.Value{}, fmt.Errorf("value is nil; should be of type %s", argType)
   151  		}
   152  		value = reflect.Zero(argType)
   153  	}
   154  	if value.Type().AssignableTo(argType) {
   155  		return value, nil
   156  	}
   157  	if intLike(value.Kind()) && intLike(argType.Kind()) && value.Type().ConvertibleTo(argType) {
   158  		value = value.Convert(argType)
   159  		return value, nil
   160  	}
   161  	return reflect.Value{}, fmt.Errorf("value has type %s; should be %s", value.Type(), argType)
   162  }
   163  
   164  func intLike(typ reflect.Kind) bool {
   165  	switch typ {
   166  	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
   167  		return true
   168  	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
   169  		return true
   170  	}
   171  	return false
   172  }
   173  
   174  // indexArg checks if a reflect.Value can be used as an index, and converts it to int if possible.
   175  func indexArg(index reflect.Value, cap int) (int, error) {
   176  	var x int64
   177  	switch index.Kind() {
   178  	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
   179  		x = index.Int()
   180  	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
   181  		x = int64(index.Uint())
   182  	case reflect.Invalid:
   183  		return 0, fmt.Errorf("cannot index slice/array with nil")
   184  	default:
   185  		return 0, fmt.Errorf("cannot index slice/array with type %s", index.Type())
   186  	}
   187  	if x < 0 || int(x) < 0 || int(x) > cap {
   188  		return 0, fmt.Errorf("index out of range: %d", x)
   189  	}
   190  	return int(x), nil
   191  }
   192  
   193  // Indexing.
   194  
   195  // index returns the result of indexing its first argument by the following
   196  // arguments. Thus "index x 1 2 3" is, in Go syntax, x[1][2][3]. Each
   197  // indexed item must be a map, slice, or array.
   198  func index(item reflect.Value, indexes ...reflect.Value) (reflect.Value, error) {
   199  	item = indirectInterface(item)
   200  	if !item.IsValid() {
   201  		return reflect.Value{}, fmt.Errorf("index of untyped nil")
   202  	}
   203  	for _, index := range indexes {
   204  		index = indirectInterface(index)
   205  		var isNil bool
   206  		if item, isNil = indirect(item); isNil {
   207  			return reflect.Value{}, fmt.Errorf("index of nil pointer")
   208  		}
   209  		switch item.Kind() {
   210  		case reflect.Array, reflect.Slice, reflect.String:
   211  			x, err := indexArg(index, item.Len())
   212  			if err != nil {
   213  				return reflect.Value{}, err
   214  			}
   215  			item = item.Index(x)
   216  		case reflect.Map:
   217  			index, err := prepareArg(index, item.Type().Key())
   218  			if err != nil {
   219  				return reflect.Value{}, err
   220  			}
   221  			if x := item.MapIndex(index); x.IsValid() {
   222  				item = x
   223  			} else {
   224  				item = reflect.Zero(item.Type().Elem())
   225  			}
   226  		case reflect.Invalid:
   227  			// the loop holds invariant: item.IsValid()
   228  			panic("unreachable")
   229  		default:
   230  			return reflect.Value{}, fmt.Errorf("can't index item of type %s", item.Type())
   231  		}
   232  	}
   233  	return item, nil
   234  }
   235  
   236  // Slicing.
   237  
   238  // slice returns the result of slicing its first argument by the remaining
   239  // arguments. Thus "slice x 1 2" is, in Go syntax, x[1:2], while "slice x"
   240  // is x[:], "slice x 1" is x[1:], and "slice x 1 2 3" is x[1:2:3]. The first
   241  // argument must be a string, slice, or array.
   242  func slice(item reflect.Value, indexes ...reflect.Value) (reflect.Value, error) {
   243  	item = indirectInterface(item)
   244  	if !item.IsValid() {
   245  		return reflect.Value{}, fmt.Errorf("slice of untyped nil")
   246  	}
   247  	var isNil bool
   248  	if item, isNil = indirect(item); isNil {
   249  		return reflect.Value{}, fmt.Errorf("slice of nil pointer")
   250  	}
   251  	if len(indexes) > 3 {
   252  		return reflect.Value{}, fmt.Errorf("too many slice indexes: %d", len(indexes))
   253  	}
   254  	var cap int
   255  	switch item.Kind() {
   256  	case reflect.String:
   257  		if len(indexes) == 3 {
   258  			return reflect.Value{}, fmt.Errorf("cannot 3-index slice a string")
   259  		}
   260  		cap = item.Len()
   261  	case reflect.Array, reflect.Slice:
   262  		cap = item.Cap()
   263  	default:
   264  		return reflect.Value{}, fmt.Errorf("can't slice item of type %s", item.Type())
   265  	}
   266  	// set default values for cases item[:], item[i:].
   267  	idx := [3]int{0, item.Len()}
   268  	for i, index := range indexes {
   269  		x, err := indexArg(index, cap)
   270  		if err != nil {
   271  			return reflect.Value{}, err
   272  		}
   273  		idx[i] = x
   274  	}
   275  	// given item[i:j], make sure i <= j.
   276  	if idx[0] > idx[1] {
   277  		return reflect.Value{}, fmt.Errorf("invalid slice index: %d > %d", idx[0], idx[1])
   278  	}
   279  	if len(indexes) < 3 {
   280  		return item.Slice(idx[0], idx[1]), nil
   281  	}
   282  	// given item[i:j:k], make sure i <= j <= k.
   283  	if idx[1] > idx[2] {
   284  		return reflect.Value{}, fmt.Errorf("invalid slice index: %d > %d", idx[1], idx[2])
   285  	}
   286  	return item.Slice3(idx[0], idx[1], idx[2]), nil
   287  }
   288  
   289  // Length
   290  
   291  // length returns the length of the item, with an error if it has no defined length.
   292  func length(item reflect.Value) (int, error) {
   293  	item, isNil := indirect(item)
   294  	if isNil {
   295  		return 0, fmt.Errorf("len of nil pointer")
   296  	}
   297  	switch item.Kind() {
   298  	case reflect.Array, reflect.Chan, reflect.Map, reflect.Slice, reflect.String:
   299  		return item.Len(), nil
   300  	}
   301  	return 0, fmt.Errorf("len of type %s", item.Type())
   302  }
   303  
   304  // Function invocation
   305  
   306  func emptyCall(fn reflect.Value, args ...reflect.Value) reflect.Value {
   307  	panic("unreachable") // implemented as a special case in evalCall
   308  }
   309  
   310  // call returns the result of evaluating the first argument as a function.
   311  // The function must return 1 result, or 2 results, the second of which is an error.
   312  func call(name string, fn reflect.Value, args ...reflect.Value) (reflect.Value, error) {
   313  	fn = indirectInterface(fn)
   314  	if !fn.IsValid() {
   315  		return reflect.Value{}, fmt.Errorf("call of nil")
   316  	}
   317  	typ := fn.Type()
   318  	if typ.Kind() != reflect.Func {
   319  		return reflect.Value{}, fmt.Errorf("non-function %s of type %s", name, typ)
   320  	}
   321  
   322  	if err := goodFunc(name, typ); err != nil {
   323  		return reflect.Value{}, err
   324  	}
   325  	numIn := typ.NumIn()
   326  	var dddType reflect.Type
   327  	if typ.IsVariadic() {
   328  		if len(args) < numIn-1 {
   329  			return reflect.Value{}, fmt.Errorf("wrong number of args for %s: got %d want at least %d", name, len(args), numIn-1)
   330  		}
   331  		dddType = typ.In(numIn - 1).Elem()
   332  	} else {
   333  		if len(args) != numIn {
   334  			return reflect.Value{}, fmt.Errorf("wrong number of args for %s: got %d want %d", name, len(args), numIn)
   335  		}
   336  	}
   337  	argv := make([]reflect.Value, len(args))
   338  	for i, arg := range args {
   339  		arg = indirectInterface(arg)
   340  		// Compute the expected type. Clumsy because of variadics.
   341  		argType := dddType
   342  		if !typ.IsVariadic() || i < numIn-1 {
   343  			argType = typ.In(i)
   344  		}
   345  
   346  		var err error
   347  		if argv[i], err = prepareArg(arg, argType); err != nil {
   348  			return reflect.Value{}, fmt.Errorf("arg %d: %w", i, err)
   349  		}
   350  	}
   351  	return safeCall(fn, argv)
   352  }
   353  
   354  // safeCall runs fun.Call(args), and returns the resulting value and error, if
   355  // any. If the call panics, the panic value is returned as an error.
   356  func safeCall(fun reflect.Value, args []reflect.Value) (val reflect.Value, err error) {
   357  	defer func() {
   358  		if r := recover(); r != nil {
   359  			if e, ok := r.(error); ok {
   360  				err = e
   361  			} else {
   362  				err = fmt.Errorf("%v", r)
   363  			}
   364  		}
   365  	}()
   366  	ret := fun.Call(args)
   367  	if len(ret) == 2 && !ret[1].IsNil() {
   368  		return ret[0], ret[1].Interface().(error)
   369  	}
   370  	return ret[0], nil
   371  }
   372  
   373  // Boolean logic.
   374  
   375  func truth(arg reflect.Value) bool {
   376  	t, _ := isTrue(indirectInterface(arg))
   377  	return t
   378  }
   379  
   380  // and computes the Boolean AND of its arguments, returning
   381  // the first false argument it encounters, or the last argument.
   382  func and(arg0 reflect.Value, args ...reflect.Value) reflect.Value {
   383  	panic("unreachable") // implemented as a special case in evalCall
   384  }
   385  
   386  // or computes the Boolean OR of its arguments, returning
   387  // the first true argument it encounters, or the last argument.
   388  func or(arg0 reflect.Value, args ...reflect.Value) reflect.Value {
   389  	panic("unreachable") // implemented as a special case in evalCall
   390  }
   391  
   392  // not returns the Boolean negation of its argument.
   393  func not(arg reflect.Value) bool {
   394  	return !truth(arg)
   395  }
   396  
   397  // Comparison.
   398  
   399  // TODO: Perhaps allow comparison between signed and unsigned integers.
   400  
   401  var (
   402  	errBadComparisonType = errors.New("invalid type for comparison")
   403  	errNoComparison      = errors.New("missing argument for comparison")
   404  )
   405  
   406  type kind int
   407  
   408  const (
   409  	invalidKind kind = iota
   410  	boolKind
   411  	complexKind
   412  	intKind
   413  	floatKind
   414  	stringKind
   415  	uintKind
   416  )
   417  
   418  func basicKind(v reflect.Value) (kind, error) {
   419  	switch v.Kind() {
   420  	case reflect.Bool:
   421  		return boolKind, nil
   422  	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
   423  		return intKind, nil
   424  	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
   425  		return uintKind, nil
   426  	case reflect.Float32, reflect.Float64:
   427  		return floatKind, nil
   428  	case reflect.Complex64, reflect.Complex128:
   429  		return complexKind, nil
   430  	case reflect.String:
   431  		return stringKind, nil
   432  	}
   433  	return invalidKind, errBadComparisonType
   434  }
   435  
   436  // isNil returns true if v is the zero reflect.Value, or nil of its type.
   437  func isNil(v reflect.Value) bool {
   438  	if !v.IsValid() {
   439  		return true
   440  	}
   441  	switch v.Kind() {
   442  	case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Pointer, reflect.Slice:
   443  		return v.IsNil()
   444  	}
   445  	return false
   446  }
   447  
   448  // canCompare reports whether v1 and v2 are both the same kind, or one is nil.
   449  // Called only when dealing with nillable types, or there's about to be an error.
   450  func canCompare(v1, v2 reflect.Value) bool {
   451  	k1 := v1.Kind()
   452  	k2 := v2.Kind()
   453  	if k1 == k2 {
   454  		return true
   455  	}
   456  	// We know the type can be compared to nil.
   457  	return k1 == reflect.Invalid || k2 == reflect.Invalid
   458  }
   459  
   460  // eq evaluates the comparison a == b || a == c || ...
   461  func eq(arg1 reflect.Value, arg2 ...reflect.Value) (bool, error) {
   462  	arg1 = indirectInterface(arg1)
   463  	if len(arg2) == 0 {
   464  		return false, errNoComparison
   465  	}
   466  	k1, _ := basicKind(arg1)
   467  	for _, arg := range arg2 {
   468  		arg = indirectInterface(arg)
   469  		k2, _ := basicKind(arg)
   470  		truth := false
   471  		if k1 != k2 {
   472  			// Special case: Can compare integer values regardless of type's sign.
   473  			switch {
   474  			case k1 == intKind && k2 == uintKind:
   475  				truth = arg1.Int() >= 0 && uint64(arg1.Int()) == arg.Uint()
   476  			case k1 == uintKind && k2 == intKind:
   477  				truth = arg.Int() >= 0 && arg1.Uint() == uint64(arg.Int())
   478  			default:
   479  				if arg1.IsValid() && arg.IsValid() {
   480  					return false, fmt.Errorf("incompatible types for comparison: %v and %v", arg1.Type(), arg.Type())
   481  				}
   482  			}
   483  		} else {
   484  			switch k1 {
   485  			case boolKind:
   486  				truth = arg1.Bool() == arg.Bool()
   487  			case complexKind:
   488  				truth = arg1.Complex() == arg.Complex()
   489  			case floatKind:
   490  				truth = arg1.Float() == arg.Float()
   491  			case intKind:
   492  				truth = arg1.Int() == arg.Int()
   493  			case stringKind:
   494  				truth = arg1.String() == arg.String()
   495  			case uintKind:
   496  				truth = arg1.Uint() == arg.Uint()
   497  			default:
   498  				if !canCompare(arg1, arg) {
   499  					return false, fmt.Errorf("non-comparable types %s: %v, %s: %v", arg1, arg1.Type(), arg.Type(), arg)
   500  				}
   501  				if isNil(arg1) || isNil(arg) {
   502  					truth = isNil(arg) == isNil(arg1)
   503  				} else {
   504  					if !arg.Type().Comparable() {
   505  						return false, fmt.Errorf("non-comparable type %s: %v", arg, arg.Type())
   506  					}
   507  					truth = arg1.Interface() == arg.Interface()
   508  				}
   509  			}
   510  		}
   511  		if truth {
   512  			return true, nil
   513  		}
   514  	}
   515  	return false, nil
   516  }
   517  
   518  // ne evaluates the comparison a != b.
   519  func ne(arg1, arg2 reflect.Value) (bool, error) {
   520  	// != is the inverse of ==.
   521  	equal, err := eq(arg1, arg2)
   522  	return !equal, err
   523  }
   524  
   525  // lt evaluates the comparison a < b.
   526  func lt(arg1, arg2 reflect.Value) (bool, error) {
   527  	arg1 = indirectInterface(arg1)
   528  	k1, err := basicKind(arg1)
   529  	if err != nil {
   530  		return false, err
   531  	}
   532  	arg2 = indirectInterface(arg2)
   533  	k2, err := basicKind(arg2)
   534  	if err != nil {
   535  		return false, err
   536  	}
   537  	truth := false
   538  	if k1 != k2 {
   539  		// Special case: Can compare integer values regardless of type's sign.
   540  		switch {
   541  		case k1 == intKind && k2 == uintKind:
   542  			truth = arg1.Int() < 0 || uint64(arg1.Int()) < arg2.Uint()
   543  		case k1 == uintKind && k2 == intKind:
   544  			truth = arg2.Int() >= 0 && arg1.Uint() < uint64(arg2.Int())
   545  		default:
   546  			return false, fmt.Errorf("incompatible types for comparison: %v and %v", arg1.Type(), arg2.Type())
   547  		}
   548  	} else {
   549  		switch k1 {
   550  		case boolKind, complexKind:
   551  			return false, errBadComparisonType
   552  		case floatKind:
   553  			truth = arg1.Float() < arg2.Float()
   554  		case intKind:
   555  			truth = arg1.Int() < arg2.Int()
   556  		case stringKind:
   557  			truth = arg1.String() < arg2.String()
   558  		case uintKind:
   559  			truth = arg1.Uint() < arg2.Uint()
   560  		default:
   561  			panic("invalid kind")
   562  		}
   563  	}
   564  	return truth, nil
   565  }
   566  
   567  // le evaluates the comparison <= b.
   568  func le(arg1, arg2 reflect.Value) (bool, error) {
   569  	// <= is < or ==.
   570  	lessThan, err := lt(arg1, arg2)
   571  	if lessThan || err != nil {
   572  		return lessThan, err
   573  	}
   574  	return eq(arg1, arg2)
   575  }
   576  
   577  // gt evaluates the comparison a > b.
   578  func gt(arg1, arg2 reflect.Value) (bool, error) {
   579  	// > is the inverse of <=.
   580  	lessOrEqual, err := le(arg1, arg2)
   581  	if err != nil {
   582  		return false, err
   583  	}
   584  	return !lessOrEqual, nil
   585  }
   586  
   587  // ge evaluates the comparison a >= b.
   588  func ge(arg1, arg2 reflect.Value) (bool, error) {
   589  	// >= is the inverse of <.
   590  	lessThan, err := lt(arg1, arg2)
   591  	if err != nil {
   592  		return false, err
   593  	}
   594  	return !lessThan, nil
   595  }
   596  
   597  // HTML escaping.
   598  
   599  var (
   600  	htmlQuot = []byte("&#34;") // shorter than "&quot;"
   601  	htmlApos = []byte("&#39;") // shorter than "&apos;" and apos was not in HTML until HTML5
   602  	htmlAmp  = []byte("&amp;")
   603  	htmlLt   = []byte("&lt;")
   604  	htmlGt   = []byte("&gt;")
   605  	htmlNull = []byte("\uFFFD")
   606  )
   607  
   608  // HTMLEscape writes to w the escaped HTML equivalent of the plain text data b.
   609  func HTMLEscape(w io.Writer, b []byte) {
   610  	last := 0
   611  	for i, c := range b {
   612  		var html []byte
   613  		switch c {
   614  		case '\000':
   615  			html = htmlNull
   616  		case '"':
   617  			html = htmlQuot
   618  		case '\'':
   619  			html = htmlApos
   620  		case '&':
   621  			html = htmlAmp
   622  		case '<':
   623  			html = htmlLt
   624  		case '>':
   625  			html = htmlGt
   626  		default:
   627  			continue
   628  		}
   629  		w.Write(b[last:i])
   630  		w.Write(html)
   631  		last = i + 1
   632  	}
   633  	w.Write(b[last:])
   634  }
   635  
   636  // HTMLEscapeString returns the escaped HTML equivalent of the plain text data s.
   637  func HTMLEscapeString(s string) string {
   638  	// Avoid allocation if we can.
   639  	if !strings.ContainsAny(s, "'\"&<>\000") {
   640  		return s
   641  	}
   642  	var b strings.Builder
   643  	HTMLEscape(&b, []byte(s))
   644  	return b.String()
   645  }
   646  
   647  // HTMLEscaper returns the escaped HTML equivalent of the textual
   648  // representation of its arguments.
   649  func HTMLEscaper(args ...any) string {
   650  	return HTMLEscapeString(evalArgs(args))
   651  }
   652  
   653  // JavaScript escaping.
   654  
   655  var (
   656  	jsLowUni = []byte(`\u00`)
   657  	hex      = []byte("0123456789ABCDEF")
   658  
   659  	jsBackslash = []byte(`\\`)
   660  	jsApos      = []byte(`\'`)
   661  	jsQuot      = []byte(`\"`)
   662  	jsLt        = []byte(`\u003C`)
   663  	jsGt        = []byte(`\u003E`)
   664  	jsAmp       = []byte(`\u0026`)
   665  	jsEq        = []byte(`\u003D`)
   666  )
   667  
   668  // JSEscape writes to w the escaped JavaScript equivalent of the plain text data b.
   669  func JSEscape(w io.Writer, b []byte) {
   670  	last := 0
   671  	for i := 0; i < len(b); i++ {
   672  		c := b[i]
   673  
   674  		if !jsIsSpecial(rune(c)) {
   675  			// fast path: nothing to do
   676  			continue
   677  		}
   678  		w.Write(b[last:i])
   679  
   680  		if c < utf8.RuneSelf {
   681  			// Quotes, slashes and angle brackets get quoted.
   682  			// Control characters get written as \u00XX.
   683  			switch c {
   684  			case '\\':
   685  				w.Write(jsBackslash)
   686  			case '\'':
   687  				w.Write(jsApos)
   688  			case '"':
   689  				w.Write(jsQuot)
   690  			case '<':
   691  				w.Write(jsLt)
   692  			case '>':
   693  				w.Write(jsGt)
   694  			case '&':
   695  				w.Write(jsAmp)
   696  			case '=':
   697  				w.Write(jsEq)
   698  			default:
   699  				w.Write(jsLowUni)
   700  				t, b := c>>4, c&0x0f
   701  				w.Write(hex[t : t+1])
   702  				w.Write(hex[b : b+1])
   703  			}
   704  		} else {
   705  			// Unicode rune.
   706  			r, size := utf8.DecodeRune(b[i:])
   707  			if unicode.IsPrint(r) {
   708  				w.Write(b[i : i+size])
   709  			} else {
   710  				fmt.Fprintf(w, "\\u%04X", r)
   711  			}
   712  			i += size - 1
   713  		}
   714  		last = i + 1
   715  	}
   716  	w.Write(b[last:])
   717  }
   718  
   719  // JSEscapeString returns the escaped JavaScript equivalent of the plain text data s.
   720  func JSEscapeString(s string) string {
   721  	// Avoid allocation if we can.
   722  	if strings.IndexFunc(s, jsIsSpecial) < 0 {
   723  		return s
   724  	}
   725  	var b strings.Builder
   726  	JSEscape(&b, []byte(s))
   727  	return b.String()
   728  }
   729  
   730  func jsIsSpecial(r rune) bool {
   731  	switch r {
   732  	case '\\', '\'', '"', '<', '>', '&', '=':
   733  		return true
   734  	}
   735  	return r < ' ' || utf8.RuneSelf <= r
   736  }
   737  
   738  // JSEscaper returns the escaped JavaScript equivalent of the textual
   739  // representation of its arguments.
   740  func JSEscaper(args ...any) string {
   741  	return JSEscapeString(evalArgs(args))
   742  }
   743  
   744  // URLQueryEscaper returns the escaped value of the textual representation of
   745  // its arguments in a form suitable for embedding in a URL query.
   746  func URLQueryEscaper(args ...any) string {
   747  	return url.QueryEscape(evalArgs(args))
   748  }
   749  
   750  // evalArgs formats the list of arguments into a string. It is therefore equivalent to
   751  //
   752  //	fmt.Sprint(args...)
   753  //
   754  // except that each argument is indirected (if a pointer), as required,
   755  // using the same rules as the default string evaluation during template
   756  // execution.
   757  func evalArgs(args []any) string {
   758  	ok := false
   759  	var s string
   760  	// Fast path for simple common case.
   761  	if len(args) == 1 {
   762  		s, ok = args[0].(string)
   763  	}
   764  	if !ok {
   765  		for i, arg := range args {
   766  			a, ok := printableValue(reflect.ValueOf(arg))
   767  			if ok {
   768  				args[i] = a
   769  			} // else let fmt do its thing
   770  		}
   771  		s = fmt.Sprint(args...)
   772  	}
   773  	return s
   774  }
   775  

View as plain text