💾 Archived View for source.community › ckaznocha › gemini › raw › main › server_test.go captured on 2024-05-26 at 15:03:57.

View Raw

More Information

⬅️ Previous capture (2021-12-17)

-=-=-=-=-=-=-

package gemini_test

import (
	"bufio"
	"context"
	"crypto/tls"
	"errors"
	"fmt"
	"net"
	"strings"
	"sync"
	"testing"
	"time"

	"source.community/ckaznocha/gemini"
)

func failOnErrIfServerRunning(t *testing.T, done <-chan struct{}, err error, msg string) bool {
	t.Helper()

	select {
	case <-done:
		return true
	default:
		if err != nil {
			t.Error(msg)

			return true
		}

		return false
	}
}

func TestServer_ListenAndServeTLS(t *testing.T) {
	t.Parallel()

	type fields struct {
		Handler     gemini.Handler
		LogHandler  func(message string, isError bool)
		BaseContext func(net.Listener) context.Context
		ConnContext func(ctx context.Context, c net.Conn) context.Context
		TLSConfig   *tls.Config
	}

	type args struct {
		addr     string
		certFile string
		keyFile  string
	}

	tests := []struct {
		name          string
		fields        fields
		args          args
		wantErr       bool
		shutdownFirst bool
	}{
		{
			name: "Starts a server on the given address",
			fields: fields{
				Handler: gemini.NewServeMux(),
			},
			args: args{
				addr:     "localhost:8049",
				certFile: "testdata/cert.pem",
				keyFile:  "testdata/key.pem",
			},
		},
		{
			name:          "Exits if the server has already been shutdown",
			shutdownFirst: true,
		},
	}

	for _, tt := range tests {
		tt := tt
		t.Run(tt.name, func(t *testing.T) {
			t.Parallel()
			s := &gemini.Server{
				Handler:     tt.fields.Handler,
				LogHandler:  tt.fields.LogHandler,
				BaseContext: tt.fields.BaseContext,
				ConnContext: tt.fields.ConnContext,
				TLSConfig:   tt.fields.TLSConfig,
			}

			if tt.shutdownFirst {
				s.Shutdown(context.Background())
			}

			clientDone := make(chan struct{}, 1)
			serverDone := make(chan struct{}, 1)

			go func() {
				defer func() {
					if err := s.Shutdown(context.Background()); err != nil {
						t.Errorf("Server.ListenAndServeTLS(%v, %v) Unable to shutdown: error = %v", tt.args.certFile, tt.args.keyFile, err)
					}

					clientDone <- struct{}{}
				}()

				conn, err := tls.DialWithDialer(
					&net.Dialer{Timeout: 1 * time.Second},
					"tcp",
					tt.args.addr,
					&tls.Config{
						InsecureSkipVerify: true,
					},
				)
				if failOnErrIfServerRunning(
					t,
					serverDone,
					err,
					fmt.Sprintf("Server.ListenAndServeTLS(%v, %v) Unable to dial: error = %v", tt.args.certFile, tt.args.keyFile, err),
				) {
					return
				}

				if _, err = fmt.Fprint(conn, "gemini://localhost.com/\r\n"); failOnErrIfServerRunning(
					t,
					serverDone,
					err,
					fmt.Sprintf("Server.ListenAndServeTLS(%v, %v) Unable to write request: error = %v", tt.args.certFile, tt.args.keyFile, err),
				) {
					return
				}

				b := bufio.NewReader(conn)
				got, err := b.ReadBytes('\n')
				if failOnErrIfServerRunning(
					t,
					serverDone,
					err,
					fmt.Sprintf("Server.ListenAndServeTLS(%v, %v) Unable to read response: error = %v", tt.args.certFile, tt.args.keyFile, err),
				) {
					return
				}

				if string(got) != "51 The requested resource could not be found but may be available in the future.\r\n" {
					t.Errorf("Server.ListenAndServeTLS(%v, %v) Incorrect response: got = %v", tt.args.certFile, tt.args.keyFile, string(got))

					return
				}

				if err = conn.Close(); failOnErrIfServerRunning(
					t,
					serverDone,
					err,
					fmt.Sprintf("Server.ListenAndServeTLS(%v, %v) Unable to close connection: error = %v", tt.args.certFile, tt.args.keyFile, err),
				) {
					return
				}
			}()

			if err := s.ListenAndServeTLS(
				tt.args.addr,
				tt.args.certFile,
				tt.args.keyFile,
			); !errors.Is(err, gemini.ErrServerShutdown) != tt.wantErr {
				t.Errorf("Server.ListenAndServeTLS(%v, %v) error = %v, wantErr %v", tt.args.certFile, tt.args.keyFile, err, tt.wantErr)
			}

			serverDone <- struct{}{}
			<-clientDone
		})
	}
}

func TestServer_ServeTLS(t *testing.T) {
	t.Parallel()

	type fields struct {
		Handler     gemini.Handler
		LogHandler  func(message string, isError bool)
		BaseContext func(net.Listener) context.Context
		ConnContext func(ctx context.Context, c net.Conn) context.Context
		TLSConfig   *tls.Config
	}

	type args struct {
		certFile string
		keyFile  string
	}

	tests := []struct {
		name          string
		fields        fields
		args          args
		wantErr       bool
		shutdownFirst bool
	}{
		{
			name: "",
			fields: fields{
				Handler: gemini.NewServeMux(),
			},
			args: args{
				certFile: "",
				keyFile:  "",
			},
			wantErr: true,
		},
		{
			name: "",
			args: args{
				certFile: "testdata/cert.pem",
				keyFile:  "testdata/key.pem",
			},
		},
		{
			name: "",
			fields: fields{
				TLSConfig: &tls.Config{MinVersion: tls.VersionTLS11},
			},
			args: args{
				certFile: "testdata/cert.pem",
				keyFile:  "testdata/key.pem",
			},
			wantErr: true,
		},
		{
			name:   "",
			fields: fields{},
			args: args{
				certFile: "testdata/cert.pem",
				keyFile:  "testdata/key.pem",
			},
			shutdownFirst: true,
		},
	}

	for _, tt := range tests {
		tt := tt
		t.Run(tt.name, func(t *testing.T) {
			t.Parallel()
			s := &gemini.Server{
				Handler:     tt.fields.Handler,
				LogHandler:  tt.fields.LogHandler,
				BaseContext: tt.fields.BaseContext,
				ConnContext: tt.fields.ConnContext,
				TLSConfig:   tt.fields.TLSConfig,
			}

			l, err := net.Listen("tcp", "localhost:0")
			if err != nil {
				t.Fatalf("unable to open listener: %s", err)
			}

			if tt.shutdownFirst {
				s.Shutdown(context.Background())
			}

			clientDone := make(chan struct{}, 1)
			serverDone := make(chan struct{}, 1)

			go func() {
				defer func() {
					if err := s.Shutdown(context.Background()); err != nil {
						t.Errorf("Server.ListenAndServeTLS(%v, %v) Unable to shutdown: error = %v", tt.args.certFile, tt.args.keyFile, err)
					}

					clientDone <- struct{}{}
				}()

				conn, err := tls.DialWithDialer(
					&net.Dialer{Timeout: 1 * time.Second},
					"tcp",
					l.Addr().String(),
					&tls.Config{
						InsecureSkipVerify: true,
					},
				)
				if failOnErrIfServerRunning(
					t,
					serverDone,
					err,
					fmt.Sprintf("Server.ListenAndServeTLS(%v, %v) Unable to dial: error = %v", tt.args.certFile, tt.args.keyFile, err),
				) {
					return
				}

				if _, err = fmt.Fprint(conn, "gemini://localhost.com/\r\n"); failOnErrIfServerRunning(
					t,
					serverDone,
					err,
					fmt.Sprintf("Server.ListenAndServeTLS(%v, %v) Unable to write request: error = %v", tt.args.certFile, tt.args.keyFile, err),
				) {
					return
				}

				b := bufio.NewReader(conn)
				got, err := b.ReadBytes('\n')
				if failOnErrIfServerRunning(
					t,
					serverDone,
					err,
					fmt.Sprintf("Server.ListenAndServeTLS(%v, %v) Unable to read response: error = %v", tt.args.certFile, tt.args.keyFile, err),
				) {
					return
				}

				if string(got) != "51 The requested resource could not be found but may be available in the future.\r\n" {
					t.Errorf("Server.ListenAndServeTLS(%v, %v) Incorrect response: got = %v", tt.args.certFile, tt.args.keyFile, string(got))

					return
				}

				if err = conn.Close(); failOnErrIfServerRunning(
					t,
					serverDone,
					err,
					fmt.Sprintf("Server.ListenAndServeTLS(%v, %v) Unable to close connection: error = %v", tt.args.certFile, tt.args.keyFile, err),
				) {
					return
				}
			}()

			if err := s.ServeTLS(l, tt.args.certFile, tt.args.keyFile); !errors.Is(err, gemini.ErrServerShutdown) != tt.wantErr {
				t.Errorf("Server.ServeTLS(%v, %v, %v) error = %v, wantErr %v", l, tt.args.certFile, tt.args.keyFile, err, tt.wantErr)
			}

			serverDone <- struct{}{}
			<-clientDone
		})
	}
}

func TestServer_Serve(t *testing.T) {
	t.Parallel()

	type fields struct {
		Handler     gemini.Handler
		LogHandler  func(message string, isError bool)
		BaseContext func(net.Listener) context.Context
		ConnContext func(ctx context.Context, c net.Conn) context.Context
		TLSConfig   *tls.Config
	}

	tests := []struct {
		fields        fields
		name          string
		reqBody       string
		wantResponse  string
		wantErr       bool
		shutdownFirst bool
		wantNonTLS    bool
	}{
		{
			name: "works",
			fields: fields{
				Handler: gemini.NewServeMux(),
			},
			wantResponse: "51 The requested resource could not be found but may be available in the future.\r\n",
		},
		{
			name: "errors when its not a TLS connection",
			fields: fields{
				Handler: gemini.NewServeMux(),
			},
			wantErr:    true,
			wantNonTLS: true,
		},
		{
			name: "errors when the server has already been shutdown",
			fields: fields{
				Handler: gemini.NewServeMux(),
			},
			shutdownFirst: true,
		},
		{
			name: "errors base context is nil",
			fields: fields{
				Handler:     gemini.NewServeMux(),
				BaseContext: func(l net.Listener) context.Context { return nil },
			},
			wantErr: true,
		},
		{
			name: "errors conn context is nil",
			fields: fields{
				Handler:     gemini.NewServeMux(),
				ConnContext: func(ctx context.Context, c net.Conn) context.Context { return nil },
			},
			wantErr: true,
		},
		{
			name: "handles a too large request",
			fields: fields{
				Handler: gemini.NewServeMux(),
			},
			wantResponse: "59 the request length exceeded the max size of 1024 bytes\r\n",
			reqBody:      "gemini://localhost/" + strings.Repeat("n", 1028-len("gemini://localhost/\r\n")) + "\r\n",
		},
	}

	for _, tt := range tests {
		tt := tt
		t.Run(tt.name, func(t *testing.T) {
			t.Parallel()
			s := &gemini.Server{
				Handler:     tt.fields.Handler,
				LogHandler:  tt.fields.LogHandler,
				BaseContext: tt.fields.BaseContext,
				ConnContext: tt.fields.ConnContext,
				TLSConfig:   tt.fields.TLSConfig,
			}

			l, err := net.Listen("tcp", "localhost:0")
			if err != nil {
				t.Fatalf("unable to open listener: %s", err)
			}

			if !tt.wantNonTLS {
				var cert tls.Certificate
				cert, err = tls.LoadX509KeyPair("testdata/cert.pem", "testdata/key.pem")
				if err != nil {
					t.Fatalf("unable to open listener: %s", err)
				}

				config := &tls.Config{
					Certificates: []tls.Certificate{cert},
				}
				l = tls.NewListener(l, config)
			}

			if tt.shutdownFirst {
				s.Shutdown(context.Background())
			}

			clientDone := make(chan struct{}, 1)
			serverDone := make(chan struct{}, 1)

			go func() {
				defer func() {
					if err = s.Shutdown(context.Background()); err != nil {
						t.Errorf("Server.Serve(%v) Unable to shutdown: error = %v", l, err)
					}

					clientDone <- struct{}{}
				}()

				var conn *tls.Conn
				conn, err = tls.DialWithDialer(
					&net.Dialer{Timeout: 1 * time.Second},
					"tcp",
					l.Addr().String(),
					&tls.Config{
						InsecureSkipVerify: true,
					},
				)
				if failOnErrIfServerRunning(
					t,
					serverDone,
					err,
					fmt.Sprintf("Server.Serve(%v) Unable to dial: error = %v", l, err),
				) {
					return
				}

				body := "gemini://localhost.com/\r\n"
				if tt.reqBody != "" {
					body = tt.reqBody
				}

				if _, err = fmt.Fprint(conn, body); failOnErrIfServerRunning(
					t,
					serverDone,
					err,
					fmt.Sprintf("Server.Serve(%v) Unable to write request: error = %v", l, err),
				) {
					return
				}

				b := bufio.NewReader(conn)
				var got []byte
				got, err = b.ReadBytes('\n')
				if failOnErrIfServerRunning(
					t,
					serverDone,
					err,
					fmt.Sprintf("Server.Serve(%v) Unable to read response: error = %v", l, err),
				) {
					return
				}

				if string(got) != tt.wantResponse { //!= "51 The requested resource could not be found but may be available in the future.\r\n" {
					t.Errorf("Server.Serve(%v) Incorrect response: got = %v, want %s", l, string(got), tt.wantResponse)

					return
				}

				if err = conn.Close(); failOnErrIfServerRunning(
					t,
					serverDone,
					err,
					fmt.Sprintf("Server.Serve(%v) Unable to close connection: error = %v", l, err),
				) {
					return
				}
			}()

			err = s.Serve(l)
			if !errors.Is(err, gemini.ErrServerShutdown) != tt.wantErr {
				t.Errorf("Server.Serve(%v) error = %v, wantErr %v", l, err, tt.wantErr)
			}

			serverDone <- struct{}{}
			<-clientDone
		})
	}
}

func TestServer_RegisterOnShutdown(t *testing.T) {
	t.Parallel()

	type args struct {
		i int
	}

	tests := []struct {
		name string
		args args
	}{
		{
			name: "",
			args: args{i: 0},
		},
		{
			name: "",
			args: args{i: 1},
		},
		{
			name: "",
			args: args{i: 2},
		},
	}

	for _, tt := range tests {
		tt := tt
		t.Run(tt.name, func(t *testing.T) {
			t.Parallel()
			s := &gemini.Server{}
			var wg sync.WaitGroup
			wg.Add(tt.args.i)

			for i := 0; i < tt.args.i; i++ {
				s.RegisterOnShutdown(func() { wg.Done() })
			}

			l, err := net.Listen("tcp", "localhost:0")
			if err != nil {
				t.Fatalf("unable to open listener: %s", err)
			}

			cert, err := tls.LoadX509KeyPair("testdata/cert.pem", "testdata/key.pem")
			if err != nil {
				t.Fatalf("unable to open listener: %s", err)
			}

			config := &tls.Config{
				Certificates: []tls.Certificate{cert},
			}

			l = tls.NewListener(l, config)

			go func() {
				defer s.Shutdown(context.Background())
				time.Sleep(5 * time.Millisecond)
			}()

			if err := s.Serve(l); err != nil && !errors.Is(err, gemini.ErrServerShutdown) {
				t.Errorf("Server.Serve(%v) error = %v", l, err)
			}

			wg.Wait()
		})
	}
}

func TestServer_Shutdown(t *testing.T) {
	t.Parallel()

	type fields struct {
		Handler     gemini.Handler
		LogHandler  func(message string, isError bool)
		BaseContext func(net.Listener) context.Context
		ConnContext func(ctx context.Context, c net.Conn) context.Context
		TLSConfig   *tls.Config
		Addr        string
	}

	type args struct {
		ctx context.Context
	}

	tests := []struct {
		args    args
		fields  fields
		name    string
		wantErr bool
	}{
		// TODO: Add test cases.
	}

	for _, tt := range tests {
		tt := tt
		t.Run(tt.name, func(t *testing.T) {
			t.Parallel()
			s := &gemini.Server{
				Handler:     tt.fields.Handler,
				LogHandler:  tt.fields.LogHandler,
				BaseContext: tt.fields.BaseContext,
				ConnContext: tt.fields.ConnContext,
				TLSConfig:   tt.fields.TLSConfig,
			}
			if err := s.Shutdown(tt.args.ctx); (err != nil) != tt.wantErr {
				t.Errorf("Server.Shutdown(%v) error = %v, wantErr %v", tt.args.ctx, err, tt.wantErr)
			}
		})
	}
}