package gemini import ( "context" "strings" "sync" ) // ServeMux is a Gemini request multiplexer. It will match requests to handlers // based on the URI path. The longest match will be the one returned. Patterns // ending with `/` will be matched exactly. Patterns without a trailing `/` will // be treated as a prefix match. type ServeMux struct { nodeToPattern map[*muxNode]string nodes *muxNode mu sync.RWMutex } // NewServeMux returns a new ServeMux ready to be used. func NewServeMux() *ServeMux { return &ServeMux{ nodeToPattern: map[*muxNode]string{}, nodes: &muxNode{ handler: nil, nodes: map[string]*muxNode{}, }, } } // Handle adds a new pattern/Handler pair to the ServeMux. func (sm *ServeMux) Handle(pattern string, handler Handler) { sm.mu.Lock() defer sm.mu.Unlock() sm.nodeToPattern[sm.nodes.insert(pattern, handler)] = pattern } // HandleFunc adds a new pattern/HandleFunc pair to the ServeMux. func (sm *ServeMux) HandleFunc(pattern string, handler func(context.Context, ResponseWriter, *Request)) { sm.Handle(pattern, HandlerFunc(handler)) } // Handler looks up a matching Handler based on a Request. It returns the patter // that matched in addition to the Hander. func (sm *ServeMux) Handler(r *Request) (h Handler, pattern string) { sm.mu.RLock() defer sm.mu.RUnlock() if !strings.HasPrefix(r.URI.Path, "/") { r.URI.Path = "/" + r.URI.Path } n := sm.nodes.find(r.URI.Path) if n.handler == nil { return HandlerFunc(func(c context.Context, rw ResponseWriter, r *Request) { rw.Failure(c, StatusNotFound, StatusNotFound.Description()) }), "" } return n.handler, sm.nodeToPattern[n] } // ServeGemini implements the Handler interface. func (sm *ServeMux) ServeGemini(ctx context.Context, w ResponseWriter, r *Request) { handlr, _ := sm.Handler(r) handlr.ServeGemini(ctx, w, r) } type muxNode struct { handler Handler nodes map[string]*muxNode } func (mn *muxNode) insert(pattern string, handler Handler) *muxNode { if pattern == "" { mn.handler = handler return mn } idx := nextIndex(pattern) child, ok := mn.nodes[pattern[:idx]] if !ok { child = &muxNode{nodes: map[string]*muxNode{}} mn.nodes[pattern[:idx]] = child } return child.insert(pattern[idx:], handler) } func (mn *muxNode) find(pattern string) *muxNode { if child, ok := mn.nodes[pattern]; ok { return child } idx := nextIndex(pattern) if child, ok := mn.nodes[pattern[:idx]]; ok { return child.find(pattern[idx:]) } return mn } func nextIndex(pattern string) int { idx := strings.IndexRune(pattern[1:], '/') if idx == -1 { return len(pattern) } return idx + 1 }