package gemini import ( "bufio" "context" "crypto/tls" "errors" "fmt" "math/rand" "net" "strings" "sync" "sync/atomic" "time" ) type serverContextKeyType struct{} //nolint:gochecknoglobals // context keys need to be global. var serverContextKey = &serverContextKeyType{} // ServerFromCtx extracts a server from a context if present. If a server is not // present on the context the returned bool will be false. func ServerFromCtx(ctx context.Context) (*Server, bool) { srv, ok := ctx.Value(serverContextKey).(*Server) return srv, ok } // Server serves network requests using the Gemini protocol. type Server struct { // Handler is a handler that is called each time a new network requests is // received. Unlike Go's net/http panics in handlers will not be recovered // automatically. Handler Handler // LogHandler is an optional function that allows a custom logger to be // hooked into the server. Erroneous logs will be passed in with `isError` // set to true. LogHandler func(message string, isError bool) // BaseContext is a optional function that takes a listener and returns a // context. The context returned by BaseContext will be used to create all // other contexts in the request lifecycle. BaseContext func(net.Listener) context.Context // ConnContext is an optional function that takes a context and a net.Conn // and returns a context. Like BaseContext, the context returned by // ConnContext will be used to create all contexts in the request lifecycle // after the connection has been created. ConnContext func(ctx context.Context, c net.Conn) context.Context // TLSConfig is the TLS config to use for the server. TLSConfig *tls.Config listeners *listeners shutdownHooks *hooks inShutdown atomicBool } type hooks struct { fns []func() mu sync.RWMutex } func (h *hooks) register(fn func()) { h.mu.Lock() defer h.mu.Unlock() h.fns = append(h.fns, fn) } func (h *hooks) call() { var wg sync.WaitGroup for _, f := range h.fns { wg.Add(1) go func(f func()) { defer wg.Done() f() }(f) } wg.Wait() } type atomicBool int32 func (b *atomicBool) isSet() bool { return atomic.LoadInt32((*int32)(b)) != 0 } func (b *atomicBool) set() { atomic.StoreInt32((*int32)(b), 1) } func (s *Server) closeAndLogOnError(closeFn func() error) { if err := closeFn(); err != nil && s.LogHandler != nil { s.LogHandler(err.Error(), true) } } // ListenAndServeTLS creates a listener and starts the server. If certFile and // keyFile are non-empty strings the key pair will be loaded and used. func (s *Server) ListenAndServeTLS(addr, certFile, keyFile string) error { if s.shuttingDown() { return fmt.Errorf("%w", ErrServerShutdown) } if addr == "" { addr = ":1965" } l, err := net.Listen("tcp", addr) if err != nil { return fmt.Errorf("%w: %s", ErrStartingServer, err) } return s.ServeTLS(l, certFile, keyFile) } // ServeTLS starts a server with the provided listener, wrapping it in a TLS // listener. If certFile and keyFile are non-empty strings the key pair will be // loaded and used. func (s *Server) ServeTLS(l net.Listener, certFile, keyFile string) error { if s.shuttingDown() { return fmt.Errorf("%w", ErrServerShutdown) } var tlsConfig *tls.Config if s.TLSConfig == nil { tlsConfig = &tls.Config{ MinVersion: tls.VersionTLS13, ClientAuth: tls.RequestClientCert, } } else { tlsConfig = s.TLSConfig.Clone() } if tlsConfig.MinVersion < tls.VersionTLS12 { return fmt.Errorf("%w: unsupported TLS version %q", ErrStartingServer, tlsConfig.MinVersion) } if tlsConfig.Certificates == nil && tlsConfig.GetCertificate == nil { cert, err := tls.LoadX509KeyPair(certFile, keyFile) if err != nil { return fmt.Errorf("%w: %s", ErrStartingServer, err) } tlsConfig.Certificates = []tls.Certificate{cert} } return s.Serve(tls.NewListener(l, tlsConfig)) } // Serve start a server using the provided listener. The listener should support // TLS. func (s *Server) Serve(l net.Listener) error { if s.shuttingDown() { return fmt.Errorf("%w", ErrServerShutdown) } ln := &listener{inner: l} defer s.closeAndLogOnError(ln.Close) if s.listeners == nil { s.listeners = &listeners{} } s.listeners.add(ln) defer s.listeners.remove(ln) baseCtx := context.Background() if s.BaseContext != nil { baseCtx = s.BaseContext(l) if baseCtx == nil { return fmt.Errorf("%w: BaseContext is nil", ErrStartingServer) } } var tempDelay time.Duration // how long to sleep on accept failure ctx := context.WithValue(baseCtx, serverContextKey, s) for { rw, err := ln.Accept() if err != nil { if s.inShutdown.isSet() { return fmt.Errorf("%w", ErrServerShutdown) } var ne net.Error if errors.As(err, &ne) && ne.Temporary() { if tempDelay == 0 { tempDelay = 5 * time.Millisecond } else { tempDelay *= 2 } if max := 1 * time.Second; tempDelay > max { tempDelay = max } time.Sleep(tempDelay) continue } return fmt.Errorf("%w: %s", ErrServing, err) } connCtx := ctx if s.ConnContext != nil { connCtx = s.ConnContext(connCtx, rw) if connCtx == nil { return fmt.Errorf("%w: ConnContext is nil", ErrServing) } } tempDelay = 0 go s.handle(connCtx, rw) } } func (s *Server) handle(ctx context.Context, c *tls.Conn) { if s.LogHandler != nil { s.LogHandler(fmt.Sprintf("started processing request from %s", c.RemoteAddr()), false) defer s.LogHandler(fmt.Sprintf("finished processing request from %s", c.RemoteAddr()), false) } if err := c.HandshakeContext(ctx); err != nil && s.LogHandler != nil { s.LogHandler(err.Error(), true) } w := &response{conn: c} defer s.closeAndLogOnError(c.Close) defer w.Failure(ctx, StatusNotFound, StatusNotFound.Description()) r, err := ReadRequest(bufio.NewReader(c)) if err != nil { if s.LogHandler != nil { s.LogHandler(err.Error(), true) } w.Failure(ctx, StatusBadRequest, err.Error()) return } if !strings.EqualFold(r.URI.Host, c.ConnectionState().ServerName) { w.Failure(ctx, StatusProxyRequestRefused, StatusProxyRequestRefused.Description()) return } if r.URI.url.Scheme != "gemini" { w.Failure(ctx, StatusProxyRequestRefused, StatusProxyRequestRefused.Description()) return } if r.URI.Port != "" { _, port, err := net.SplitHostPort(c.LocalAddr().String()) if err != nil { if s.LogHandler != nil { s.LogHandler(err.Error(), true) } w.Failure(ctx, StatusServerFailure, err.Error()) return } if r.URI.Port != port { w.Failure(ctx, StatusProxyRequestRefused, StatusProxyRequestRefused.Description()) return } } r.RemoteAddr = c.RemoteAddr().String() if connState := c.ConnectionState(); len(connState.PeerCertificates) > 0 { r.Subject = &connState.PeerCertificates[0].Subject } if s.Handler != nil { s.Handler.ServeGemini(ctx, w, r) } } func (s *Server) shuttingDown() bool { return s.inShutdown.isSet() } // RegisterOnShutdown adds a function which will be called when the server shuts // down. RegisterOnShutdown can be called more than once to stack functions. func (s *Server) RegisterOnShutdown(f func()) { if s.shutdownHooks == nil { s.shutdownHooks = &hooks{} } s.shutdownHooks.register(f) } const shutdownPollIntervalMax = 500 * time.Millisecond // Shutdown shuts the server down gracefully. The shutdown will stop waiting for // requests to finish if the context cancels. func (s *Server) Shutdown(ctx context.Context) error { s.inShutdown.set() if s.shutdownHooks != nil { s.shutdownHooks.call() } if s.listeners == nil { return nil } err := s.listeners.close() interval := pollInterval(time.Millisecond) timer := time.NewTimer(interval.next()) defer timer.Stop() for { if s.listeners.len() == 0 { return err } select { case <-ctx.Done(): return fmt.Errorf("%w: context done: %s", ErrServerShutdown, ctx.Err()) case <-timer.C: timer.Reset(interval.next()) } } } type pollInterval time.Duration func (p *pollInterval) next() time.Duration { const jitterPercent = 10 last := time.Duration(*p) interval := last + time.Duration(rand.Intn(int(last/jitterPercent))) //nolint:gosec // this does not need to be sercure. // Double and clamp for next time. last *= 2 if last > shutdownPollIntervalMax { last = shutdownPollIntervalMax } *p = pollInterval(last) return interval }