Thumbnail image

Golang Server Graceful Stop

When building a fullstack HTTP server I came across the following problem. I needed to run a HTTP server and HTTPS server and a third service that performs some search in the database and performs some user notifications. I wanted the program to synchronise the goroutines and be able to stop it gracefully. I didn’t want it to continue to run the http server if the https server failed to start or the third service failed to connect to the database. So I have created this server code that synchronises the services goroutines and stops all of them if there is an error. Furthermore, it stops them if an interrupt signal is called, for example hitting Ctrl+C/Cmd+C . I will explain a little more the need for the services:

  • HTTP Server

    • Redirect to the HTTPS service
    • Serve the directory for certboot challenge
  • HTTPS Server

    • Serve the application
  • Other service

    • Run a periodic task to query the database and notify users if needed

I wanted to avoid using a reverse proxy, like nginx for example, first to reduce the attack surface and software dependencies and also to keep my application contained within my code. Furthermore, Some of these reverse proxies don’t support HTTP/2.

To achieve this I set myself to start the server code, but I had a few problems getting it to synchronise all the goroutines. Eventually I came up with something that looks like the code below. For the purpose of explaining it, I created a server using some of the best practices, but for sake of demonstration I kept a flat file structure.

File structure in the directory

The main file:

package main

import (
	"context"
	"log/slog"
	"os"
)

func main() {
	serv := NewServer(":8081", ":8082")
	if err := serv.Run(context.Background()); err != nil {
		slog.Error(err.Error())
		os.Exit(1)
	}
	slog.Info("Server stopped gracefully.")
}

The main function is pretty simple. Create a new server, where I pass two ports, in this example I pass the strings :8081 and :8082. These are the ports for the http and https services. In production you would pass :80 and :443 and you would define them either on a config file or environment variable. I then call the Run function, passing a root context, that should be propagated to all the functions throughout your code. I am using slog for logging, which is awesome. I may create a new post in the future about slog. Once you call the Run function your program stops the execution of the code underneath and is not executed unless there is an error or the program stops.

So let’s analyse the server file code:

package main

import (
	"context"
	"crypto/rand"
	"errors"
	"fmt"
	"golang.org/x/sync/errgroup"
	"math/big"
	"net/http"
	"os"
	"os/signal"
	"syscall"
	"time"
)

type Server struct {
	httpAddr  string
	httpsAddr string
}

func NewServer(httpAddr, httpsAddr string) *Server {
	return &Server{httpAddr: httpAddr, httpsAddr: httpsAddr}
}

func (s *Server) Run(ctx context.Context) error {
	signalChan := make(chan os.Signal, 1)
	ctx, cancel := signal.NotifyContext(ctx, os.Interrupt, os.Kill, syscall.SIGINT, syscall.SIGTERM)
	defer cancel()

	// Create an errgroup for managing multiple goroutines
	g, gctx := errgroup.WithContext(ctx)

	// Start two web services in separate goroutines
	g.Go(func() error { return httpServer(gctx, s.httpAddr) })
	g.Go(func() error { return httpsServer(gctx, s.httpsAddr) })

	// Listen for OS interrupts and cancel context
	go func() {
		<-signalChan // Block until an OS signal is received
		fmt.Println("Received interrupt signal, shutting down...")
		cancel() // Cancel the context
	}()

	// Wait for all goroutines to exit
	if err := g.Wait(); err != nil {
		fmt.Println("Error:", err)
		return err
	}

	fmt.Println("All services stopped. Exiting.")
	return nil
}

Let’s start with the NewServer function, which is basically a Server constructor. it receives two strings as parameters, the first is where the http server port, and the second is the https server port, as we have seen in the main function. The function then returns a pointer to a Server struct.

Now let’s analyse the Run function. The Run function is a method of the Server type, which is where the dependencies of the Server are stored (dependency injection). So in this example, the dependencies are the ports required for the server to run. Usually, a logger and some other dependencies can be included, but for the example I kept it simple.

func (s *Server) Run(ctx context.Context) error {
	signalChan := make(chan os.Signal, 1)
	ctx, cancel := signal.NotifyContext(ctx, os.Interrupt, os.Kill, syscall.SIGINT, syscall.SIGTERM)
	defer cancel()
[...]

For communication with the go routines we will need a channel. For this I have created a channel called signalChan, which is of type os.Signal. I have then invoked a function called signal.NotifyContext() from package “os/signal”, that is created for this purpose. It takes a parent context, which comes from our main function and if it receives any of signals specified of type os.Signal, it notifies the context with Done. I have selected the signals os.Interrupt, os.Kill, syscall.SIGINT, syscall.SIGTERM. This function returns a new context that we are going to use. There is also a defer cancel() to make sure that the cancel function is called at the end of the execution of the Run function.

Moving on with the code.

// Create an errgroup for managing multiple goroutines
g, gctx := errgroup.WithContext(ctx)

// Start two web services in separate goroutines
g.Go(func() error { return httpServer(gctx, s.httpAddr) })
g.Go(func() error { return httpsServer(gctx, s.httpsAddr) })

// Listen for OS interrupts and cancel context
go func() {
	<-signalChan // Block until an OS signal is received
	fmt.Println("Received interrupt signal, shutting down...")
	cancel() // Cancel the context
}()

// Wait for all goroutines to exit
if err := g.Wait(); err != nil {
	fmt.Println("Error:", err)
	return err
}

fmt.Println("All services stopped. Exiting.")
return nil
}

In the above code, we are using an errgroup which adds functionality on top of sync.WaitGroup. I use this to make sure that if one of the goroutines has an error, all the goroutines stop. In basic terms, if there is an error returned from the goroutine, all the other go routines should be stopped. This is an example of how simple it is to use concurrency in golang. In the first line, I create a errgroup.WithContext(ctx) and pass the context which is also now getting notified of any os.Signals that we defined. This function will return a pointer to Group (*Group) and a new context which I called gctx. To start a goroutine within the errgroup you call g.Go(). The g.Go() call needs a function passed as parameter that returns an error. This is how the g.group knows when to stop the other goroutines, it expects the goroutines to return an error, and then close the other goroutines in the same group. So I started two goroutines with the same structure:

// Start two web services in separate goroutines
g.Go(func() error { return httpServer(gctx, s.httpAddr) })
g.Go(func() error { return httpsServer(gctx, s.httpsAddr) })

In these lines of code, I start two goroutines. They are very similar, so there isn’t a need to explain both. As I said before, we need a function that returns an error. I could simply call g.Go(func() error { return http.ListenAndServe(s.httpAddr, handler)}) where the handler is the handler for the http listener, but having a function is more organized, and I can contain the configuration of each inside the function. So what I have is an anonymous function func() error {} and inside the function I need to return an error, so func() error {return functionName()} and I must make sure that functionName returns an error. for my httpServer function, I pass the gctx or errror group context and the s.httpAddr, which is the port where the server will run.

func httpServer(ctx context.Context, addr string) error {
	mux := http.NewServeMux()
	mux.HandleFunc("GET /", httpHandler)
	mux.Handle("GET /.well-known/acme-challenge/", http.StripPrefix("/.well-known/acme-challenge/", http.FileServer(http.Dir("/challenge/.well-known/acme-challenge/"))))

	httpServer := &http.Server{
		Addr:         addr,
		Handler:      mux,
		ReadTimeout:  5 * time.Second,
		WriteTimeout: 10 * time.Second,
		IdleTimeout:  15 * time.Second,
	}

	errChan := make(chan error, 1)
	defer close(errChan)

	go func() {
		fmt.Println("Starting server on", addr)
		// Return ListenAndServe error directly so errgroup can handle it
		if err := httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
			errChan <- err
		}
	}()

	select {
	case <-ctx.Done():
		fmt.Println("Shutting down HTTP server on", addr)
		httpServer.SetKeepAlivesEnabled(false)
		return httpServer.Shutdown(ctx) // Gracefully shutdown server
	case err := <-errChan:
		return err
	}
}

Let’s unpack this code! First three lines are just creation of a serverMux and defining the routes and handler functions. I have two handlers in the example, the first one is the catch all requests that are coming to the http server listener and the second one is the file server for the letsencrypt certboot. I will explain the httpHandler latter. The mux.Handle("GET /.well-known/acme-challenge/", http.StripPrefix("/.well-known/acme-challenge/", http.FileServer(http.Dir("/challenge/.well-known/acme-challenge/")))) code will handle requests with method GET to path “/.well-known/acme-challenge/” which is the required path for the certbot, to serve the files mounted in the directory “/challenge/.well-known/acme-challenge/". Why this directory?! Because in my production server I am using docker and a docker compose file, where I mount a folder “challenge” at the root directory of the docker container and point the certbot to put the challenge in it. The certbot will then create the file inside the “challenge” folder wich will generate the path “/challenge/.well-known/acme-challenge/”. I could also use the autocert golang module but I thought using the certbot was more appriate considering that it’s let’s encrypt recommended deployment.

Ok, let’s move to the http server options!

httpServer := &http.Server{
	Addr:         addr,
	Handler:      mux,
	ReadTimeout:  5 * time.Second,
	WriteTimeout: 10 * time.Second,
	IdleTimeout:  15 * time.Second,
}

In this section we are starting the http server with the mux created and with the port passed as argument, nothing really complex. I just want to highlight that from a security perspective, you should always define the timeouts (ReadTimeout,WriteTimeout and IdleTimeout) to avoid DoS attacks and make sure that your server runs smoothly. By default timers are not set, and as we know from the security world, the default settings are never secure by default.

Next I create a channel to send the error from a goroutine from the ListenAndServe() to the function and open a select statement to wait for a channel to receive a message. In this select portion of the code, we either receive a Done from the context received as parameter, which means that a signal was received to terminate the code, or the context was canceled. This means that we can initiate a graceful shutdown of the http server. In case there is an error shutting down the server, this error is returned, causing the other goroutines to stop as well. The err := <-errChan listens for errors coming out of the ListenAndServe() goroutine and passes the error back to the errgroup, cancelling the other goroutines!

That’s all on the http server function. The https function should be similar, but it needs a certificate chain, private key and a tls config as well. I will give you an example:

certPair, err := tls.LoadX509KeyPair(server.CertFile, server.KeyFile)
if err != nil {
	slog.Error("failed to load priv/public key pair:")
	return fmt.Errorf("failed to load key pair: %w", err)
}

tlsConfig := &tls.Config{
	Certificates: []tls.Certificate{certPair},
	MinVersion:   tls.VersionTLS12,
	MaxVersion:   tls.VersionTLS13,
	CipherSuites: []uint16{
		tls.TLS_CHACHA20_POLY1305_SHA256,
		tls.TLS_AES_128_GCM_SHA256,
		tls.TLS_AES_256_GCM_SHA384,
		tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256,
		tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
		tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
		tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
		tls.TLS_RSA_WITH_AES_256_GCM_SHA384,
		tls.TLS_RSA_WITH_AES_128_GCM_SHA256,
	},
	CurvePreferences:   []tls.CurveID{tls.X25519, tls.CurveP521, tls.CurveP384, tls.CurveP256},
	ClientSessionCache: tls.NewLRUClientSessionCache(128),
}

httpServer := &http.Server{
	Addr:         addr,
	Handler:      mux,
	TLSConfig:         tlsConfig,
	ReadHeaderTimeout: 10 * time.Second,
	ReadTimeout:  5 * time.Second,
	WriteTimeout: 10 * time.Second,
	IdleTimeout:  15 * time.Second,
}

if err := httpServer.ListenAndServeTLS("",""); err != nil && err != http.ErrServerClosed {
	errChan <- err
}

Ok, so we have come to a state that we have an errgroup, with two goroutines that stop if one of them fails to start, and have a group context that we should be able to cancel if a signal is received. We also have a mux, or router for http and https requests. Great! But how do we send the signals to the goroutines if the program ends after starting them? If the main program ends, the goroutines will end as well. We need to create a go routine that listens to the signals and sends them (with cancel() call) and a wait function to wait until errors come from the goroutines:

// Listen for OS interrupts and cancel context
go func() {
	<-signalChan // Block until an OS signal is received
	fmt.Println("Received interrupt signal, shutting down...")
	cancel() // Cancel the context
}()

// Wait for all goroutines to exit
if err := g.Wait(); err != nil {
	fmt.Println("Error:", err)
	return err
}

If all the goroutines terminate without error(meaning they received a signal and shutdown gracefully), then the line fmt.Println("All services stopped. Exiting.") return nil will execute, returning to the main function.

This will output something like this: Terminal output with signal interruption

If for example you try to run you code and one of the goroutines fails to start, the all program stops.

If one of the goroutines fails to start, in this example I am trying to bind the HTTP server to a port already used, it will terminate in error and as you can see below it says “exit status 1” instead of gracefully. This will also cause the other goroutines to stop, which is desirable as you don’t want the program to run without all the services. Unsucessful exit

The httpHandler is an http.handler function that processes the http request. For this example I have just created a function that prints “Hello World” plus some random string, just to differentiate the output between refreshes. The generateRandomString() code can be seen on the my github repository, link below.

func httpHandler(w http.ResponseWriter, r *http.Request) {
	_, _ = fmt.Fprintf(w, "Hello World %s", generateRandomString(10))
}

Now we have a boilerplate code that we can expand but just adding g.Go(func() error {return functionName()}).

That’s it for today! I have included all the code from this post in my github repository.

I hope you like it, comment and share, make a pull request 😃

comments

comments powered by Disqus