1  
     2  
     3  
     4  
     5  
     6  package sumdb
     7  
     8  import (
     9  	"bytes"
    10  	"context"
    11  	"net/http"
    12  	"os"
    13  	"strings"
    14  
    15  	"golang.org/x/mod/internal/lazyregexp"
    16  	"golang.org/x/mod/module"
    17  	"golang.org/x/mod/sumdb/tlog"
    18  )
    19  
    20  
    21  
    22  type ServerOps interface {
    23  	
    24  	Signed(ctx context.Context) ([]byte, error)
    25  
    26  	
    27  	ReadRecords(ctx context.Context, id, n int64) ([][]byte, error)
    28  
    29  	
    30  	
    31  	Lookup(ctx context.Context, m module.Version) (int64, error)
    32  
    33  	
    34  	
    35  	ReadTileData(ctx context.Context, t tlog.Tile) ([]byte, error)
    36  }
    37  
    38  
    39  
    40  
    41  type Server struct {
    42  	ops ServerOps
    43  }
    44  
    45  
    46  func NewServer(ops ServerOps) *Server {
    47  	return &Server{ops: ops}
    48  }
    49  
    50  
    51  
    52  
    53  
    54  
    55  
    56  
    57  
    58  var ServerPaths = []string{
    59  	"/lookup/",
    60  	"/latest",
    61  	"/tile/",
    62  }
    63  
    64  var modVerRE = lazyregexp.New(`^[^@]+@v[0-9]+\.[0-9]+\.[0-9]+(-[^@]*)?(\+incompatible)?$`)
    65  
    66  func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
    67  	ctx := r.Context()
    68  
    69  	switch {
    70  	default:
    71  		http.NotFound(w, r)
    72  
    73  	case strings.HasPrefix(r.URL.Path, "/lookup/"):
    74  		mod := strings.TrimPrefix(r.URL.Path, "/lookup/")
    75  		if !modVerRE.MatchString(mod) {
    76  			http.Error(w, "invalid module@version syntax", http.StatusBadRequest)
    77  			return
    78  		}
    79  		i := strings.Index(mod, "@")
    80  		escPath, escVers := mod[:i], mod[i+1:]
    81  		path, err := module.UnescapePath(escPath)
    82  		if err != nil {
    83  			reportError(w, err)
    84  			return
    85  		}
    86  		vers, err := module.UnescapeVersion(escVers)
    87  		if err != nil {
    88  			reportError(w, err)
    89  			return
    90  		}
    91  		id, err := s.ops.Lookup(ctx, module.Version{Path: path, Version: vers})
    92  		if err != nil {
    93  			reportError(w, err)
    94  			return
    95  		}
    96  		records, err := s.ops.ReadRecords(ctx, id, 1)
    97  		if err != nil {
    98  			
    99  			http.Error(w, err.Error(), http.StatusInternalServerError)
   100  			return
   101  		}
   102  		if len(records) != 1 {
   103  			http.Error(w, "invalid record count returned by ReadRecords", http.StatusInternalServerError)
   104  			return
   105  		}
   106  		msg, err := tlog.FormatRecord(id, records[0])
   107  		if err != nil {
   108  			http.Error(w, err.Error(), http.StatusInternalServerError)
   109  			return
   110  		}
   111  		signed, err := s.ops.Signed(ctx)
   112  		if err != nil {
   113  			http.Error(w, err.Error(), http.StatusInternalServerError)
   114  			return
   115  		}
   116  		w.Header().Set("Content-Type", "text/plain; charset=UTF-8")
   117  		w.Write(msg)
   118  		w.Write(signed)
   119  
   120  	case r.URL.Path == "/latest":
   121  		data, err := s.ops.Signed(ctx)
   122  		if err != nil {
   123  			http.Error(w, err.Error(), http.StatusInternalServerError)
   124  			return
   125  		}
   126  		w.Header().Set("Content-Type", "text/plain; charset=UTF-8")
   127  		w.Write(data)
   128  
   129  	case strings.HasPrefix(r.URL.Path, "/tile/"):
   130  		t, err := tlog.ParseTilePath(r.URL.Path[1:])
   131  		if err != nil {
   132  			http.Error(w, "invalid tile syntax", http.StatusBadRequest)
   133  			return
   134  		}
   135  		if t.L == -1 {
   136  			
   137  			start := t.N << uint(t.H)
   138  			records, err := s.ops.ReadRecords(ctx, start, int64(t.W))
   139  			if err != nil {
   140  				reportError(w, err)
   141  				return
   142  			}
   143  			if len(records) != t.W {
   144  				http.Error(w, "invalid record count returned by ReadRecords", http.StatusInternalServerError)
   145  				return
   146  			}
   147  			var data []byte
   148  			for i, text := range records {
   149  				msg, err := tlog.FormatRecord(start+int64(i), text)
   150  				if err != nil {
   151  					http.Error(w, err.Error(), http.StatusInternalServerError)
   152  					return
   153  				}
   154  				
   155  				_, msg, _ = bytes.Cut(msg, []byte{'\n'})
   156  				data = append(data, msg...)
   157  			}
   158  			w.Header().Set("Content-Type", "text/plain; charset=UTF-8")
   159  			w.Write(data)
   160  			return
   161  		}
   162  
   163  		data, err := s.ops.ReadTileData(ctx, t)
   164  		if err != nil {
   165  			reportError(w, err)
   166  			return
   167  		}
   168  		w.Header().Set("Content-Type", "application/octet-stream")
   169  		w.Write(data)
   170  	}
   171  }
   172  
   173  
   174  
   175  
   176  
   177  
   178  func reportError(w http.ResponseWriter, err error) {
   179  	if os.IsNotExist(err) {
   180  		http.Error(w, err.Error(), http.StatusNotFound)
   181  		return
   182  	}
   183  	http.Error(w, err.Error(), http.StatusInternalServerError)
   184  }
   185  
View as plain text