18th October 2019

Let's implement zlib.decompress()

Ever wondered how lossless compression in PNG files, ZIP files is achieved? In this article, we will attempt to gain some insight into it by writing a decompressor for the zlib format, which PNG and ZIP files use for compression!

Bit Reader

The zlib/DEFLATE spec is interesting in that it usually deals with streams of bits rather than bytes. To make it easier to implement later algorithms, we shall first implement a BitReader class which will handle the details of extracting individual bits and bytes from a bytestring:

class BitReader:
    def __init__(self, mem):
        self.mem = mem
        self.pos = 0
        self.b = 0
        self.numbits = 0

    def read_byte(self):
        self.numbits = 0 # discard unread bits
        b = self.mem[self.pos]
        self.pos += 1
        return b

    def read_bit(self):
        if self.numbits <= 0:
            self.b = self.read_byte()
            self.numbits = 8
        self.numbits -= 1
        # shift bit out of byte
        bit = self.b & 1
        self.b >>= 1
        return bit

    def read_bits(self, n):
        o = 0
        for i in range(n):
            o |= self.read_bit() << i
        return o

    def read_bytes(self, n):
        # read bytes as an integer in little-endian
        o = 0
        for i in range(n):
            o |= self.read_byte() << (8 * i)
        return o

For the DEFLATE spec, the rules of extracting bits from bytestring are as follows:

  • Given a byte, the first bit from DEFLATE's perspective (0th bit) is the byte's least significant bit, while the last bit from DEFLATE's perspective (7th bit) is the byte's most significant bit. e.g. for the byte 0x9d, the bits in from first to last are 1, 0, 1, 1, 1, 0, 0, 1.

    This is implemented in read_bit(), which does this by first reading a byte (with read_byte()), shifting a bit out and then returning it. This continues for each call to read_bit() until all 8 bits have been shifted out, at which point it will read the next byte. Let's try it out:

    r = BitReader(b'\x9d')
    print(r.read_bit()) # 1
    print(r.read_bit()) # 0
    print(r.read_bit()) # 1
    print(r.read_bit()) # 1
    print(r.read_bit()) # 1
    print(r.read_bit()) # 0
    print(r.read_bit()) # 0
    print(r.read_bit()) # 1
    print(r.read_bit()) # IndexError: index out of range
  • The DEFLATE spec sometimes requires reading an unsigned integer I comprising of n bits from the stream. In this case, the 1st bit that is read from the stream is the least significant bit, while the n-th bit read from the stream is the most significant bit. For example, consider the integer 299, which is 0b100101011 in binary (9 bits). Assuming that all the other bits are zero, this would be encoded as the two bytes b'\x2b\x01'.

    This is implemented in read_bits(self, n), which implements this by calling read_bit() n times and placing the bit at the correct bit position. Let's try it out:

    r = BitReader(b'\x2b\x01')
    print(r.read_bits(9)) # 299
  • The DEFLATE spec also sometimes requires reading an unsigned integer I comprising of n bytes from the stream. In this case, whatever bits that were not read from read_bits()/read_bit() are discarded, and n bytes are read. These n bytes are in little-endian order, that is, least significant byte first. For example, the integer 13926 is encoded as b'\x66\x36'.

    This is implemented in read_bytes(self, n), which implements this by calling read_byte() n times and placing the bits of the bytes at the correct bit position. Let's try it out:

    r = BitReader(b'\x66\x36')
    print(r.read_bytes(2)) # 13926

The zlib container

Now that we have BitReader, let's start from the top. The zlib format, starting from the start of the file/bytestring, is as follows:

  • 1 byte: CMF --- Compression Method and compression info fields
  • 1 byte: FLG --- Compression flags fields
  • Variable number of bytes: The DEFLATE data
  • 4 bytes: ADLER32 -- Adler-32 checksum over the original uncompressed data. For this exercise, we will simply ignore it.
  • End of file/bytestring

As we can see, the zlib format is a very lightweight container format over the DEFLATE format, adding only 6 extra bytes on top.

The CMF fields are as follows:

  • Bits 0 to 3: CM --- Compression Method. This identifies the compression method in the file. Only CM=8 is defined by the zlib spec, which basically says that the zlib file contains data compressed with DEFLATE. Any other value of CM is not supported, and we should error out.
  • Bits 4 to 7: CINFO --- Compression info. This is the base-2 logarithm of the LZ77 window size, minus eight. In other words, the window size is . The maximum window size that is allowed by the spec is 32768, which is . In other words, CINFO must be , and any other value should be treated as an error.

The FLG fields are as follows:

  • Bits 0 to 4: FCHECK --- Used as part of the checksum. See below.
  • Bits 5: FDICT --- If set, a DICT dictionary identifier is present immediately after the FLG byte. The dictionary is a sequence of bytes which is known beforehand to both the compressor and decompressor that can be used to achieve greater compression ratios. The zlib spec does not define any preset dictionaries and leaves it up to the implementor. The file formats that we are interested in (e.g. PNG, ZIP files) also do not specify any preset dictionaries. As such, we should not need to handle preset dictionaries at all, and we should error out if the FDICT bit is set.
  • Bits 6 to 7: FLEVEL --- Compression level. This indicates whether the original data was compressed with the fastest/fast/default/max-compression compression level. It's not needed for de-compression at all, and is only there to indicate if recompression might be worthwhile. For our purposes, we can simply ignore it.

In addition, the CMF and FLG bytes also serve as a checksum: must be a multiple of 31. Bits 0 to 4 in the FLG byte are set in the way that ensures that this is true if the CMF and FLG bytes are uncorrupted. This makes sense: since , and 5 bits can store values 0..31 inclusive, 5 bits is sufficient in ensuring that we can specify a value that is added to any integer to make it a multiple of 31.

As such, let's implement our top-level decompress() function:

def decompress(input):
    r = BitReader(input)
    CMF = r.read_byte()
    CM = CMF & 15 # Compression method
    if CM != 8: # only CM=8 is supported
        raise Exception('invalid CM')
    CINFO = (CMF >> 4) & 15 # Compression info
    if CINFO > 7:
        raise Exception('invalid CINFO')
    FLG = r.read_byte()
    if (CMF * 256 + FLG) % 31 != 0:
        raise Exception('CMF+FLG checksum failed')
    FDICT = (FLG >> 5) & 1 # preset dictionary?
    if FDICT:
        raise Exception('preset dictionary not supported')
    out = inflate(r) # decompress DEFLATE data
    ADLER32 = r.read_bytes(4) # Adler-32 checksum (for this exercise, we ignore it)
    return out

Let's say we have some zlib compressed data like this:

import zlib
x = zlib.compress(b'Hello World!')

We should then be able to decompress it with our implemented decompress() function as follows:

print(decompress(x)) # Supposed to print b'Hello World!'

Of course, it does not presently work because it calls inflate() to handle the DEFLATE data, but we haven't implemented inflate() yet. We shall implement it in the following sections.

DEFLATE Blocks

A DEFLATE stream consists of a series of blocks. Each block begins with 3 header bits containing the following data:

BitsNameMeaning
first bitBFINALSet if and only if this is the last block of data.
next 2 bitsBTYPE

Specifies how the data are compressed, as follows:

  • 00 - no compression
  • 01 - compressed with fixed Huffman codes
  • 10 - compressed with dynamic Huffman codes
  • 11 - reserved (error)

As such, let's define the main function of our library, inflate(), which will inflate a compressed DEFLATE bitstream.

def inflate(r):
    BFINAL = 0
    out = [] # list of integers 0..255, representing decompressed bytes
    while not BFINAL:
        BFINAL = r.read_bit()
        BTYPE = r.read_bits(2)
        if BTYPE == 0:
            inflate_block_no_compression(r, out)
        elif BTYPE == 1:
            inflate_block_fixed(r, out)
        elif BTYPE == 2:
            inflate_block_dynamic(r, out)
        else:
            raise Exception('invalid BTYPE')
    return bytes(out) # return decompressed bytes as bytestring

We can't run it yet because it refers to missing functions inflate_block_no_compression(), inflate_block_fixed() and inflate_block_dynamic() (you can stub them out for now). In the rest of this article we will fill them up one by one.

Non-compressed blocks

Let's start out with something very simple -- non-compressed blocks. After the BFINAL and BTYPE bits, any bits of input up to the next byte boundary are ignored. Then, the rest of the block consists of:

  • 2 bytes: LEN --- the number of data bytes in the block.
  • 2 bytes: NLEN -- the one's complement of LEN.
  • LEN bytes: The literal data

As such, implementing inflate_block_no_compression() is trivial:

def inflate_block_no_compression(r, o):
    LEN = r.read_bytes(2)
    NLEN = r.read_bytes(2)
    o.extend(r.read_byte() for _ in range(LEN))

As mentioned when we were implementing BitReader, read_bytes() will discard any unread bits and move to the next byte boundary, which is what we want in this case.

At this point, our decompress() function will work for zlib bytestrings that have no compression! Let's try it out:

import zlib
x = zlib.compress(b'Hello World!', level=0) # level 0 means no compression
print(decompress(x)) # b'Hello World!'

Now, let's move on to implement inflate_block_fixed() and inflate_block_dynamic() so that we can handle compressed data. To do that, we must first understand how DEFLATE does compression.

Huffman coding

The first half of DEFLATE's compression story is huffman coding. Here's an example to show how huffman coding works:

Suppose we only allow strings with the characters A, B, C and D. Formally, this is called the alphabet . An example of valid string which meets our conditions is "BBBACD". An example of an invalid string is "BBBACDF", since it contains the letter "F" which is not in our alphabet.

Now, let's say we want to encode the string "BBBACD" into binary. One simple way to do it would be to map the characters of the alphabet to integers. Since there are 4 characters in the alphabet, we will map them to the range 0..3 inclusive. The maximum integer is 3, which requires at least 2 bits (b11 in binary) to represent, and so we encode these integers with 2 bits each.

We map to 0 (binary: 0b00), to 1 (binary: 0b01), to 2 (binary: 0b10) and to 3 (binary: 0b11). The string "BBBACD" encoded in our simple encoding is thus 0b010101001011, which is bits in total.

How could we compress the string "BBBACD" by encoding it with less bits? Once way is to use a variable-length encoding which gives shorter codes to more frequent characters.

A huffman code is a variable-length code table for encoding symbols of an alphabet into a string of bits. An example huffman code table for encoding the alphabet is:

SymbolCodeword
A0b01
B0b1
C0b000
D0b001

One important constraint of a huffman code is that no codeword can be a prefix of another codeword. For example, we can't assign B the codeword 0b0 because 0b0 is a prefix of A's codeword (0b01). Likewise, we can't assign C the codeword 0b100 because B's codeword, 0b1, is a prefix of 0b100. This constraint is needed to prevent ambiguity during decoding. Let's say B is the codeword 0b0. In that case, an ambiguity exists when decoding 0b001 --- should we decode it to "BA" or to "D"?

Now, we can see from the table that B can be encoded with 1 bit, A with 2 bits, and C and D with 3 bits. As such, the string "BBBACD" will be encoded as 0b11101000001, which is 11 bits in total. This is shorter than the simple encoding that we used earlier, and so compression has been achieved.

How would we encode 0b11101000001 into bytes? The DEFLATE spec specifies that huffman-coded bits should be encoded in such a way that when reading it back with BitReader, the first bit that read_bit() returns is the most significant bit (0b1), the next bit that read_bit() returns is the second-most significant bit (0b1) etc. To help us out later, let's implement code_to_bytes(code, n) which will help us do just that:

def code_to_bytes(code, n):
    # Encodes a code that is `n` bits long into bytes that is conformant with DEFLATE spec
    out = [0]
    numbits = 0
    for i in range(n-1, -1, -1):
        if numbits >= 8:
            out.append(0)
            numbits = 0
        out[-1] |= (1 if code & (1 << i) else 0) << numbits
        numbits += 1
    return bytes(out)

Testing it out:

r = BitReader(code_to_bytes(0b11101000001, 11))
print([r.read_bit() for _ in range(11)]) # [1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 1]

Next, given a huffman code table like above, how could we write a program that can decode any arbitrary valid bitstream? One way would be to first convert the above table into a huffman tree. A huffman tree is a binary tree whose leaves are the symbols of the alphabet. For a given internal node, the left edge represents the bit 0 while the right edge represents the bit 1. The equivalent huffman tree for the huffman code in the table above is as follows:

From the huffman tree, we can derive the codeword of a symbol by starting from the root and walking down the edges to the leaf corresponding to the leaf, while noting down whether we took the left or right edge. For example, to get to D from the root node, we first start from the root node, take the left edge (0), then the left edge (0), then finally the right edge (1) to get to D, which gives us the code 0b001 which is D's codeword in the table above!

Likewise, to decode a symbol from a bitstream, we can:

  1. Start from the root node.
  2. Read a bit from the bitstream. If it is a 0, take the left edge, otherwise (if it is a 1), take the right edge.
  3. Repeat Step 2 until we reach a leaf node. That is the decoded symbol.

As we can see, having a huffman tree makes it easy to write an algorithm that can decode symbols from a bitstream in linear time.

Now let's implement this. We start we the implementation of a HuffmanTree that has an insert() operation that allows us to insert codeword -> symbol mappings into the tree (with the assumption that they obey the "no codeword can be a prefix of another codeword" constraint).

class Node:
    def __init__(self):
        self.symbol = ''
        self.left = None
        self.right = None

class HuffmanTree:
    def __init__(self):
        self.root = Node()

    def insert(self, codeword, n, symbol):
        # Insert an entry into the tree mapping `codeword` of len `n` to `symbol`
        node = self.root
        for i in range(n-1, -1, -1):
            b = codeword & (1 << i)
            if b:
                next_node = node.right
                if next_node is None:
                    node.right = Node()
                    next_node = node.right
            else:
                next_node = node.left
                if next_node is None:
                    node.left = Node()
                    next_node = node.left
            node = next_node
        node.symbol = symbol

We can now construct our example huffman tree as follows:

t = HuffmanTree()
t.insert(0b01, 2, 'A')
t.insert(0b1, 1, 'B')
t.insert(0b000, 3, 'C')
t.insert(0b001, 3, 'D')

Then, we can write a decode_symbol() function that decodes a symbol from a bitstream using a specified huffman tree. This is implemented using the algorithm described above:

def decode_symbol(r, t):
    "Decodes one symbol from bitstream `r` using HuffmanTree `t`"
    node = t.root
    while node.left or node.right:
        b = r.read_bit()
        node = node.right if b else node.left
    return node.symbol

Now, let's try decoding 0b11101000001 back into "BBBACD":

r = BitReader(code_to_bytes(0b11101000001, 11))
print(decode_symbol(r, t)) # 'B'
print(decode_symbol(r, t)) # 'B'
print(decode_symbol(r, t)) # 'B'
print(decode_symbol(r, t)) # 'A'
print(decode_symbol(r, t)) # 'C'
print(decode_symbol(r, t)) # 'D'

LZ77 and the Literal/Length and Distance Alphabet

One might imagine that the DEFLATE spec simply uses Huffman Coding as-is, but with an alphabet that consists of the values 0..255 inclusive, which are the values that a byte can take. That way, bytes which occur frequently in a file will be given shorter codewords, thus compressing the file. This is only half-correct --- the DEFLATE spec has an additional trick up its sleeve which achieve better compression ratios on files that contain repeated data: LZ77 compression.

Suppose we have a file that consists of the pattern "BBA" repeated over and over again, i.e. "BBABBABBABBA...". Even if we encode B and A one bit each, the size of the compressed file will still be linear to the size of the original file.

We can compress this file even further by employing LZ77 compression. With LZ77 compression, we can first writing the literal values "BBA" into the output compressed file. Next, rather than writing "BBA" again, we write a <length, backward distance> marker that specifies that the (de-compressed) string of the specified length, backward distance away, should be duplicated. For example, the string "BBABBA" (length 6) can be encoded as "BBA <3,3>", the string "BBABBABBA" (length 9) can be encoded as "BBA <6,3>", and the string "BBABBABBABBABBABBA" (length 18) can be encoded as "BBA <15,3>"! This means that we can save nearly an exponential amount of space!

Now, using LZ77 compression as well means that a DEFLATE stream would contain both literal byte values and <length, backward distance> pairs. So, how does DEFLATE encode them and distinguish between them at de-compression time? It does this by defining two alphabets --- the Literal/Length alphabet and the Distance alphabet:

  • The Literal/Length alphabet consists of the values 0..285 inclusive. Values 0..255 are the corresponding literal byte values. Value 256 is the end-of-block marker. Values 257..285 are used to encode the "length" portion of <length, backward distance> pairs, although it does not encode the length literally. See below.
  • The Distance alphabet consists of the values 0..29 inclusive. They are used to encode the "backward distance" portion of <length, backward distance> pairs, although they do not encode them literally. See below.

These two alphabets each have their own huffman trees, which we shall call the Literal/Length tree and the Distance tree.

Values 257..285 of the Literal/Length alphabet are used to represent length values, sometimes in conjunction with extra bits that come after the symbol:

SymbolExtra bitsLength(s)
25703
25804
25905
26006
26107
26208
26309
264010
265111,12
266113,14
267115,16
268117,18
269219-22
270223-26
271227-30
272231-34
273335-42
274343-50
275351-58
276359-66
277467-82
278483-98
279499-114
2804115-130
2815131-162
2825163-194
2835195-226
2845227-257
2850258

Example 1: Let's say we call decode_symbol() with the literal/length huffman tree and it returns 42. Since it is within the range 0..255, it is a literal byte and so we should write byte 42 to the output stream.

Example 2: Let's say we call decode_symbol() with the literal/length huffman tree and it returns 256. Since 256 is the end-of-block marker, we should terminate parsing of the block.

Example 3(a): Let's say we call decode_symbol() with the literal/length huffman tree and it returns 257. According to the table, symbol 257 has 0 extra bits. Therefore we know immediately that we have reached a <length, backward distance> pair with length 3. We then expect the backward distance of the pair to follow and we should next call decode_symbol() with the distance tree to decode the backward distance. (See Example 3(b) for continuation)

Example 4(a): Let's say we call decode_symbol() with the literal/length huffman tree and it returns 275. According to the table, symbol 275 has 3 extra bits. We thus call read_bits(3) to read those 3 extra bits. Let's say that read_bits(3) returns 2. Therefore, we know that we have reached a <length, backward distance> pair with length 51 + 2 = 53. We then expect the backward distance of the pair to follow and we should next call decode_symbol() with the distance tree to decode the backward distance. (See Example 4(b) for continuation)

Similarly to the literal/length alphabet, the distance alphabet is used to represent backward distance values, sometimes in conjunction with extra bits that come after the symbol:

SymbolExtra BitsDistance(s)
001
102
203
304
415,6
517,8
629-12
7213-16
8317-24
9325-32
10433-48
11449-64
12565-96
13597-128
146129-192
156193-256
167257-384
177385-512
188513-768
198769-1024
2091025-1536
2191537-2048
22102049-3072
23103073-4096
24114097-6144
25116145-8192
26128193-12288
271212289-16384
281316385-24576
291324577-32768

Example 3(b): After decoding the length of the <length, backward distance> pair in Example 3(a) and getting a length of 3, we then call decode_symbol() with the distance tree to decode the backward distance. Let's say decode_symbol() returns 2. According to the table, symbol 2 has 0 extra bits. Therefore we know immediately that we have a backward distance of 3, and so our <length, backward distance> pair is . As such, we should then copy 3 bytes 3 bytes back from the current position in the output stream to the output stream.

Example 4(b): After decoding the length of the <length, backward distance> pair in Example 4(a) and getting a length of 53, we then call decode_symbol() with the distance tree to decode the backward distance. Let's say decode_symbol() returns 12. According to the table, symbol 12 has 5 extra bits. We thus call read_bits(5) to read those extra bits. Let's say read_bits(5) returns 30. Therefore, the backward distance is 65 + 30 = 95. As such, we should then copy 53 bytes 95 bytes back from the current position in the output stream to the output stream.

Now let's implement it.

Firstly, we have the tables:

LengthExtraBits = [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 0]
LengthBase = [3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 15, 17, 19, 23, 27, 31, 35, 43, 51, 59, 67, 83, 99, 115, 131, 163, 195, 227, 258]
DistanceExtraBits = [0, 0, 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9, 10, 10, 11, 11, 12, 12, 13, 13]
DistanceBase = [1, 2, 3, 4, 5, 7, 9, 13, 17, 25, 33, 49, 65, 97, 129, 193, 257, 385, 513, 769, 1025, 1537, 2049, 3073, 4097, 6145, 8193, 12289, 16385, 24577]

LengthExtraBits is the "Extra bits" column of the Literal/Length table, and LengthBase is the corresponding base integer of the "Length(s)" column of the Literal/Length table. Likewise, DistanceExtraBits is the "Extra bits" column of the Distance table, and DistanceBase is the corresponding base integer of the "Distance(s)" column of the Distance table.

And next, we have the decode function:

def inflate_block_data(r, literal_length_tree, distance_tree, out):
    while True:
        sym = decode_symbol(r, literal_length_tree)
        if sym <= 255: # Literal byte
            out.append(sym)
        elif sym == 256: # End of block
            return
        else: # <length, backward distance> pair
            sym -= 257
            length = r.read_bits(LengthExtraBits[sym]) + LengthBase[sym]
            dist_sym = decode_symbol(r, distance_tree)
            dist = r.read_bits(DistanceExtraBits[dist_sym]) + DistanceBase[dist_sym]
            for _ in range(length):
                out.append(out[-dist])

Encoding the literal/length and distance huffman trees

So, given the literal/length and distance huffman trees, we can de-compress (or inflate) a block of data using our newly-implemented inflate_block_data() function. However, how are these trees even stored in the first place?

Once way would be to just store the codeword -> symbol mappings as-is. However, the DEFLATE spec uses a different method which allows huffman trees to be stored more compactly in the file by placing additional restrictions on them. In particular, the DEFLATE spec requires the huffman trees to be canonical huffman trees.

To recap, a huffman code is required to be a prefix code -- no codeword can be a prefix of another codeword.

On top of that, canonical huffman codes adds the following additional requirements:

  1. The first symbol whose codeword has bit length 1 must have the codeword of 0b0.
  2. All codes of a given bit length have lexicographically consecutive values, in the same order as the symbols they represent.
  3. Shorter codes lexicographically precede longer codes.

For example, near the beginning of this article we had this example huffman code defined over the alphabet :

SymbolCodeword
A0b01
B0b1
C0b000
D0b001

This is not a valid canonical huffman code, and thus cannot be encoded in the DEFLATE file. This is because:

  • The first symbol whose codeword has bit length 1, B, does not have a codeword of 0b0, violating rule 1.
  • B has a shorter code compared to A, but it is lexicographically greater compared to A, violating rule 3.

The fixed huffman code that meets all the conditions of being a canonical huffman code is:

SymbolCodeword
A0b10
B0b0
C0b110
D0b111

Why require the huffman code to be canonical? This is because any canonical huffman code can be uniquely compactly represented by (1) a list of the symbols of its alphabet (in the alphabet's order), and (2) the bit length of the codeword for each symbol of the alphabet. Furthermore, if the alphabet is already known by the decoder, then the alphabet list is not required at all, and only the bit length list (which we shall call bl) needs to be provided.

For example, let's try to derive the canonical huffman code in the table above with just its alphabet and its bit length list bl.

To start off, the corresponding alphabet and bit length list bl is:

alphabet = 'ABCD'
bl = [2, 1, 3, 3]

We also compute MAX_BITS, which is the size of the longest codeword:

MAX_BITS = max(bl)

To derive the huffman tree, we first compute the number of codes for each code length.

bl_count = [sum(1 for x in bl if x == y and y != 0) for y in range(MAX_BITS+1)]
print(bl_count) # [0, 1, 1, 2]

Next, we compute next_code such that next_code[n] is the smallest codeword with code length n.

next_code = [0, 0]
for bits in range(2, MAX_BITS+1):
    next_code.append((next_code[bits-1] + bl_count[bits-1]) << 1)
print(next_code) # [0, 0, 2, 6]

Sanity check: next_code[1] is 0 (binary: 0b0) which is correct since B's codeword is 0b0. next_code[2] is 2 (binary: 0b10) which is correct since A's codeword is 0b10. next_code[3] is 6 (binary: 0b110) which is correct since C's codeword is 0b110. So everything seems fine.

Finally, we can compute the code for each symbol of our alphabet, and enter it into our huffman tree, taking care to ignore symbols of our alphabet whose codeword has a bit length of 0 (meaning the symbol is not used in the compressed data):

t = HuffmanTree()
for c, bitlen in zip(alphabet, bl):
    if bitlen != 0: # ignore if bit length is 0
        t.insert(next_code[bitlen], bitlen, c)
        next_code[bitlen] += 1

Now, let's try decoding b00010110111, which is "BBBACD":

r = BitReader(code_to_bytes(0b00010110111, 11))
print(decode_symbol(r, t)) # B
print(decode_symbol(r, t)) # B
print(decode_symbol(r, t)) # B
print(decode_symbol(r, t)) # A
print(decode_symbol(r, t)) # C
print(decode_symbol(r, t)) # D

So everything seems correct!

What happens if the alphabet is as usual, but C and D are not used at all? In that case, the DEFLATE spec specifies that C and D should be specified as having bit length zero. In other words,

alphabet = 'ABCD'
bl = [2, 1, 0, 0]

Furthermore, the DEFLATE spec allows trailing zeroes in the bl list to be ommitted. In other words, this is valid:

alphabet = 'ABCD'
bl = [2, 1]

To wrap up, here is a reusable bl_list_to_tree() function which constructs a huffman tree from a bit length list, implemented using all the algorithms we have covered in this section:

def bl_list_to_tree(bl, alphabet):
    MAX_BITS = max(bl)
    bl_count = [sum(1 for x in bl if x == y and y != 0) for y in range(MAX_BITS+1)]
    next_code = [0, 0]
    for bits in range(2, MAX_BITS+1):
        next_code.append((next_code[bits-1] + bl_count[bits-1]) << 1)
    t = HuffmanTree()
    for c, bitlen in zip(alphabet, bl):
        if bitlen != 0:
            t.insert(next_code[bitlen], bitlen, c)
            next_code[bitlen] += 1
    return t

Usage:

t = bl_list_to_tree([2, 1, 3, 3], 'ABCD')
r = BitReader(code_to_bytes(0b00010110111, 11))
print(decode_symbol(r, t)) # B
print(decode_symbol(r, t)) # B
print(decode_symbol(r, t)) # B
print(decode_symbol(r, t)) # A
print(decode_symbol(r, t)) # C
print(decode_symbol(r, t)) # D

However, the DEFLATE spec adds yet another twist. Rather than just specifying the code lengths directly, for "even greater compactness", the code length sequences themselves are compressed using a Huffman code! The alphabet for the code lengths is as follows:

SymbolMeaning
0 - 15Represent code lengths of 0 - 15
16Copy the previous code length 3 - 6 times. The next 2 bits indicate repeat length ( 0 = 3, ..., 3 = 6)
17Repeat a code length of 0 for 3 - 10 times. (3 bits of length)
18Repeat a code length of 0 for 11 - 138 times. (7 bits of length)

To make this clear: we actually now have three huffman codes in play here --- the code length huffman code, the literal/length huffman code and the distance huffman code. For clarity let's give them mathematical symbols as follows:

  • Let be the code length alphabet (the table directly above).
  • Let be the code length huffman tree.
  • Let be the code length list of .
  • Let be the literal/length huffman tree.
  • Let be the code length list of .
  • Let be the distance huffman tree.
  • Let be the code length list of .
  • Let be encoded using the code length alphabet and then compressed using the code length huffman tree .

The code length, literal/length and distance huffman trees are stored as bits in the DEFLATE file as follows:

  • 5 bits: HLIT, the number of literal/length codes () minus 257. In other words, .
  • 5 bits: HDIST: the number of distance codes () minus 1. In other words, .
  • 4 bits: HCLEN: the number of code length codes () minus 4. In other words, .
  • (HCLEN + 4) x 3 bits: code lengths for the code length alphabet given just above, in the order: 16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15. In other words, , but the elements are in a different order --- we will need to rearrange them back.
  • Variable number of bits follows, containing the code length list for the literal/length alphabet () followed by the code length list for the distance alphabet (), encoded using the code length alphabet and huffman code. In other words, .

To wrap everything up, let's implement a function decode_trees() which will read the relevant bits from the input BitReader stream and decode them into the literal/length and distance huffman trees:

CodeLengthCodesOrder = [16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15]

def decode_trees(r):
    # The number of literal/length codes
    HLIT = r.read_bits(5) + 257

    # The number of distance codes
    HDIST = r.read_bits(5) + 1

    # The number of code length codes
    HCLEN = r.read_bits(4) + 4

    # Read code lengths for the code length alphabet
    code_length_tree_bl = [0 for _ in range(19)]
    for i in range(HCLEN):
        code_length_tree_bl[CodeLengthCodesOrder[i]] = r.read_bits(3)

    # Construct code length tree
    code_length_tree = bl_list_to_tree(code_length_tree_bl, range(19))

    # Read literal/length + distance code length list
    bl = []
    while len(bl) < HLIT + HDIST:
        sym = decode_symbol(r, code_length_tree)
        if 0 <= sym <= 15: # literal value
            bl.append(sym)
        elif sym == 16:
            # copy the previous code length 3..6 times.
            # the next 2 bits indicate repeat length ( 0 = 3, ..., 3 = 6 )
            prev_code_length = bl[-1]
            repeat_length = r.read_bits(2) + 3
            bl.extend(prev_code_length for _ in range(repeat_length))
        elif sym == 17:
            # repeat code length 0 for 3..10 times. (3 bits of length)
            repeat_length = r.read_bits(3) + 3
            bl.extend(0 for _ in range(repeat_length))
        elif sym == 18:
            # repeat code length 0 for 11..138 times. (7 bits of length)
            repeat_length = r.read_bits(7) + 11
            bl.extend(0 for _ in range(repeat_length))
        else:
            raise Exception('invalid symbol')

    # Construct trees
    literal_length_tree = bl_list_to_tree(bl[:HLIT], range(286))
    distance_tree = bl_list_to_tree(bl[HLIT:], range(30))
    return literal_length_tree, distance_tree

Putting everything together: handling compressed blocks

We are now ready to write the implementations of inflate_block_dynamic() (decompress block compressed with dynamic huffman codes) and inflate_block_fixed() (decompress block with fixed huffman codes)!

With the decode_trees() and inflate_block_data() function we have painstakenly implemented over multiple sections, it is now trivial to implement inflate_block_dynamic() --- we simply first read in and decode the literal/length and distance huffman trees with decode_trees(), and then call inflate_block_data() to inflate the compressed data that follows:

def inflate_block_dynamic(r, o):
    # decompress block with dynamic huffman codes
    literal_length_tree, distance_tree = decode_trees(r)
    inflate_block_data(r, literal_length_tree, distance_tree, o)

With inflate_block_fixed(), the literal/length and distance huffman codes are not specified in the compressed file, but rather is known to the decompressor beforehand.

The DEFLATE spec specifies the literal/length huffman code in terms of bitlengths, which we can then use to construct a canonical huffman tree with bl_list_to_tree() that we implemented earlier:

Literal/length symbolBit length of codeword
0 - 1438
144 - 2559
256 - 2797
280 - 2878

Eagle-eyed readers will notice that literal/length symbols 286-287 are mentioned in the table, but are actually not part of the length/literal alphabet. They will not occur in compressed data, but participate in the code construction in bl_list_to_tree(), particularly in the computation of bl_count.

Likewise, the DEFLATE spec specifies the distance huffman code in terms of bitlengths:

Distance symbolBit length of codeword
0 - 295

For reference, here is how the fixed literal/length huffman tree looks like:

And here is how the fixed distance huffman tree looks like:

We can implement inflate_block_fixed() as follows:

def inflate_block_fixed(r, o):
    bl = ([8 for _ in range(144)] + [9 for _ in range(144, 256)] +
        [7 for _ in range(256, 280)] + [8 for _ in range(280, 288)])
    literal_length_tree = bl_list_to_tree(bl, range(286))

    bl = [5 for _ in range(30)]
    distance_tree = bl_list_to_tree(bl, range(30))

    inflate_block_data(r, literal_length_tree, distance_tree, o)

Full source code

And we are done! We can test out our newly-minted decompress() function with:

import zlib
x = zlib.compress(b'Hello World!')
print(decompress(x)) # b'Hello World!'

Here is the full source code listing:

class BitReader:
    def __init__(self, mem):
        self.mem = mem
        self.pos = 0
        self.b = 0
        self.numbits = 0

    def read_byte(self):
        self.numbits = 0 # discard unread bits
        b = self.mem[self.pos]
        self.pos += 1
        return b

    def read_bit(self):
        if self.numbits <= 0:
            self.b = self.read_byte()
            self.numbits = 8
        self.numbits -= 1
        # shift bit out of byte
        bit = self.b & 1
        self.b >>= 1
        return bit

    def read_bits(self, n):
        o = 0
        for i in range(n):
            o |= self.read_bit() << i
        return o

    def read_bytes(self, n):
        # read bytes as an integer in little-endian
        o = 0
        for i in range(n):
            o |= self.read_byte() << (8 * i)
        return o

def decompress(input):
    r = BitReader(input)
    CMF = r.read_byte()
    CM = CMF & 15 # Compression method
    if CM != 8: # only CM=8 is supported
        raise Exception('invalid CM')
    CINFO = (CMF >> 4) & 15 # Compression info
    if CINFO > 7:
        raise Exception('invalid CINFO')
    FLG = r.read_byte()
    if (CMF * 256 + FLG) % 31 != 0:
        raise Exception('CMF+FLG checksum failed')
    FDICT = (FLG >> 5) & 1 # preset dictionary?
    if FDICT:
        raise Exception('preset dictionary not supported')
    out = inflate(r) # decompress DEFLATE data
    ADLER32 = r.read_bytes(4) # Adler-32 checksum (for this exercise, we ignore it)
    return out

def inflate(r):
    BFINAL = 0
    out = []
    while not BFINAL:
        BFINAL = r.read_bit()
        BTYPE = r.read_bits(2)
        if BTYPE == 0:
            inflate_block_no_compression(r, out)
        elif BTYPE == 1:
            inflate_block_fixed(r, out)
        elif BTYPE == 2:
            inflate_block_dynamic(r, out)
        else:
            raise Exception('invalid BTYPE')
    return bytes(out)

def inflate_block_no_compression(r, o):
    LEN = r.read_bytes(2)
    NLEN = r.read_bytes(2)
    o.extend(r.read_byte() for _ in range(LEN))

def code_to_bytes(code, n):
    # Encodes a code that is `n` bits long into bytes that is conformant with DEFLATE spec
    out = [0]
    numbits = 0
    for i in range(n-1, -1, -1):
        if numbits >= 8:
            out.append(0)
            numbits = 0
        out[-1] |= (1 if code & (1 << i) else 0) << numbits
        numbits += 1
    return bytes(out)

class Node:
    def __init__(self):
        self.symbol = ''
        self.left = None
        self.right = None

class HuffmanTree:
    def __init__(self):
        self.root = Node()
        self.root.symbol = ''

    def insert(self, codeword, n, symbol):
        # Insert an entry into the tree mapping `codeword` of len `n` to `symbol`
        node = self.root
        for i in range(n-1, -1, -1):
            b = codeword & (1 << i)
            if b:
                next_node = node.right
                if next_node is None:
                    node.right = Node()
                    next_node = node.right
            else:
                next_node = node.left
                if next_node is None:
                    node.left = Node()
                    next_node = node.left
            node = next_node
        node.symbol = symbol

def decode_symbol(r, t):
    "Decodes one symbol from bitstream `r` using HuffmanTree `t`"
    node = t.root
    while node.left or node.right:
        b = r.read_bit()
        node = node.right if b else node.left
    return node.symbol

LengthExtraBits = [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3,
        3, 4, 4, 4, 4, 5, 5, 5, 5, 0]
LengthBase = [3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 15, 17, 19, 23, 27, 31, 35, 43,
        51, 59, 67, 83, 99, 115, 131, 163, 195, 227, 258]
DistanceExtraBits = [0, 0, 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7,
        8, 8, 9, 9, 10, 10, 11, 11, 12, 12, 13, 13]
DistanceBase = [1, 2, 3, 4, 5, 7, 9, 13, 17, 25, 33, 49, 65, 97, 129, 193, 257,
        385, 513, 769, 1025, 1537, 2049, 3073, 4097, 6145, 8193, 12289, 16385,
        24577]

def inflate_block_data(r, literal_length_tree, distance_tree, out):
    while True:
        sym = decode_symbol(r, literal_length_tree)
        if sym <= 255: # Literal byte
            out.append(sym)
        elif sym == 256: # End of block
            return
        else: # <length, backward distance> pair
            sym -= 257
            length = r.read_bits(LengthExtraBits[sym]) + LengthBase[sym]
            dist_sym = decode_symbol(r, distance_tree)
            dist = r.read_bits(DistanceExtraBits[dist_sym]) + DistanceBase[dist_sym]
            for _ in range(length):
                out.append(out[-dist])

def bl_list_to_tree(bl, alphabet):
    MAX_BITS = max(bl)
    bl_count = [sum(1 for x in bl if x == y and y != 0) for y in range(MAX_BITS+1)]
    next_code = [0, 0]
    for bits in range(2, MAX_BITS+1):
        next_code.append((next_code[bits-1] + bl_count[bits-1]) << 1)
    t = HuffmanTree()
    for c, bitlen in zip(alphabet, bl):
        if bitlen != 0:
            t.insert(next_code[bitlen], bitlen, c)
            next_code[bitlen] += 1
    return t

CodeLengthCodesOrder = [16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15]

def decode_trees(r):
    # The number of literal/length codes
    HLIT = r.read_bits(5) + 257

    # The number of distance codes
    HDIST = r.read_bits(5) + 1

    # The number of code length codes
    HCLEN = r.read_bits(4) + 4

    # Read code lengths for the code length alphabet
    code_length_tree_bl = [0 for _ in range(19)]
    for i in range(HCLEN):
        code_length_tree_bl[CodeLengthCodesOrder[i]] = r.read_bits(3)

    # Construct code length tree
    code_length_tree = bl_list_to_tree(code_length_tree_bl, range(19))

    # Read literal/length + distance code length list
    bl = []
    while len(bl) < HLIT + HDIST:
        sym = decode_symbol(r, code_length_tree)
        if 0 <= sym <= 15: # literal value
            bl.append(sym)
        elif sym == 16:
            # copy the previous code length 3..6 times.
            # the next 2 bits indicate repeat length ( 0 = 3, ..., 3 = 6 )
            prev_code_length = bl[-1]
            repeat_length = r.read_bits(2) + 3
            bl.extend(prev_code_length for _ in range(repeat_length))
        elif sym == 17:
            # repeat code length 0 for 3..10 times. (3 bits of length)
            repeat_length = r.read_bits(3) + 3
            bl.extend(0 for _ in range(repeat_length))
        elif sym == 18:
            # repeat code length 0 for 11..138 times. (7 bits of length)
            repeat_length = r.read_bits(7) + 11
            bl.extend(0 for _ in range(repeat_length))
        else:
            raise Exception('invalid symbol')

    # Construct trees
    literal_length_tree = bl_list_to_tree(bl[:HLIT], range(286))
    distance_tree = bl_list_to_tree(bl[HLIT:], range(30))
    return literal_length_tree, distance_tree

def inflate_block_dynamic(r, o):
    literal_length_tree, distance_tree = decode_trees(r)
    inflate_block_data(r, literal_length_tree, distance_tree, o)

def inflate_block_fixed(r, o):
    bl = ([8 for _ in range(144)] + [9 for _ in range(144, 256)] +
        [7 for _ in range(256, 280)] + [8 for _ in range(280, 288)])
    literal_length_tree = bl_list_to_tree(bl, range(286))

    bl = [5 for _ in range(30)]
    distance_tree = bl_list_to_tree(bl, range(30))

    inflate_block_data(r, literal_length_tree, distance_tree, o)

import zlib
x = zlib.compress(b'Hello World!')
print(decompress(x)) # b'Hello World!'