💾 Archived View for source.community › ckaznocha › gemini › raw › main › server.go captured on 2024-05-12 at 15:29:11.

View Raw

More Information

⬅️ Previous capture (2021-12-17)

-=-=-=-=-=-=-

package gemini

import (
	"bufio"
	"context"
	"crypto/tls"
	"errors"
	"fmt"
	"math/rand"
	"net"
	"strings"
	"sync"
	"sync/atomic"
	"time"
)

type serverContextKeyType struct{}

//nolint:gochecknoglobals // context keys need to be global.
var serverContextKey = &serverContextKeyType{}

// ServerFromCtx extracts a server from a context if present. If a server is not
// present on the context the returned bool will be false.
func ServerFromCtx(ctx context.Context) (*Server, bool) {
	srv, ok := ctx.Value(serverContextKey).(*Server)

	return srv, ok
}

// Server serves network requests using the Gemini protocol.
type Server struct {
	// Handler is a handler that is called each time a new network requests is
	// received. Unlike Go's net/http panics in handlers will not be recovered
	// automatically.
	Handler Handler

	// LogHandler is an optional function that allows a custom logger to be
	// hooked into the server. Erroneous logs will be passed in with `isError`
	// set to true.
	LogHandler func(message string, isError bool)

	// BaseContext is a optional function that takes a listener and returns a
	// context. The context returned by BaseContext will be used to create all
	// other contexts in the request lifecycle.
	BaseContext func(net.Listener) context.Context

	// ConnContext is an optional function that takes a context and a net.Conn
	// and returns a context. Like BaseContext, the context returned by
	// ConnContext will be used to create all contexts in the request lifecycle
	// after the connection has been created.
	ConnContext func(ctx context.Context, c net.Conn) context.Context

	// TLSConfig is the TLS config to use for the server.
	TLSConfig *tls.Config

	listeners     *listeners
	shutdownHooks *hooks
	inShutdown    atomicBool
}

type hooks struct {
	fns []func()
	mu  sync.RWMutex
}

func (h *hooks) register(fn func()) {
	h.mu.Lock()
	defer h.mu.Unlock()

	h.fns = append(h.fns, fn)
}

func (h *hooks) call() {
	var wg sync.WaitGroup
	for _, f := range h.fns {
		wg.Add(1)

		go func(f func()) {
			defer wg.Done()
			f()
		}(f)
	}

	wg.Wait()
}

type atomicBool int32

func (b *atomicBool) isSet() bool { return atomic.LoadInt32((*int32)(b)) != 0 }
func (b *atomicBool) set()        { atomic.StoreInt32((*int32)(b), 1) }

func (s *Server) closeAndLogOnError(closeFn func() error) {
	if err := closeFn(); err != nil && s.LogHandler != nil {
		s.LogHandler(err.Error(), true)
	}
}

// ListenAndServeTLS creates a listener and starts the server. If certFile and
// keyFile are non-empty strings the key pair will be loaded and used.
func (s *Server) ListenAndServeTLS(addr, certFile, keyFile string) error {
	if s.shuttingDown() {
		return fmt.Errorf("%w", ErrServerShutdown)
	}

	if addr == "" {
		addr = ":1965"
	}

	l, err := net.Listen("tcp", addr)
	if err != nil {
		return fmt.Errorf("%w: %s", ErrStartingServer, err)
	}

	return s.ServeTLS(l, certFile, keyFile)
}

// ServeTLS starts a server with the provided listener, wrapping it in a TLS
// listener. If certFile and keyFile are non-empty strings the key pair will be
// loaded and used.
func (s *Server) ServeTLS(l net.Listener, certFile, keyFile string) error {
	if s.shuttingDown() {
		return fmt.Errorf("%w", ErrServerShutdown)
	}

	var tlsConfig *tls.Config
	if s.TLSConfig == nil {
		tlsConfig = &tls.Config{
			MinVersion: tls.VersionTLS13,
			ClientAuth: tls.RequestClientCert,
		}
	} else {
		tlsConfig = s.TLSConfig.Clone()
	}

	if tlsConfig.MinVersion < tls.VersionTLS12 {
		return fmt.Errorf("%w: unsupported TLS version %q", ErrStartingServer, tlsConfig.MinVersion)
	}

	if tlsConfig.Certificates == nil && tlsConfig.GetCertificate == nil {
		cert, err := tls.LoadX509KeyPair(certFile, keyFile)
		if err != nil {
			return fmt.Errorf("%w: %s", ErrStartingServer, err)
		}

		tlsConfig.Certificates = []tls.Certificate{cert}
	}

	return s.Serve(tls.NewListener(l, tlsConfig))
}

// Serve start a server using the provided listener. The listener should support
// TLS.
func (s *Server) Serve(l net.Listener) error {
	if s.shuttingDown() {
		return fmt.Errorf("%w", ErrServerShutdown)
	}

	ln := &listener{inner: l}
	defer s.closeAndLogOnError(ln.Close)

	if s.listeners == nil {
		s.listeners = &listeners{}
	}

	s.listeners.add(ln)
	defer s.listeners.remove(ln)

	baseCtx := context.Background()
	if s.BaseContext != nil {
		baseCtx = s.BaseContext(l)
		if baseCtx == nil {
			return fmt.Errorf("%w: BaseContext is nil", ErrStartingServer)
		}
	}

	var tempDelay time.Duration // how long to sleep on accept failure

	ctx := context.WithValue(baseCtx, serverContextKey, s)

	for {
		rw, err := ln.Accept()
		if err != nil {
			if s.inShutdown.isSet() {
				return fmt.Errorf("%w", ErrServerShutdown)
			}

			var ne net.Error
			if errors.As(err, &ne) && ne.Temporary() {
				if tempDelay == 0 {
					tempDelay = 5 * time.Millisecond
				} else {
					tempDelay *= 2
				}

				if max := 1 * time.Second; tempDelay > max {
					tempDelay = max
				}

				time.Sleep(tempDelay)

				continue
			}

			return fmt.Errorf("%w: %s", ErrServing, err)
		}

		connCtx := ctx

		if s.ConnContext != nil {
			connCtx = s.ConnContext(connCtx, rw)
			if connCtx == nil {
				return fmt.Errorf("%w: ConnContext is nil", ErrServing)
			}
		}

		tempDelay = 0

		go s.handle(connCtx, rw)
	}
}

func (s *Server) handle(ctx context.Context, c *tls.Conn) {
	if s.LogHandler != nil {
		s.LogHandler(fmt.Sprintf("started processing request from %s", c.RemoteAddr()), false)
		defer s.LogHandler(fmt.Sprintf("finished processing request from %s", c.RemoteAddr()), false)
	}

	if err := c.HandshakeContext(ctx); err != nil && s.LogHandler != nil {
		s.LogHandler(err.Error(), true)
	}

	w := &response{conn: c}

	defer s.closeAndLogOnError(c.Close)
	defer w.Failure(ctx, StatusNotFound, StatusNotFound.Description())

	r, err := ReadRequest(bufio.NewReader(c))
	if err != nil {
		if s.LogHandler != nil {
			s.LogHandler(err.Error(), true)
		}

		w.Failure(ctx, StatusBadRequest, err.Error())

		return
	}

	if !strings.EqualFold(r.URI.Host, c.ConnectionState().ServerName) {
		w.Failure(ctx, StatusProxyRequestRefused, StatusProxyRequestRefused.Description())

		return
	}

	if r.URI.url.Scheme != "gemini" {
		w.Failure(ctx, StatusProxyRequestRefused, StatusProxyRequestRefused.Description())

		return
	}

	if r.URI.Port != "" {
		_, port, err := net.SplitHostPort(c.LocalAddr().String())
		if err != nil {
			if s.LogHandler != nil {
				s.LogHandler(err.Error(), true)
			}

			w.Failure(ctx, StatusServerFailure, err.Error())

			return
		}

		if r.URI.Port != port {
			w.Failure(ctx, StatusProxyRequestRefused, StatusProxyRequestRefused.Description())

			return
		}
	}

	r.RemoteAddr = c.RemoteAddr().String()

	if connState := c.ConnectionState(); len(connState.PeerCertificates) > 0 {
		r.Subject = &connState.PeerCertificates[0].Subject
	}

	if s.Handler != nil {
		s.Handler.ServeGemini(ctx, w, r)
	}
}

func (s *Server) shuttingDown() bool {
	return s.inShutdown.isSet()
}

// RegisterOnShutdown adds a function which will be called when the server shuts
// down. RegisterOnShutdown can be called more than once to stack functions.
func (s *Server) RegisterOnShutdown(f func()) {
	if s.shutdownHooks == nil {
		s.shutdownHooks = &hooks{}
	}

	s.shutdownHooks.register(f)
}

const shutdownPollIntervalMax = 500 * time.Millisecond

// Shutdown shuts the server down gracefully. The shutdown will stop waiting for
// requests to finish if the context cancels.
func (s *Server) Shutdown(ctx context.Context) error {
	s.inShutdown.set()

	if s.shutdownHooks != nil {
		s.shutdownHooks.call()
	}

	if s.listeners == nil {
		return nil
	}

	err := s.listeners.close()

	interval := pollInterval(time.Millisecond)

	timer := time.NewTimer(interval.next())
	defer timer.Stop()

	for {
		if s.listeners.len() == 0 {
			return err
		}
		select {
		case <-ctx.Done():
			return fmt.Errorf("%w: context done: %s", ErrServerShutdown, ctx.Err())
		case <-timer.C:
			timer.Reset(interval.next())
		}
	}
}

type pollInterval time.Duration

func (p *pollInterval) next() time.Duration {
	const jitterPercent = 10

	last := time.Duration(*p)
	interval := last + time.Duration(rand.Intn(int(last/jitterPercent))) //nolint:gosec // this does not need to be sercure.
	// Double and clamp for next time.
	last *= 2
	if last > shutdownPollIntervalMax {
		last = shutdownPollIntervalMax
	}

	*p = pollInterval(last)

	return interval
}