package gemini_test import ( "bufio" "context" "crypto/tls" "errors" "fmt" "net" "strings" "sync" "testing" "time" "source.community/ckaznocha/gemini" ) func failOnErrIfServerRunning(t *testing.T, done <-chan struct{}, err error, msg string) bool { t.Helper() select { case <-done: return true default: if err != nil { t.Error(msg) return true } return false } } func TestServer_ListenAndServeTLS(t *testing.T) { t.Parallel() type fields struct { Handler gemini.Handler LogHandler func(message string, isError bool) BaseContext func(net.Listener) context.Context ConnContext func(ctx context.Context, c net.Conn) context.Context TLSConfig *tls.Config } type args struct { addr string certFile string keyFile string } tests := []struct { name string fields fields args args wantErr bool shutdownFirst bool }{ { name: "Starts a server on the given address", fields: fields{ Handler: gemini.NewServeMux(), }, args: args{ addr: "localhost:8049", certFile: "testdata/cert.pem", keyFile: "testdata/key.pem", }, }, { name: "Exits if the server has already been shutdown", shutdownFirst: true, }, } for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() s := &gemini.Server{ Handler: tt.fields.Handler, LogHandler: tt.fields.LogHandler, BaseContext: tt.fields.BaseContext, ConnContext: tt.fields.ConnContext, TLSConfig: tt.fields.TLSConfig, } if tt.shutdownFirst { s.Shutdown(context.Background()) } clientDone := make(chan struct{}, 1) serverDone := make(chan struct{}, 1) go func() { defer func() { if err := s.Shutdown(context.Background()); err != nil { t.Errorf("Server.ListenAndServeTLS(%v, %v) Unable to shutdown: error = %v", tt.args.certFile, tt.args.keyFile, err) } clientDone <- struct{}{} }() conn, err := tls.DialWithDialer( &net.Dialer{Timeout: 1 * time.Second}, "tcp", tt.args.addr, &tls.Config{ InsecureSkipVerify: true, }, ) if failOnErrIfServerRunning( t, serverDone, err, fmt.Sprintf("Server.ListenAndServeTLS(%v, %v) Unable to dial: error = %v", tt.args.certFile, tt.args.keyFile, err), ) { return } if _, err = fmt.Fprint(conn, "gemini://localhost.com/\r\n"); failOnErrIfServerRunning( t, serverDone, err, fmt.Sprintf("Server.ListenAndServeTLS(%v, %v) Unable to write request: error = %v", tt.args.certFile, tt.args.keyFile, err), ) { return } b := bufio.NewReader(conn) got, err := b.ReadBytes('\n') if failOnErrIfServerRunning( t, serverDone, err, fmt.Sprintf("Server.ListenAndServeTLS(%v, %v) Unable to read response: error = %v", tt.args.certFile, tt.args.keyFile, err), ) { return } if string(got) != "51 The requested resource could not be found but may be available in the future.\r\n" { t.Errorf("Server.ListenAndServeTLS(%v, %v) Incorrect response: got = %v", tt.args.certFile, tt.args.keyFile, string(got)) return } if err = conn.Close(); failOnErrIfServerRunning( t, serverDone, err, fmt.Sprintf("Server.ListenAndServeTLS(%v, %v) Unable to close connection: error = %v", tt.args.certFile, tt.args.keyFile, err), ) { return } }() if err := s.ListenAndServeTLS( tt.args.addr, tt.args.certFile, tt.args.keyFile, ); !errors.Is(err, gemini.ErrServerShutdown) != tt.wantErr { t.Errorf("Server.ListenAndServeTLS(%v, %v) error = %v, wantErr %v", tt.args.certFile, tt.args.keyFile, err, tt.wantErr) } serverDone <- struct{}{} <-clientDone }) } } func TestServer_ServeTLS(t *testing.T) { t.Parallel() type fields struct { Handler gemini.Handler LogHandler func(message string, isError bool) BaseContext func(net.Listener) context.Context ConnContext func(ctx context.Context, c net.Conn) context.Context TLSConfig *tls.Config } type args struct { certFile string keyFile string } tests := []struct { name string fields fields args args wantErr bool shutdownFirst bool }{ { name: "", fields: fields{ Handler: gemini.NewServeMux(), }, args: args{ certFile: "", keyFile: "", }, wantErr: true, }, { name: "", args: args{ certFile: "testdata/cert.pem", keyFile: "testdata/key.pem", }, }, { name: "", fields: fields{ TLSConfig: &tls.Config{MinVersion: tls.VersionTLS11}, }, args: args{ certFile: "testdata/cert.pem", keyFile: "testdata/key.pem", }, wantErr: true, }, { name: "", fields: fields{}, args: args{ certFile: "testdata/cert.pem", keyFile: "testdata/key.pem", }, shutdownFirst: true, }, } for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() s := &gemini.Server{ Handler: tt.fields.Handler, LogHandler: tt.fields.LogHandler, BaseContext: tt.fields.BaseContext, ConnContext: tt.fields.ConnContext, TLSConfig: tt.fields.TLSConfig, } l, err := net.Listen("tcp", "localhost:0") if err != nil { t.Fatalf("unable to open listener: %s", err) } if tt.shutdownFirst { s.Shutdown(context.Background()) } clientDone := make(chan struct{}, 1) serverDone := make(chan struct{}, 1) go func() { defer func() { if err := s.Shutdown(context.Background()); err != nil { t.Errorf("Server.ListenAndServeTLS(%v, %v) Unable to shutdown: error = %v", tt.args.certFile, tt.args.keyFile, err) } clientDone <- struct{}{} }() conn, err := tls.DialWithDialer( &net.Dialer{Timeout: 1 * time.Second}, "tcp", l.Addr().String(), &tls.Config{ InsecureSkipVerify: true, }, ) if failOnErrIfServerRunning( t, serverDone, err, fmt.Sprintf("Server.ListenAndServeTLS(%v, %v) Unable to dial: error = %v", tt.args.certFile, tt.args.keyFile, err), ) { return } if _, err = fmt.Fprint(conn, "gemini://localhost.com/\r\n"); failOnErrIfServerRunning( t, serverDone, err, fmt.Sprintf("Server.ListenAndServeTLS(%v, %v) Unable to write request: error = %v", tt.args.certFile, tt.args.keyFile, err), ) { return } b := bufio.NewReader(conn) got, err := b.ReadBytes('\n') if failOnErrIfServerRunning( t, serverDone, err, fmt.Sprintf("Server.ListenAndServeTLS(%v, %v) Unable to read response: error = %v", tt.args.certFile, tt.args.keyFile, err), ) { return } if string(got) != "51 The requested resource could not be found but may be available in the future.\r\n" { t.Errorf("Server.ListenAndServeTLS(%v, %v) Incorrect response: got = %v", tt.args.certFile, tt.args.keyFile, string(got)) return } if err = conn.Close(); failOnErrIfServerRunning( t, serverDone, err, fmt.Sprintf("Server.ListenAndServeTLS(%v, %v) Unable to close connection: error = %v", tt.args.certFile, tt.args.keyFile, err), ) { return } }() if err := s.ServeTLS(l, tt.args.certFile, tt.args.keyFile); !errors.Is(err, gemini.ErrServerShutdown) != tt.wantErr { t.Errorf("Server.ServeTLS(%v, %v, %v) error = %v, wantErr %v", l, tt.args.certFile, tt.args.keyFile, err, tt.wantErr) } serverDone <- struct{}{} <-clientDone }) } } func TestServer_Serve(t *testing.T) { t.Parallel() type fields struct { Handler gemini.Handler LogHandler func(message string, isError bool) BaseContext func(net.Listener) context.Context ConnContext func(ctx context.Context, c net.Conn) context.Context TLSConfig *tls.Config } tests := []struct { fields fields name string reqBody string wantResponse string wantErr bool shutdownFirst bool wantNonTLS bool }{ { name: "works", fields: fields{ Handler: gemini.NewServeMux(), }, wantResponse: "51 The requested resource could not be found but may be available in the future.\r\n", }, { name: "errors when its not a TLS connection", fields: fields{ Handler: gemini.NewServeMux(), }, wantErr: true, wantNonTLS: true, }, { name: "errors when the server has already been shutdown", fields: fields{ Handler: gemini.NewServeMux(), }, shutdownFirst: true, }, { name: "errors base context is nil", fields: fields{ Handler: gemini.NewServeMux(), BaseContext: func(l net.Listener) context.Context { return nil }, }, wantErr: true, }, { name: "errors conn context is nil", fields: fields{ Handler: gemini.NewServeMux(), ConnContext: func(ctx context.Context, c net.Conn) context.Context { return nil }, }, wantErr: true, }, { name: "handles a too large request", fields: fields{ Handler: gemini.NewServeMux(), }, wantResponse: "59 the request length exceeded the max size of 1024 bytes\r\n", reqBody: "gemini://localhost/" + strings.Repeat("n", 1028-len("gemini://localhost/\r\n")) + "\r\n", }, } for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() s := &gemini.Server{ Handler: tt.fields.Handler, LogHandler: tt.fields.LogHandler, BaseContext: tt.fields.BaseContext, ConnContext: tt.fields.ConnContext, TLSConfig: tt.fields.TLSConfig, } l, err := net.Listen("tcp", "localhost:0") if err != nil { t.Fatalf("unable to open listener: %s", err) } if !tt.wantNonTLS { var cert tls.Certificate cert, err = tls.LoadX509KeyPair("testdata/cert.pem", "testdata/key.pem") if err != nil { t.Fatalf("unable to open listener: %s", err) } config := &tls.Config{ Certificates: []tls.Certificate{cert}, } l = tls.NewListener(l, config) } if tt.shutdownFirst { s.Shutdown(context.Background()) } clientDone := make(chan struct{}, 1) serverDone := make(chan struct{}, 1) go func() { defer func() { if err = s.Shutdown(context.Background()); err != nil { t.Errorf("Server.Serve(%v) Unable to shutdown: error = %v", l, err) } clientDone <- struct{}{} }() var conn *tls.Conn conn, err = tls.DialWithDialer( &net.Dialer{Timeout: 1 * time.Second}, "tcp", l.Addr().String(), &tls.Config{ InsecureSkipVerify: true, }, ) if failOnErrIfServerRunning( t, serverDone, err, fmt.Sprintf("Server.Serve(%v) Unable to dial: error = %v", l, err), ) { return } body := "gemini://localhost.com/\r\n" if tt.reqBody != "" { body = tt.reqBody } if _, err = fmt.Fprint(conn, body); failOnErrIfServerRunning( t, serverDone, err, fmt.Sprintf("Server.Serve(%v) Unable to write request: error = %v", l, err), ) { return } b := bufio.NewReader(conn) var got []byte got, err = b.ReadBytes('\n') if failOnErrIfServerRunning( t, serverDone, err, fmt.Sprintf("Server.Serve(%v) Unable to read response: error = %v", l, err), ) { return } if string(got) != tt.wantResponse { //!= "51 The requested resource could not be found but may be available in the future.\r\n" { t.Errorf("Server.Serve(%v) Incorrect response: got = %v, want %s", l, string(got), tt.wantResponse) return } if err = conn.Close(); failOnErrIfServerRunning( t, serverDone, err, fmt.Sprintf("Server.Serve(%v) Unable to close connection: error = %v", l, err), ) { return } }() err = s.Serve(l) if !errors.Is(err, gemini.ErrServerShutdown) != tt.wantErr { t.Errorf("Server.Serve(%v) error = %v, wantErr %v", l, err, tt.wantErr) } serverDone <- struct{}{} <-clientDone }) } } func TestServer_RegisterOnShutdown(t *testing.T) { t.Parallel() type args struct { i int } tests := []struct { name string args args }{ { name: "", args: args{i: 0}, }, { name: "", args: args{i: 1}, }, { name: "", args: args{i: 2}, }, } for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() s := &gemini.Server{} var wg sync.WaitGroup wg.Add(tt.args.i) for i := 0; i < tt.args.i; i++ { s.RegisterOnShutdown(func() { wg.Done() }) } l, err := net.Listen("tcp", "localhost:0") if err != nil { t.Fatalf("unable to open listener: %s", err) } cert, err := tls.LoadX509KeyPair("testdata/cert.pem", "testdata/key.pem") if err != nil { t.Fatalf("unable to open listener: %s", err) } config := &tls.Config{ Certificates: []tls.Certificate{cert}, } l = tls.NewListener(l, config) go func() { defer s.Shutdown(context.Background()) time.Sleep(5 * time.Millisecond) }() if err := s.Serve(l); err != nil && !errors.Is(err, gemini.ErrServerShutdown) { t.Errorf("Server.Serve(%v) error = %v", l, err) } wg.Wait() }) } } func TestServer_Shutdown(t *testing.T) { t.Parallel() type fields struct { Handler gemini.Handler LogHandler func(message string, isError bool) BaseContext func(net.Listener) context.Context ConnContext func(ctx context.Context, c net.Conn) context.Context TLSConfig *tls.Config Addr string } type args struct { ctx context.Context } tests := []struct { args args fields fields name string wantErr bool }{ // TODO: Add test cases. } for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() s := &gemini.Server{ Handler: tt.fields.Handler, LogHandler: tt.fields.LogHandler, BaseContext: tt.fields.BaseContext, ConnContext: tt.fields.ConnContext, TLSConfig: tt.fields.TLSConfig, } if err := s.Shutdown(tt.args.ctx); (err != nil) != tt.wantErr { t.Errorf("Server.Shutdown(%v) error = %v, wantErr %v", tt.args.ctx, err, tt.wantErr) } }) } }