💾 Archived View for source.community › ckaznocha › gemini › raw › main › mux.go captured on 2021-12-17 at 13:26:06.

View Raw

More Information

-=-=-=-=-=-=-

package gemini

import (
	"context"
	"strings"
	"sync"
)

// ServeMux is a Gemini request multiplexer. It will match requests to handlers
// based on the URI path.  The longest match will be the one returned. Patterns
// ending with  `/` will be matched exactly. Patterns without a trailing `/` will
// be treated as a prefix match.
type ServeMux struct {
	nodeToPattern map[*muxNode]string
	nodes         *muxNode
	mu            sync.RWMutex
}

// NewServeMux returns a new ServeMux ready to be used.
func NewServeMux() *ServeMux {
	return &ServeMux{
		nodeToPattern: map[*muxNode]string{},
		nodes: &muxNode{
			handler: nil,
			nodes:   map[string]*muxNode{},
		},
	}
}

// Handle adds a new pattern/Handler pair to the ServeMux.
func (sm *ServeMux) Handle(pattern string, handler Handler) {
	sm.mu.Lock()
	defer sm.mu.Unlock()

	sm.nodeToPattern[sm.nodes.insert(pattern, handler)] = pattern
}

// HandleFunc adds a new pattern/HandleFunc pair to the ServeMux.
func (sm *ServeMux) HandleFunc(pattern string, handler func(context.Context, ResponseWriter, *Request)) {
	sm.Handle(pattern, HandlerFunc(handler))
}

// Handler looks up a matching Handler based on a Request. It returns the patter
// that matched in addition to the Hander.
func (sm *ServeMux) Handler(r *Request) (h Handler, pattern string) {
	sm.mu.RLock()
	defer sm.mu.RUnlock()

	if !strings.HasPrefix(r.URI.Path, "/") {
		r.URI.Path = "/" + r.URI.Path
	}

	n := sm.nodes.find(r.URI.Path)
	if n.handler == nil {
		return HandlerFunc(func(c context.Context, rw ResponseWriter, r *Request) {
			rw.Failure(c, StatusNotFound, StatusNotFound.Description())
		}), ""
	}

	return n.handler, sm.nodeToPattern[n]
}

// ServeGemini implements the Handler interface.
func (sm *ServeMux) ServeGemini(ctx context.Context, w ResponseWriter, r *Request) {
	handlr, _ := sm.Handler(r)
	handlr.ServeGemini(ctx, w, r)
}

type muxNode struct {
	handler Handler
	nodes   map[string]*muxNode
}

func (mn *muxNode) insert(pattern string, handler Handler) *muxNode {
	if pattern == "" {
		mn.handler = handler

		return mn
	}

	idx := nextIndex(pattern)

	child, ok := mn.nodes[pattern[:idx]]
	if !ok {
		child = &muxNode{nodes: map[string]*muxNode{}}
		mn.nodes[pattern[:idx]] = child
	}

	return child.insert(pattern[idx:], handler)
}

func (mn *muxNode) find(pattern string) *muxNode {
	if child, ok := mn.nodes[pattern]; ok {
		return child
	}

	idx := nextIndex(pattern)

	if child, ok := mn.nodes[pattern[:idx]]; ok {
		return child.find(pattern[idx:])
	}

	return mn
}

func nextIndex(pattern string) int {
	idx := strings.IndexRune(pattern[1:], '/')
	if idx == -1 {
		return len(pattern)
	}

	return idx + 1
}