Building a DNS Resolver in Golang: A Step-by-Step Guide

Building a DNS Resolver in Golang: A Step-by-Step Guide

My One-Week Journey to Build a DNS Resolver

Featured on Hashnode

Most of you should be wondering, a DNS what?

A DNS Resolver is a crucial component that your operating system uses to resolve a domain name into an IP address. For example, (dns.google.com -> 8.8.8.8) Your browser does not understand a domain name, it needs an IP address to make a network call to the domain. For a detailed explanation of how DNS works, you can refer to my previous blog post here.

I decided to build this after studying more about DNS and discovering a Coding Challenge related to it. If you don't know about Coding Challenge, it is a set of weekly challenges where we have to build a full application based on real world tools. This challenge consists of building a simple DNS Resolver that can resolve a domain name and return it's IP address.

So with that out of the way, let's see how it was built.

Step 0

In this step, I chose the programming language to build this (Golang, I choose you) and also read up on how the resolver works. I chose Golang because it was the only language I was comfortable to work with for performing low level tasks (which I had to do quite a lot) like byte encoding, handling sockets, etc. Also, I had to look into the RFC implementation on how a DNS message is created, encoded and decoded. These are defined in RFC 1035 With that out of the way, lets see how a DNS Resolver works.

Fun Fact: If you went through the RFC document, you will notice that it called 8-bits as octet sequences instead of simply calling them bytes. That might be because even though most modern computers have 1 byte as 8-bits, that is not always the case for all the computers. Some old computer architectures or supercomputers have 1-byte equal to different number of bits (more than 8 in most of the cases).

Step 1

A DNS Resolver sends a DNS Message over UDP protocol to different nameservers. The message format is defined in Section 4.1 of the RFC document. Each message is divided into 5 different sections:

  • Header

  • Question

  • Answer

  • Authority

  • Additional

    Out of these, every message always has a header and a question where the header defines the metadata of the message (like if the message is a request or a response, whether it has any errors, the number of questions and so on), while the question defines the actual domain name which needs to be resolved.

Header

The header has the following format

  1. ID: A 16-bit ID of the query

  2. Flags: A 16-bit area of the memory after the ID which holds multiple flags (these determine if the query is a response, whether it has error, etc)

  3. QDCOUNT: A 16-bit integer which denotes the number of questions.

  4. ANCOUNT: A 16-bit integer which denotes the number of answers.

  5. NSCOUNT: A 16-bit integer which denotes the number of authorities.

  6. ARCOUNT: A 16-bit integer which denotes the number of additionals.

Overall, each header requires 96-bits, i.e, 12 bytes of data. It is defined in Section 4.1.1

This part made me fascinated on how the RFC writers thought of storing 8 different fields in a single 16-bit integer to reduce the size of the header, something that we simply don't care in today's software but is very important for low level designs.

Question

The question has the following format

  1. QNAME: A variable length space of memory which holds the domain name which needs to be resolved.

  2. QTYPE: A 16-bit integer denoting the record type like A, NS, CNAME, etc.

  3. QCLASS: A 16-bit integer that denotes the class of the query.

Overall, a Question does not have a fixed size. It is defined in Section 4.1.2

Creating the Message

To send in this data, we need to convert the Header and the Question into their byte representation and send it over the network. We don't need to add in the other parts of the message as they are received in the response.

Lets look at the code

First, I needed to define the structs for the Header and Question while also finding a way to convert them into bytes. Below is my Header implementation using unsigned 16 bit integers for the fields.

type Header struct {
    ID uint16 // ID is a 16-bit identifier assigned by the program that generates any kind of query.
    Flags uint16 // Flags contains various control flags for the DNS message.
    QDCount uint16 // QDCount specifies the number of entries in the question section.
    ANCount uint16 // ANCount specifies the number of resource records in the answer section.
    NSCount uint16 // NSCount specifies the number of name server resource records in the authority section.
    ARCount uint16 // ARCount specifies the number of resource records in the additional records section.
}

Next was to convert the Header to it's byte representation. I used the built in packages bytes and encoding/binary for that.

// ToBytes converts the Header to its byte representation.
func (h *Header) ToBytes() []byte {
    buf := new(bytes.Buffer)

    binary.Write(buf, binary.BigEndian, h.ID)
    binary.Write(buf, binary.BigEndian, h.Flags)
    binary.Write(buf, binary.BigEndian, h.QDCount)
    binary.Write(buf, binary.BigEndian, h.ANCount)
    binary.Write(buf, binary.BigEndian, h.NSCount)
    binary.Write(buf, binary.BigEndian, h.ARCount)

    return buf.Bytes()
}

To get that flag value, I also created a Header Flag struct strictly for conversion of flag values to the actual 16-bit integer flag. Here's the implementation.

// HeaderFlag represents the individual flags in the DNS header.
type HeaderFlag struct {
    QR bool // QR indicates whether the message is a query (0) or a response (1).
    Opcode uint8 // Opcode specifies the kind of query in the message.
    AA bool // AA indicates whether the responding name server is an authority for the domain name in question section.
    TC bool // TC indicates whether the message was truncated.
    RD bool // RD indicates whether recursion is desired.
    RA bool // RA indicates whether recursion is available in the name server.
    Z uint8 // Z is reserved for future use.
    RCode uint8 // RCode specifies the response code.
}


// GenerateFlag generates the 16-bit flag value from the individual flag components.
func (hf *HeaderFlag) GenerateFlag() uint16 {
    qr := uint16(boolToInt(hf.QR))
    opcode := uint16(hf.Opcode)
    aa := uint16(boolToInt(hf.AA))
    tc := uint16(boolToInt(hf.TC))
    rd := uint16(boolToInt(hf.RD))
    ra := uint16(boolToInt(hf.RA))
    z := uint16(hf.Z)
    rcode := uint16(hf.RCode)

    return uint16(qr<<15 | opcode<<11 | aa<<10 | tc<<9 | rd<<8 | ra<<7 | z<<4 | rcode)
}

The GenerateFlag() function required some bitwise operations to create the flag. (Who knew they would be useful beyond the college tests). This implementation shifts the important bits of each flag to their respective places and does a bitwise OR operation to get the total which would be the flag value.

Next is the Question Implementation

type Question struct {
    Name string // This is a domain name
    QName string // This is the converted domain name based on the RFC 1035 document
    QType uint16 // The question type
    QClass uint16 // The question class
}

Here, I am saving the name and qname both in the same struct. This was because we had to encode the name to a specific format before we can send it over the network. If the Name was dns.google.com, then the QName would be \x03dns\x06google\x03com\x00 We have to add an additional bit indicating the length of the following part of the domain while removing the dot. The encoded name also needs to end with \x00.

// encodeName encodes the domain name to the format specified in RFC 1035.
func encodeName(name string) string {
    domainParts := strings.Split(name, ".")
    qname := ""
    for _, part := range domainParts {
        newDomainPart := string(byte(len(part))) + part
        qname += newDomainPart
    }
    return qname + "\x00"
}

I call this function everytime a new Question is created to get the encoded name. Then, we need to convert the question to bytes.

// ToBytes converts the Question to its byte representation.
func (q *Question) ToBytes() []byte {
    buf := new(bytes.Buffer)  
    buf.Write([]byte(q.QName))
    binary.Write(buf, binary.BigEndian, q.QType)
    binary.Write(buf, binary.BigEndian, q.QClass)
    return buf.Bytes()
}

Combining that, A DNS Message implementation will look like this.

type DNSMessage struct {
    Header Header
    Questions []Question
    Answers []ResourceRecord
    AuthorityRRs []ResourceRecord
    AdditionalRRs []ResourceRecord
}

// NewDNSMessage creates a new DNSMessage with the given header, questions, and resource records.
// It returns a pointer to the created DNSMessage.
func NewDNSMessage(header Header, questions []Question, records ...[]ResourceRecord) *DNSMessage {

    answers := make([]ResourceRecord, 0)
    authorityRRs := make([]ResourceRecord, 0)
    additionalRRs := make([]ResourceRecord, 0)

    if len(records) > 0 {
        answers = records[0]
    }

    if len(records) > 1 {
        authorityRRs = records[1]
    }

    if len(records) > 2 {
        additionalRRs = records[2]
    }

    return &DNSMessage{
        Header: header,
        Questions: questions,
        Answers: answers,
        AuthorityRRs: authorityRRs,
        AdditionalRRs: additionalRRs,
    }
}

This takes the header and question and create a new DNS Message to send over the network.

Step 2

In this step, we had to send the message over the network to the nameserver and read back the response. I created a client to do that.

// Client represents a UDP client for sending DNS queries.
type Client struct {
    ipAddress string
    port int
}

// Query sends a message to the given ip address and port and returns the response.
func (c *Client) Query(message []byte) ([]byte, error) {
    // Create a UDP connection
    ipType, err := c.ipType()
    var addr string
    if err != nil {
        return nil, fmt.Errorf("failed to get the IP type: %v", err)
    }

    if ipType == "ipv4" {
        addr = fmt.Sprintf("%s:%d", c.ipAddress, c.port)
    } else if ipType == "ipv6" {
        addr = fmt.Sprintf("[%s]:%d", c.ipAddress, c.port)
    }

    conn, err := net.Dial("udp", addr)
    if err != nil {
        return nil, fmt.Errorf("failed to connect to the DNS server: %v", err)
    }

    defer conn.Close()

    // Set a timeout for the connection
    conn.SetDeadline(time.Now().Add(5 * time.Second))

    // Send a message
    _, err = conn.Write(message)
    if err != nil {
        return nil, fmt.Errorf("failed to send the DNS message: %v", err)
    }

    // Receive the response
    buf := make([]byte, 1024)

    // Read the response
    n, err := conn.Read(buf)

    if err != nil {
        return nil, fmt.Errorf("failed to read the response: %v", err)
    }

    response := buf[:n]

    // Check if the response ID matches the request ID
    if !IDMatcher(message[:2], response[:2]) {
        return nil, fmt.Errorf("the response ID does not match the request ID")
    }

    return response, nil
}

The Query() function takes in the message that we need to send and it sends it to the ipAddress and Port defined in the Client. We create a UDP socket, set a timeout of 5 second and then writing and reading the socket for the response.

We also match the id of the request with the response to verify if the response is for this specific request.

Step 3

Probably the most interesting step / the step that took me the longest. Here, we have to parse the response to get the actual data that we can use. That means parsing all the resource records that is sent back to the resolver (where each of them is of variable length). Here was my implementation of ResourceRecord struct and parser:

type ResourceRecord struct {
    Name string // The domain name of the resource record
    Type uint16 // The type of the resource record
    Class uint16 // The class of the resource record
    TTL uint32 // The time to live of the resource record
    RDLength uint16 // The length of the resource data
    RData []byte // The resource data
    RDataParsed string // The parsed resource data
}

// ResourceRecordFromBytes creates a ResourceRecord from a byte slice.
func ResourceRecordFromBytes(data []byte, messageBufs ...*bytes.Buffer) *ResourceRecord {
    buf := bytes.NewBuffer(data)
    var messageBuf *bytes.Buffer
    if messageBufs != nil {
        messageBuf = messageBufs[0]
    }

    name := appendFromBufferUntilNull(buf)
    nameLength := len(name) - 1
    decodedName, err := DecodeName(string(name), messageBuf)

    if err != nil {
        fmt.Printf("Failed to decode the name: %v\n", err)
    }

    typ := binary.BigEndian.Uint16(data[nameLength : nameLength+2])
    class := binary.BigEndian.Uint16(data[nameLength+2 : nameLength+4])
    ttl := binary.BigEndian.Uint32(data[nameLength+4 : nameLength+8])
    rdLength := binary.BigEndian.Uint16(data[nameLength+8 : nameLength+10])
    rData := data[nameLength+10 : nameLength+10+int(rdLength)] // 10 is the length of the fields before RData
    rDataParsed, _ := parseRData(typ, rData, messageBuf)

    return &ResourceRecord{
        Name: decodedName,
        Type: typ,
        Class: class,
        TTL: ttl,
        RDLength: rdLength,
        RData: rData,
        RDataParsed: rDataParsed,
    }
}

You might wonder on why we are taking a 2nd argument for the function. Well, that has something to do with Message Compression that happens for the response. You see, a domain name from the response usually ends with a 0 bit (\x00) but to reduce the size of the response, if a domain (or a part of it) was already mentioned in the response before, then the later mentions of that domain or it's parts will instead use a pointer.

The pointer will use 2-bytes where the most significant 2 bits of the 16-bits will be 11, while the rest would be the offset. The offset is the number of bytes from the start of the whole response body.

Thus, i am passing the response as a buffer to the parser so that it can be used to decode the domain. That is happening here: decodedName, err := DecodeName(string(name), messageBuf)

Here's the implementation of DecodeName:

// DecodeName decodes the encoded domain name to its original format.
func DecodeName(qname string, messageBufs ...*bytes.Buffer) (string, error) {
    encoded := []byte(qname)
    var result bytes.Buffer
    var messageBuf *bytes.Buffer
    if messageBufs != nil {
        messageBuf = messageBufs[0]
    }

    for i := 0; i < len(encoded); {
        length := int(encoded[i])
        if length == 0 {
            break
        }

        if encoded[i]>>6 == 0b11 && messageBuf != nil {
            // Check if the name is a pointer. Parse the pointer, get the offset and parse the name from the offset.
            // See https://datatracker.ietf.org/doc/html/rfc1035#section-4.1.4 for more information
            b := encoded[i+1]
            offset := int(b & 0b11111111)
            messageBytes := messageBuf.Bytes()
            messageBytes = messageBytes[offset:]
            name := appendFromBufferUntilNull(bytes.NewBuffer(messageBytes))
            n, _ := DecodeName(string(name))
            name = []byte(n)
            length = len(name)
            if result.Len() > 0 {
                result.WriteByte('.')
            }
            result.Write(name)
            i += length
            break
        }
        i++

        if i+length > len(encoded) {
            return "", fmt.Errorf("invalid encoded domain name")
        }
        if result.Len() > 0 {
            result.WriteByte('.')
        }
        result.Write(encoded[i : i+length])
        i += length
    }

    return result.String(), nil
}

This code might get a bit complex but it basically iterates over the domain name and parses it while checking if the name contains any pointer. If it finds a pointer, then it parses the buffer to get the part that which the offset mentions. It can return the result after doing this as I know from the RFC that there can't be multiple pointers in a single domain (no need for recursion then). If anyone can improve on this part, then feel free to contribute.

To wrap up the resolver, I created a Resolve function resolve the IP address of the domain name using all the parts mentioned above.

Final Result

Here is the resolver in action after doing all that:

As you can see from the image, the resolver works fine for finding the IPv4 addresses and also supports NS and CNAME records. But we can do better than that.

Adding Cache

You thought this was over, but no, it's me, cache.

Okay, I shouldn't have done that.

Anyways, this was not part of the coding challenge but I was like, why not. A DNS Resolver usually caches the DNS records based on the TTL it receives from the response. That is why when we make any changes to DNS records, it takes time to propagate those changes as those records are already cached in a resolver.

There were few reasons I chose SQLite for caching:

  • I didn't want to deal with a webserver just to maintain a cache layer for my CLI Resolver.

  • i needed something simple and lightweight.

  • I wanted it to be fast & persistent.

    SQLite checks all these requirements so i went ahead and created a SQLite client with a few exposed APIs for cache operations.

type CacheClient struct {
    db *sql.DB
}

// NewClient creates a new CacheClient instance.
func NewClient(cachePaths ...string) (*CacheClient, error) {
    var cachePath string
    if len(cachePaths) > 0 {
        cachePath = cachePaths[0]
    } else {
        dir, err := os.Getwd()
        if err != nil {
            return nil, err
        }
        cachePath = dir + "/cache.db"
    }
    db, err := sql.Open("sqlite3", cachePath)
    if err != nil {
        return nil, err
    }

    err = createTable(db)
    if err != nil {
        return nil, err
    }

    client := &CacheClient{
        db: db,
    }

    go client.ClearExpiredRecords()

    return client, nil
}



// createTable creates the table in the database if it doesn't exist.
func createTable(db *sql.DB) error {
    _, err := db.Exec(`
        CREATE TABLE IF NOT EXISTS dns_records (
        id INTEGER PRIMARY KEY,
        domain TEXT NOT NULL,
        type INTEGER NOT NULL,
        address TEXT NOT NULL,
        ttl INTEGER NOT NULL,
        created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
        expired_at DATETIME
        );
    `)
    return err
}

// ClearExpiredRecords deletes all the expired records from the cache.
// It is called automatically every time the client is created.
func (client *CacheClient) ClearExpiredRecords() error {
    _, err := client.db.Exec(`DELETE FROM dns_records WHERE expired_at < ?`, time.Now())
    return err
}

// Exposed APIs
func (client *CacheClient) Get(domain string) ([]dns.ResourceRecord, error) {}

func (client *CacheClient) Insert(domain string, recordType uint16, address string, ttl int) error {}

func (client *CacheClient) Delete(domain string) error {}

func (client *CacheClient) Close() error {}

The Cache Client will create a sqlite embedded db file while also creating the table and clearing expired records from cache. This will ensure I don't have to handle those configurations later.

With this cache layer, anytime I resolve the domain, the records gets cached so that the later records can be served through the cache. I have also added a --no-cache option to CLI if we have to skip the cache.

Final Thoughts

In conclusion, building a DNS Resolver in Golang was a rewarding experience that deepened my understanding of DNS and low-level programming. I had a ton of fun learning and implementing this in Golang. There are still a few things that can be implemented, so if you want to see the full codebase or contribute, feel free to check out the repo: https://github.com/Harsh-2909/dns-resolver-go

Some articles I found helpful while building this: