package gemini

import (

type serverContextKeyType struct{}

//nolint:gochecknoglobals // context keys need to be global.
var serverContextKey = &serverContextKeyType{}

// ServerFromCtx extracts a server from a context if present. If a server is not
// present on the context the returned bool will be false.
func ServerFromCtx(ctx context.Context) (*Server, bool) {
	srv, ok := ctx.Value(serverContextKey).(*Server)

	return srv, ok

// Server serves network requests using the Gemini protocol.
type Server struct {
	// Handler is a handler that is called each time a new network requests is
	// received. Unlike Go's net/http panics in handlers will not be recovered
	// automatically.
	Handler Handler

	// LogHandler is an optional function that allows a custom logger to be
	// hooked into the server. Erroneous logs will be passed in with `isError`
	// set to true.
	LogHandler func(message string, isError bool)

	// BaseContext is a optional function that takes a listener and returns a
	// context. The context returned by BaseContext will be used to create all
	// other contexts in the request lifecycle.
	BaseContext func(net.Listener) context.Context

	// ConnContext is an optional function that takes a context and a net.Conn
	// and returns a context. Like BaseContext, the context returned by
	// ConnContext will be used to create all contexts in the request lifecycle
	// after the connection has been created.
	ConnContext func(ctx context.Context, c net.Conn) context.Context

	// TLSConfig is the TLS config to use for the server.
	TLSConfig *tls.Config

	listeners     *listeners
	shutdownHooks *hooks
	inShutdown    atomicBool

type hooks struct {
	fns []func()
	mu  sync.RWMutex

func (h *hooks) register(fn func()) {
	defer h.mu.Unlock()

	h.fns = append(h.fns, fn)

func (h *hooks) call() {
	var wg sync.WaitGroup
	for _, f := range h.fns {

		go func(f func()) {
			defer wg.Done()


type atomicBool int32

func (b *atomicBool) isSet() bool { return atomic.LoadInt32((*int32)(b)) != 0 }
func (b *atomicBool) set()        { atomic.StoreInt32((*int32)(b), 1) }

func (s *Server) closeAndLogOnError(closeFn func() error) {
	if err := closeFn(); err != nil && s.LogHandler != nil {
		s.LogHandler(err.Error(), true)

// ListenAndServeTLS creates a listener and starts the server. If certFile and
// keyFile are non-empty strings the key pair will be loaded and used.
func (s *Server) ListenAndServeTLS(addr, certFile, keyFile string) error {
	if s.shuttingDown() {
		return fmt.Errorf("%w", ErrServerShutdown)

	if addr == "" {
		addr = ":1965"

	l, err := net.Listen("tcp", addr)
	if err != nil {
		return fmt.Errorf("%w: %s", ErrStartingServer, err)

	return s.ServeTLS(l, certFile, keyFile)

// ServeTLS starts a server with the provided listener, wrapping it in a TLS
// listener. If certFile and keyFile are non-empty strings the key pair will be
// loaded and used.
func (s *Server) ServeTLS(l net.Listener, certFile, keyFile string) error {
	if s.shuttingDown() {
		return fmt.Errorf("%w", ErrServerShutdown)

	var tlsConfig *tls.Config
	if s.TLSConfig == nil {
		tlsConfig = &tls.Config{
			MinVersion: tls.VersionTLS13,
			ClientAuth: tls.RequestClientCert,
	} else {
		tlsConfig = s.TLSConfig.Clone()

	if tlsConfig.MinVersion < tls.VersionTLS12 {
		return fmt.Errorf("%w: unsupported TLS version %q", ErrStartingServer, tlsConfig.MinVersion)

	if tlsConfig.Certificates == nil && tlsConfig.GetCertificate == nil {
		cert, err := tls.LoadX509KeyPair(certFile, keyFile)
		if err != nil {
			return fmt.Errorf("%w: %s", ErrStartingServer, err)

		tlsConfig.Certificates = []tls.Certificate{cert}

	return s.Serve(tls.NewListener(l, tlsConfig))

// Serve start a server using the provided listener. The listener should support
// TLS.
func (s *Server) Serve(l net.Listener) error {
	if s.shuttingDown() {
		return fmt.Errorf("%w", ErrServerShutdown)

	ln := &listener{inner: l}
	defer s.closeAndLogOnError(ln.Close)

	if s.listeners == nil {
		s.listeners = &listeners{}

	defer s.listeners.remove(ln)

	baseCtx := context.Background()
	if s.BaseContext != nil {
		baseCtx = s.BaseContext(l)
		if baseCtx == nil {
			return fmt.Errorf("%w: BaseContext is nil", ErrStartingServer)

	var tempDelay time.Duration // how long to sleep on accept failure

	ctx := context.WithValue(baseCtx, serverContextKey, s)

	for {
		rw, err := ln.Accept()
		if err != nil {
			if s.inShutdown.isSet() {
				return fmt.Errorf("%w", ErrServerShutdown)

			var ne net.Error
			if errors.As(err, &ne) && ne.Temporary() {
				if tempDelay == 0 {
					tempDelay = 5 * time.Millisecond
				} else {
					tempDelay *= 2

				if max := 1 * time.Second; tempDelay > max {
					tempDelay = max



			return fmt.Errorf("%w: %s", ErrServing, err)

		connCtx := ctx

		if s.ConnContext != nil {
			connCtx = s.ConnContext(connCtx, rw)
			if connCtx == nil {
				return fmt.Errorf("%w: ConnContext is nil", ErrServing)

		tempDelay = 0

		go s.handle(connCtx, rw)

func (s *Server) handle(ctx context.Context, c *tls.Conn) {
	if s.LogHandler != nil {
		s.LogHandler(fmt.Sprintf("started processing request from %s", c.RemoteAddr()), false)
		defer s.LogHandler(fmt.Sprintf("finished processing request from %s", c.RemoteAddr()), false)

	if err := c.HandshakeContext(ctx); err != nil && s.LogHandler != nil {
		s.LogHandler(err.Error(), true)

	w := &response{conn: c}

	defer s.closeAndLogOnError(c.Close)
	defer w.Failure(ctx, StatusNotFound, StatusNotFound.Description())

	r, err := ReadRequest(bufio.NewReader(c))
	if err != nil {
		if s.LogHandler != nil {
			s.LogHandler(err.Error(), true)

		w.Failure(ctx, StatusBadRequest, err.Error())


	if !strings.EqualFold(r.URI.Host, c.ConnectionState().ServerName) {
		w.Failure(ctx, StatusProxyRequestRefused, StatusProxyRequestRefused.Description())


	if r.URI.url.Scheme != "gemini" {
		w.Failure(ctx, StatusProxyRequestRefused, StatusProxyRequestRefused.Description())


	if r.URI.Port != "" {
		_, port, err := net.SplitHostPort(c.LocalAddr().String())
		if err != nil {
			if s.LogHandler != nil {
				s.LogHandler(err.Error(), true)

			w.Failure(ctx, StatusServerFailure, err.Error())


		if r.URI.Port != port {
			w.Failure(ctx, StatusProxyRequestRefused, StatusProxyRequestRefused.Description())


	r.RemoteAddr = c.RemoteAddr().String()

	if connState := c.ConnectionState(); len(connState.PeerCertificates) > 0 {
		r.Subject = &connState.PeerCertificates[0].Subject

	if s.Handler != nil {
		s.Handler.ServeGemini(ctx, w, r)

func (s *Server) shuttingDown() bool {
	return s.inShutdown.isSet()

// RegisterOnShutdown adds a function which will be called when the server shuts
// down. RegisterOnShutdown can be called more than once to stack functions.
func (s *Server) RegisterOnShutdown(f func()) {
	if s.shutdownHooks == nil {
		s.shutdownHooks = &hooks{}


const shutdownPollIntervalMax = 500 * time.Millisecond

// Shutdown shuts the server down gracefully. The shutdown will stop waiting for
// requests to finish if the context cancels.
func (s *Server) Shutdown(ctx context.Context) error {

	if s.shutdownHooks != nil {

	if s.listeners == nil {
		return nil

	err := s.listeners.close()

	interval := pollInterval(time.Millisecond)

	timer := time.NewTimer(interval.next())
	defer timer.Stop()

	for {
		if s.listeners.len() == 0 {
			return err
		select {
		case <-ctx.Done():
			return fmt.Errorf("%w: context done: %s", ErrServerShutdown, ctx.Err())
		case <-timer.C:

type pollInterval time.Duration

func (p *pollInterval) next() time.Duration {
	const jitterPercent = 10

	last := time.Duration(*p)
	interval := last + time.Duration(rand.Intn(int(last/jitterPercent))) //nolint:gosec // this does not need to be sercure.
	// Double and clamp for next time.
	last *= 2
	if last > shutdownPollIntervalMax {
		last = shutdownPollIntervalMax

	*p = pollInterval(last)

	return interval