18th October 2019

# Let's implement `zlib.decompress()`

Ever wondered how lossless compression in PNG files, ZIP files is achieved? In this article, we will attempt to gain some insight into it by writing a decompressor for the zlib format, which PNG and ZIP files use for compression!

The zlib/DEFLATE spec is interesting in that it usually deals with streams of bits rather than bytes. To make it easier to implement later algorithms, we shall first implement a `BitReader` class which will handle the details of extracting individual bits and bytes from a bytestring:

``````class BitReader:
def __init__(self, mem):
self.mem = mem
self.pos = 0
self.b = 0
self.numbits = 0

b = self.mem[self.pos]
self.pos += 1
return b

if self.numbits <= 0:
self.numbits = 8
self.numbits -= 1
# shift bit out of byte
bit = self.b & 1
self.b >>= 1
return bit

o = 0
for i in range(n):
return o

# read bytes as an integer in little-endian
o = 0
for i in range(n):
o |= self.read_byte() << (8 * i)
return o``````

For the DEFLATE spec, the rules of extracting bits from bytestring are as follows:

• Given a byte, the first bit from DEFLATE's perspective (0th bit) is the byte's least significant bit, while the last bit from DEFLATE's perspective (7th bit) is the byte's most significant bit. e.g. for the byte `0x9d`, the bits in from first to last are 1, 0, 1, 1, 1, 0, 0, 1.

This is implemented in `read_bit()`, which does this by first reading a byte (with `read_byte()`), shifting a bit out and then returning it. This continues for each call to `read_bit()` until all 8 bits have been shifted out, at which point it will read the next byte. Let's try it out:

``````r = BitReader(b'\x9d')
print(r.read_bit()) # IndexError: index out of range``````
• The DEFLATE spec sometimes requires reading an unsigned integer `I` comprising of `n` bits from the stream. In this case, the 1st bit that is read from the stream is the least significant bit, while the `n`-th bit read from the stream is the most significant bit. For example, consider the integer `299`, which is `0b100101011` in binary (9 bits). Assuming that all the other bits are zero, this would be encoded as the two bytes `b'\x2b\x01'`.

This is implemented in `read_bits(self, n)`, which implements this by calling `read_bit()` n times and placing the bit at the correct bit position. Let's try it out:

``````r = BitReader(b'\x2b\x01')
• The DEFLATE spec also sometimes requires reading an unsigned integer `I` comprising of `n` bytes from the stream. In this case, whatever bits that were not read from `read_bits()`/`read_bit()` are discarded, and `n` bytes are read. These `n` bytes are in little-endian order, that is, least significant byte first. For example, the integer `13926` is encoded as `b'\x66\x36'`.

This is implemented in `read_bytes(self, n)`, which implements this by calling `read_byte()` n times and placing the bits of the bytes at the correct bit position. Let's try it out:

``````r = BitReader(b'\x66\x36')

## The zlib container

Now that we have `BitReader`, let's start from the top. The zlib format, starting from the start of the file/bytestring, is as follows:

• 1 byte: `CMF` --- Compression Method and compression info fields
• 1 byte: `FLG` --- Compression flags fields
• Variable number of bytes: The DEFLATE data
• 4 bytes: `ADLER32` -- Adler-32 checksum over the original uncompressed data. For this exercise, we will simply ignore it.
• End of file/bytestring

As we can see, the zlib format is a very lightweight container format over the DEFLATE format, adding only 6 extra bytes on top.

The `CMF` fields are as follows:

• Bits 0 to 3: `CM` --- Compression Method. This identifies the compression method in the file. Only `CM=8` is defined by the zlib spec, which basically says that the zlib file contains data compressed with DEFLATE. Any other value of `CM` is not supported, and we should error out.
• Bits 4 to 7: `CINFO` --- Compression info. This is the base-2 logarithm of the LZ77 window size, minus eight. In other words, the window size is . The maximum window size that is allowed by the spec is 32768, which is . In other words, `CINFO` must be , and any other value should be treated as an error.

The `FLG` fields are as follows:

• Bits 0 to 4: `FCHECK` --- Used as part of the checksum. See below.
• Bits 5: `FDICT` --- If set, a `DICT` dictionary identifier is present immediately after the `FLG` byte. The dictionary is a sequence of bytes which is known beforehand to both the compressor and decompressor that can be used to achieve greater compression ratios. The zlib spec does not define any preset dictionaries and leaves it up to the implementor. The file formats that we are interested in (e.g. PNG, ZIP files) also do not specify any preset dictionaries. As such, we should not need to handle preset dictionaries at all, and we should error out if the `FDICT` bit is set.
• Bits 6 to 7: `FLEVEL` --- Compression level. This indicates whether the original data was compressed with the fastest/fast/default/max-compression compression level. It's not needed for de-compression at all, and is only there to indicate if recompression might be worthwhile. For our purposes, we can simply ignore it.

In addition, the `CMF` and `FLG` bytes also serve as a checksum: must be a multiple of 31. Bits 0 to 4 in the `FLG` byte are set in the way that ensures that this is true if the CMF and FLG bytes are uncorrupted. This makes sense: since , and 5 bits can store values 0..31 inclusive, 5 bits is sufficient in ensuring that we can specify a value that is added to any integer to make it a multiple of 31.

As such, let's implement our top-level `decompress()` function:

``````def decompress(input):
CM = CMF & 15 # Compression method
if CM != 8: # only CM=8 is supported
raise Exception('invalid CM')
CINFO = (CMF >> 4) & 15 # Compression info
if CINFO > 7:
raise Exception('invalid CINFO')
if (CMF * 256 + FLG) % 31 != 0:
raise Exception('CMF+FLG checksum failed')
FDICT = (FLG >> 5) & 1 # preset dictionary?
if FDICT:
raise Exception('preset dictionary not supported')
out = inflate(r) # decompress DEFLATE data
return out``````

Let's say we have some zlib compressed data like this:

``````import zlib
x = zlib.compress(b'Hello World!')``````

We should then be able to decompress it with our implemented `decompress()` function as follows:

``print(decompress(x)) # Supposed to print b'Hello World!'``

Of course, it does not presently work because it calls `inflate()` to handle the DEFLATE data, but we haven't implemented `inflate()` yet. We shall implement it in the following sections.

## DEFLATE Blocks

A DEFLATE stream consists of a series of blocks. Each block begins with 3 header bits containing the following data:

BitsNameMeaning
first bit`BFINAL`Set if and only if this is the last block of data.
next 2 bits`BTYPE`

Specifies how the data are compressed, as follows:

• 00 - no compression
• 01 - compressed with fixed Huffman codes
• 10 - compressed with dynamic Huffman codes
• 11 - reserved (error)

As such, let's define the main function of our library, `inflate()`, which will inflate a compressed DEFLATE bitstream.

``````def inflate(r):
BFINAL = 0
out = [] # list of integers 0..255, representing decompressed bytes
while not BFINAL:
if BTYPE == 0:
inflate_block_no_compression(r, out)
elif BTYPE == 1:
inflate_block_fixed(r, out)
elif BTYPE == 2:
inflate_block_dynamic(r, out)
else:
raise Exception('invalid BTYPE')
return bytes(out) # return decompressed bytes as bytestring``````

We can't run it yet because it refers to missing functions `inflate_block_no_compression()`, `inflate_block_fixed()` and `inflate_block_dynamic()` (you can stub them out for now). In the rest of this article we will fill them up one by one.

## Non-compressed blocks

Let's start out with something very simple -- non-compressed blocks. After the BFINAL and BTYPE bits, any bits of input up to the next byte boundary are ignored. Then, the rest of the block consists of:

• 2 bytes: `LEN` --- the number of data bytes in the block.
• 2 bytes: `NLEN` -- the one's complement of `LEN`.
• `LEN` bytes: The literal data

As such, implementing `inflate_block_no_compression()` is trivial:

``````def inflate_block_no_compression(r, o):

As mentioned when we were implementing `BitReader`, `read_bytes()` will discard any unread bits and move to the next byte boundary, which is what we want in this case.

At this point, our `decompress()` function will work for zlib bytestrings that have no compression! Let's try it out:

``````import zlib
x = zlib.compress(b'Hello World!', level=0) # level 0 means no compression
print(decompress(x)) # b'Hello World!'``````

Now, let's move on to implement `inflate_block_fixed()` and `inflate_block_dynamic()` so that we can handle compressed data. To do that, we must first understand how DEFLATE does compression.

## Huffman coding

The first half of DEFLATE's compression story is huffman coding. Here's an example to show how huffman coding works:

Suppose we only allow strings with the characters `A`, `B`, `C` and `D`. Formally, this is called the alphabet . An example of valid string which meets our conditions is "`BBBACD`". An example of an invalid string is "`BBBACDF`", since it contains the letter "`F`" which is not in our alphabet.

Now, let's say we want to encode the string "`BBBACD`" into binary. One simple way to do it would be to map the characters of the alphabet to integers. Since there are 4 characters in the alphabet, we will map them to the range 0..3 inclusive. The maximum integer is 3, which requires at least 2 bits (`b11` in binary) to represent, and so we encode these integers with 2 bits each.

We map to 0 (binary: `0b00`), to 1 (binary: `0b01`), to 2 (binary: `0b10`) and to 3 (binary: `0b11`). The string "`BBBACD`" encoded in our simple encoding is thus `0b010101001011`, which is bits in total.

How could we compress the string "`BBBACD`" by encoding it with less bits? Once way is to use a variable-length encoding which gives shorter codes to more frequent characters.

A huffman code is a variable-length code table for encoding symbols of an alphabet into a string of bits. An example huffman code table for encoding the alphabet is:

SymbolCodeword
`A``0b01`
`B``0b1`
`C``0b000`
`D``0b001`

One important constraint of a huffman code is that no codeword can be a prefix of another codeword. For example, we can't assign `B` the codeword `0b0` because `0b0` is a prefix of `A`'s codeword (`0b01`). Likewise, we can't assign `C` the codeword `0b100` because `B`'s codeword, `0b1`, is a prefix of `0b100`. This constraint is needed to prevent ambiguity during decoding. Let's say `B` is the codeword `0b0`. In that case, an ambiguity exists when decoding `0b001` --- should we decode it to "`BA`" or to "`D`"?

Now, we can see from the table that `B` can be encoded with 1 bit, `A` with 2 bits, and `C` and `D` with 3 bits. As such, the string "`BBBACD`" will be encoded as `0b11101000001`, which is 11 bits in total. This is shorter than the simple encoding that we used earlier, and so compression has been achieved.

How would we encode `0b11101000001` into bytes? The DEFLATE spec specifies that huffman-coded bits should be encoded in such a way that when reading it back with `BitReader`, the first bit that `read_bit()` returns is the most significant bit (`0b1`), the next bit that `read_bit()` returns is the second-most significant bit (`0b1`) etc. To help us out later, let's implement `code_to_bytes(code, n)` which will help us do just that:

``````def code_to_bytes(code, n):
# Encodes a code that is `n` bits long into bytes that is conformant with DEFLATE spec
out = 
numbits = 0
for i in range(n-1, -1, -1):
if numbits >= 8:
out.append(0)
numbits = 0
out[-1] |= (1 if code & (1 << i) else 0) << numbits
numbits += 1
return bytes(out)``````

Testing it out:

``````r = BitReader(code_to_bytes(0b11101000001, 11))
print([r.read_bit() for _ in range(11)]) # [1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 1]``````

Next, given a huffman code table like above, how could we write a program that can decode any arbitrary valid bitstream? One way would be to first convert the above table into a huffman tree. A huffman tree is a binary tree whose leaves are the symbols of the alphabet. For a given internal node, the left edge represents the bit 0 while the right edge represents the bit 1. The equivalent huffman tree for the huffman code in the table above is as follows: From the huffman tree, we can derive the codeword of a symbol by starting from the root and walking down the edges to the leaf corresponding to the leaf, while noting down whether we took the left or right edge. For example, to get to `D` from the root node, we first start from the root node, take the left edge (`0`), then the left edge (`0`), then finally the right edge (`1`) to get to `D`, which gives us the code `0b001` which is `D`'s codeword in the table above!

Likewise, to decode a symbol from a bitstream, we can:

1. Start from the root node.
2. Read a bit from the bitstream. If it is a `0`, take the left edge, otherwise (if it is a `1`), take the right edge.
3. Repeat Step 2 until we reach a leaf node. That is the decoded symbol.

As we can see, having a huffman tree makes it easy to write an algorithm that can decode symbols from a bitstream in linear time.

Now let's implement this. We start we the implementation of a `HuffmanTree` that has an `insert()` operation that allows us to insert `codeword -> symbol` mappings into the tree (with the assumption that they obey the "no codeword can be a prefix of another codeword" constraint).

``````class Node:
def __init__(self):
self.symbol = ''
self.left = None
self.right = None

class HuffmanTree:
def __init__(self):
self.root = Node()

def insert(self, codeword, n, symbol):
# Insert an entry into the tree mapping `codeword` of len `n` to `symbol`
node = self.root
for i in range(n-1, -1, -1):
b = codeword & (1 << i)
if b:
next_node = node.right
if next_node is None:
node.right = Node()
next_node = node.right
else:
next_node = node.left
if next_node is None:
node.left = Node()
next_node = node.left
node = next_node
node.symbol = symbol``````

We can now construct our example huffman tree as follows:

``````t = HuffmanTree()
t.insert(0b01, 2, 'A')
t.insert(0b1, 1, 'B')
t.insert(0b000, 3, 'C')
t.insert(0b001, 3, 'D')``````

Then, we can write a `decode_symbol()` function that decodes a symbol from a bitstream using a specified huffman tree. This is implemented using the algorithm described above:

``````def decode_symbol(r, t):
"Decodes one symbol from bitstream `r` using HuffmanTree `t`"
node = t.root
while node.left or node.right:
node = node.right if b else node.left
return node.symbol``````

Now, let's try decoding `0b11101000001` back into "`BBBACD`":

``````r = BitReader(code_to_bytes(0b11101000001, 11))
print(decode_symbol(r, t)) # 'B'
print(decode_symbol(r, t)) # 'B'
print(decode_symbol(r, t)) # 'B'
print(decode_symbol(r, t)) # 'A'
print(decode_symbol(r, t)) # 'C'
print(decode_symbol(r, t)) # 'D'``````

## LZ77 and the Literal/Length and Distance Alphabet

One might imagine that the DEFLATE spec simply uses Huffman Coding as-is, but with an alphabet that consists of the values 0..255 inclusive, which are the values that a byte can take. That way, bytes which occur frequently in a file will be given shorter codewords, thus compressing the file. This is only half-correct --- the DEFLATE spec has an additional trick up its sleeve which achieve better compression ratios on files that contain repeated data: LZ77 compression.

Suppose we have a file that consists of the pattern "`BBA`" repeated over and over again, i.e. "`BBABBABBABBA`...". Even if we encode `B` and `A` one bit each, the size of the compressed file will still be linear to the size of the original file.

We can compress this file even further by employing LZ77 compression. With LZ77 compression, we can first writing the literal values "`BBA`" into the output compressed file. Next, rather than writing "`BBA`" again, we write a <length, backward distance> marker that specifies that the (de-compressed) string of the specified length, backward distance away, should be duplicated. For example, the string "`BBABBA`" (length 6) can be encoded as "`BBA <3,3>`", the string "`BBABBABBA`" (length 9) can be encoded as "`BBA <6,3>`", and the string "`BBABBABBABBABBABBA`" (length 18) can be encoded as "`BBA <15,3>`"! This means that we can save nearly an exponential amount of space!

Now, using LZ77 compression as well means that a DEFLATE stream would contain both literal byte values and <length, backward distance> pairs. So, how does DEFLATE encode them and distinguish between them at de-compression time? It does this by defining two alphabets --- the Literal/Length alphabet and the Distance alphabet:

• The Literal/Length alphabet consists of the values 0..285 inclusive. Values 0..255 are the corresponding literal byte values. Value 256 is the end-of-block marker. Values 257..285 are used to encode the "length" portion of <length, backward distance> pairs, although it does not encode the length literally. See below.
• The Distance alphabet consists of the values 0..29 inclusive. They are used to encode the "backward distance" portion of <length, backward distance> pairs, although they do not encode them literally. See below.

These two alphabets each have their own huffman trees, which we shall call the Literal/Length tree and the Distance tree.

Values 257..285 of the Literal/Length alphabet are used to represent length values, sometimes in conjunction with extra bits that come after the symbol:

SymbolExtra bitsLength(s)
25703
25804
25905
26006
26107
26208
26309
264010
265111,12
266113,14
267115,16
268117,18
269219-22
270223-26
271227-30
272231-34
273335-42
274343-50
275351-58
276359-66
277467-82
278483-98
279499-114
2804115-130
2815131-162
2825163-194
2835195-226
2845227-257
2850258

Example 1: Let's say we call `decode_symbol()` with the literal/length huffman tree and it returns `42`. Since it is within the range 0..255, it is a literal byte and so we should write byte `42` to the output stream.

Example 2: Let's say we call `decode_symbol()` with the literal/length huffman tree and it returns `256`. Since 256 is the end-of-block marker, we should terminate parsing of the block.

Example 3(a): Let's say we call `decode_symbol()` with the literal/length huffman tree and it returns `257`. According to the table, symbol `257` has 0 extra bits. Therefore we know immediately that we have reached a <length, backward distance> pair with length 3. We then expect the backward distance of the pair to follow and we should next call `decode_symbol()` with the distance tree to decode the backward distance. (See Example 3(b) for continuation)

Example 4(a): Let's say we call `decode_symbol()` with the literal/length huffman tree and it returns `275`. According to the table, symbol `275` has 3 extra bits. We thus call `read_bits(3)` to read those 3 extra bits. Let's say that `read_bits(3)` returns 2. Therefore, we know that we have reached a <length, backward distance> pair with length `51 + 2 = 53`. We then expect the backward distance of the pair to follow and we should next call `decode_symbol()` with the distance tree to decode the backward distance. (See Example 4(b) for continuation)

Similarly to the literal/length alphabet, the distance alphabet is used to represent backward distance values, sometimes in conjunction with extra bits that come after the symbol:

SymbolExtra BitsDistance(s)
001
102
203
304
415,6
517,8
629-12
7213-16
8317-24
9325-32
10433-48
11449-64
12565-96
13597-128
146129-192
156193-256
167257-384
177385-512
188513-768
198769-1024
2091025-1536
2191537-2048
22102049-3072
23103073-4096
24114097-6144
25116145-8192
26128193-12288
271212289-16384
281316385-24576
291324577-32768

Example 3(b): After decoding the length of the <length, backward distance> pair in Example 3(a) and getting a length of `3`, we then call `decode_symbol()` with the distance tree to decode the backward distance. Let's say `decode_symbol()` returns `2`. According to the table, symbol `2` has 0 extra bits. Therefore we know immediately that we have a backward distance of 3, and so our <length, backward distance> pair is . As such, we should then copy 3 bytes 3 bytes back from the current position in the output stream to the output stream.

Example 4(b): After decoding the length of the <length, backward distance> pair in Example 4(a) and getting a length of `53`, we then call `decode_symbol()` with the distance tree to decode the backward distance. Let's say `decode_symbol()` returns `12`. According to the table, symbol `12` has 5 extra bits. We thus call `read_bits(5)` to read those extra bits. Let's say `read_bits(5)` returns `30`. Therefore, the backward distance is `65 + 30 = 95`. As such, we should then copy 53 bytes 95 bytes back from the current position in the output stream to the output stream.

Now let's implement it.

Firstly, we have the tables:

``````LengthExtraBits = [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 0]
LengthBase = [3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 15, 17, 19, 23, 27, 31, 35, 43, 51, 59, 67, 83, 99, 115, 131, 163, 195, 227, 258]
DistanceExtraBits = [0, 0, 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9, 10, 10, 11, 11, 12, 12, 13, 13]
DistanceBase = [1, 2, 3, 4, 5, 7, 9, 13, 17, 25, 33, 49, 65, 97, 129, 193, 257, 385, 513, 769, 1025, 1537, 2049, 3073, 4097, 6145, 8193, 12289, 16385, 24577]``````

`LengthExtraBits` is the "Extra bits" column of the Literal/Length table, and `LengthBase` is the corresponding base integer of the "Length(s)" column of the Literal/Length table. Likewise, `DistanceExtraBits` is the "Extra bits" column of the Distance table, and `DistanceBase` is the corresponding base integer of the "Distance(s)" column of the Distance table.

And next, we have the decode function:

``````def inflate_block_data(r, literal_length_tree, distance_tree, out):
while True:
sym = decode_symbol(r, literal_length_tree)
if sym <= 255: # Literal byte
out.append(sym)
elif sym == 256: # End of block
return
else: # <length, backward distance> pair
sym -= 257
dist_sym = decode_symbol(r, distance_tree)
for _ in range(length):
out.append(out[-dist])``````

## Encoding the literal/length and distance huffman trees

So, given the literal/length and distance huffman trees, we can de-compress (or inflate) a block of data using our newly-implemented `inflate_block_data()` function. However, how are these trees even stored in the first place?

Once way would be to just store the `codeword -> symbol` mappings as-is. However, the DEFLATE spec uses a different method which allows huffman trees to be stored more compactly in the file by placing additional restrictions on them. In particular, the DEFLATE spec requires the huffman trees to be canonical huffman trees.

To recap, a huffman code is required to be a prefix code -- no codeword can be a prefix of another codeword.

On top of that, canonical huffman codes adds the following additional requirements:

1. The first symbol whose codeword has bit length 1 must have the codeword of `0b0`.
2. All codes of a given bit length have lexicographically consecutive values, in the same order as the symbols they represent.
3. Shorter codes lexicographically precede longer codes.

For example, near the beginning of this article we had this example huffman code defined over the alphabet :

SymbolCodeword
`A``0b01`
`B``0b1`
`C``0b000`
`D``0b001`

This is not a valid canonical huffman code, and thus cannot be encoded in the DEFLATE file. This is because:

• The first symbol whose codeword has bit length 1, `B`, does not have a codeword of `0b0`, violating rule 1.
• `B` has a shorter code compared to `A`, but it is lexicographically greater compared to `A`, violating rule 3.

The fixed huffman code that meets all the conditions of being a canonical huffman code is:

SymbolCodeword
`A``0b10`
`B``0b0`
`C``0b110`
`D``0b111`

Why require the huffman code to be canonical? This is because any canonical huffman code can be uniquely compactly represented by (1) a list of the symbols of its alphabet (in the alphabet's order), and (2) the bit length of the codeword for each symbol of the alphabet. Furthermore, if the alphabet is already known by the decoder, then the alphabet list is not required at all, and only the bit length list (which we shall call `bl`) needs to be provided.

For example, let's try to derive the canonical huffman code in the table above with just its alphabet and its bit length list `bl`.

To start off, the corresponding alphabet and bit length list `bl` is:

``````alphabet = 'ABCD'
bl = [2, 1, 3, 3]``````

We also compute `MAX_BITS`, which is the size of the longest codeword:

``MAX_BITS = max(bl)``

To derive the huffman tree, we first compute the number of codes for each code length.

``````bl_count = [sum(1 for x in bl if x == y and y != 0) for y in range(MAX_BITS+1)]
print(bl_count) # [0, 1, 1, 2]``````

Next, we compute `next_code` such that `next_code[n]` is the smallest codeword with code length `n`.

``````next_code = [0, 0]
for bits in range(2, MAX_BITS+1):
next_code.append((next_code[bits-1] + bl_count[bits-1]) << 1)
print(next_code) # [0, 0, 2, 6]``````

Sanity check: `next_code` is `0` (binary: `0b0`) which is correct since `B`'s codeword is `0b0`. `next_code` is `2` (binary: `0b10`) which is correct since `A`'s codeword is `0b10`. `next_code` is `6` (binary: `0b110`) which is correct since `C`'s codeword is `0b110`. So everything seems fine.

Finally, we can compute the code for each symbol of our alphabet, and enter it into our huffman tree, taking care to ignore symbols of our alphabet whose codeword has a bit length of 0 (meaning the symbol is not used in the compressed data):

``````t = HuffmanTree()
for c, bitlen in zip(alphabet, bl):
if bitlen != 0: # ignore if bit length is 0
t.insert(next_code[bitlen], bitlen, c)
next_code[bitlen] += 1``````

Now, let's try decoding `b00010110111`, which is "`BBBACD`":

``````r = BitReader(code_to_bytes(0b00010110111, 11))
print(decode_symbol(r, t)) # B
print(decode_symbol(r, t)) # B
print(decode_symbol(r, t)) # B
print(decode_symbol(r, t)) # A
print(decode_symbol(r, t)) # C
print(decode_symbol(r, t)) # D``````

So everything seems correct!

What happens if the alphabet is as usual, but `C` and `D` are not used at all? In that case, the DEFLATE spec specifies that `C` and `D` should be specified as having bit length zero. In other words,

``````alphabet = 'ABCD'
bl = [2, 1, 0, 0]``````

Furthermore, the DEFLATE spec allows trailing zeroes in the `bl` list to be ommitted. In other words, this is valid:

``````alphabet = 'ABCD'
bl = [2, 1]``````

To wrap up, here is a reusable `bl_list_to_tree()` function which constructs a huffman tree from a bit length list, implemented using all the algorithms we have covered in this section:

``````def bl_list_to_tree(bl, alphabet):
MAX_BITS = max(bl)
bl_count = [sum(1 for x in bl if x == y and y != 0) for y in range(MAX_BITS+1)]
next_code = [0, 0]
for bits in range(2, MAX_BITS+1):
next_code.append((next_code[bits-1] + bl_count[bits-1]) << 1)
t = HuffmanTree()
for c, bitlen in zip(alphabet, bl):
if bitlen != 0:
t.insert(next_code[bitlen], bitlen, c)
next_code[bitlen] += 1
return t``````

Usage:

``````t = bl_list_to_tree([2, 1, 3, 3], 'ABCD')
print(decode_symbol(r, t)) # B
print(decode_symbol(r, t)) # B
print(decode_symbol(r, t)) # B
print(decode_symbol(r, t)) # A
print(decode_symbol(r, t)) # C
print(decode_symbol(r, t)) # D``````

However, the DEFLATE spec adds yet another twist. Rather than just specifying the code lengths directly, for "even greater compactness", the code length sequences themselves are compressed using a Huffman code! The alphabet for the code lengths is as follows:

SymbolMeaning
0 - 15Represent code lengths of 0 - 15
16Copy the previous code length 3 - 6 times. The next 2 bits indicate repeat length ( 0 = 3, ..., 3 = 6)
17Repeat a code length of 0 for 3 - 10 times. (3 bits of length)
18Repeat a code length of 0 for 11 - 138 times. (7 bits of length)

To make this clear: we actually now have three huffman codes in play here --- the code length huffman code, the literal/length huffman code and the distance huffman code. For clarity let's give them mathematical symbols as follows:

• Let be the code length alphabet (the table directly above).
• Let be the code length huffman tree.
• Let be the code length list of .
• Let be the literal/length huffman tree.
• Let be the code length list of .
• Let be the distance huffman tree.
• Let be the code length list of .
• Let be encoded using the code length alphabet and then compressed using the code length huffman tree .

The code length, literal/length and distance huffman trees are stored as bits in the DEFLATE file as follows:

• 5 bits: `HLIT`, the number of literal/length codes () minus 257. In other words, .
• 5 bits: `HDIST`: the number of distance codes () minus 1. In other words, .
• 4 bits: `HCLEN`: the number of code length codes () minus 4. In other words, .
• (`HCLEN` + 4) x 3 bits: code lengths for the code length alphabet given just above, in the order: 16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15. In other words, , but the elements are in a different order --- we will need to rearrange them back.
• Variable number of bits follows, containing the code length list for the literal/length alphabet () followed by the code length list for the distance alphabet (), encoded using the code length alphabet and huffman code. In other words, .

To wrap everything up, let's implement a function `decode_trees()` which will read the relevant bits from the input `BitReader` stream and decode them into the literal/length and distance huffman trees:

``````CodeLengthCodesOrder = [16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15]

def decode_trees(r):
# The number of literal/length codes

# The number of distance codes

# The number of code length codes

# Read code lengths for the code length alphabet
code_length_tree_bl = [0 for _ in range(19)]
for i in range(HCLEN):

# Construct code length tree
code_length_tree = bl_list_to_tree(code_length_tree_bl, range(19))

# Read literal/length + distance code length list
bl = []
while len(bl) < HLIT + HDIST:
sym = decode_symbol(r, code_length_tree)
if 0 <= sym <= 15: # literal value
bl.append(sym)
elif sym == 16:
# copy the previous code length 3..6 times.
# the next 2 bits indicate repeat length ( 0 = 3, ..., 3 = 6 )
prev_code_length = bl[-1]
bl.extend(prev_code_length for _ in range(repeat_length))
elif sym == 17:
# repeat code length 0 for 3..10 times. (3 bits of length)
bl.extend(0 for _ in range(repeat_length))
elif sym == 18:
# repeat code length 0 for 11..138 times. (7 bits of length)
bl.extend(0 for _ in range(repeat_length))
else:
raise Exception('invalid symbol')

# Construct trees
literal_length_tree = bl_list_to_tree(bl[:HLIT], range(286))
distance_tree = bl_list_to_tree(bl[HLIT:], range(30))
return literal_length_tree, distance_tree``````

## Putting everything together: handling compressed blocks

We are now ready to write the implementations of `inflate_block_dynamic()` (decompress block compressed with dynamic huffman codes) and `inflate_block_fixed()` (decompress block with fixed huffman codes)!

With the `decode_trees()` and `inflate_block_data()` function we have painstakenly implemented over multiple sections, it is now trivial to implement `inflate_block_dynamic()` --- we simply first read in and decode the literal/length and distance huffman trees with `decode_trees()`, and then call `inflate_block_data()` to inflate the compressed data that follows:

``````def inflate_block_dynamic(r, o):
# decompress block with dynamic huffman codes
literal_length_tree, distance_tree = decode_trees(r)
inflate_block_data(r, literal_length_tree, distance_tree, o)``````

With `inflate_block_fixed()`, the literal/length and distance huffman codes are not specified in the compressed file, but rather is known to the decompressor beforehand.

The DEFLATE spec specifies the literal/length huffman code in terms of bitlengths, which we can then use to construct a canonical huffman tree with `bl_list_to_tree()` that we implemented earlier:

Literal/length symbolBit length of codeword
0 - 1438
144 - 2559
256 - 2797
280 - 2878

Eagle-eyed readers will notice that literal/length symbols 286-287 are mentioned in the table, but are actually not part of the length/literal alphabet. They will not occur in compressed data, but participate in the code construction in `bl_list_to_tree()`, particularly in the computation of `bl_count`.

Likewise, the DEFLATE spec specifies the distance huffman code in terms of bitlengths:

Distance symbolBit length of codeword
0 - 295

For reference, here is how the fixed literal/length huffman tree looks like: And here is how the fixed distance huffman tree looks like: We can implement `inflate_block_fixed()` as follows:

``````def inflate_block_fixed(r, o):
bl = ([8 for _ in range(144)] + [9 for _ in range(144, 256)] +
[7 for _ in range(256, 280)] + [8 for _ in range(280, 288)])
literal_length_tree = bl_list_to_tree(bl, range(286))

bl = [5 for _ in range(30)]
distance_tree = bl_list_to_tree(bl, range(30))

inflate_block_data(r, literal_length_tree, distance_tree, o)``````

## Full source code

And we are done! We can test out our newly-minted `decompress()` function with:

``````import zlib
x = zlib.compress(b'Hello World!')
print(decompress(x)) # b'Hello World!'``````

Here is the full source code listing:

``````class BitReader:
def __init__(self, mem):
self.mem = mem
self.pos = 0
self.b = 0
self.numbits = 0

b = self.mem[self.pos]
self.pos += 1
return b

if self.numbits <= 0:
self.numbits = 8
self.numbits -= 1
# shift bit out of byte
bit = self.b & 1
self.b >>= 1
return bit

o = 0
for i in range(n):
return o

# read bytes as an integer in little-endian
o = 0
for i in range(n):
o |= self.read_byte() << (8 * i)
return o

def decompress(input):
CM = CMF & 15 # Compression method
if CM != 8: # only CM=8 is supported
raise Exception('invalid CM')
CINFO = (CMF >> 4) & 15 # Compression info
if CINFO > 7:
raise Exception('invalid CINFO')
if (CMF * 256 + FLG) % 31 != 0:
raise Exception('CMF+FLG checksum failed')
FDICT = (FLG >> 5) & 1 # preset dictionary?
if FDICT:
raise Exception('preset dictionary not supported')
out = inflate(r) # decompress DEFLATE data
return out

def inflate(r):
BFINAL = 0
out = []
while not BFINAL:
if BTYPE == 0:
inflate_block_no_compression(r, out)
elif BTYPE == 1:
inflate_block_fixed(r, out)
elif BTYPE == 2:
inflate_block_dynamic(r, out)
else:
raise Exception('invalid BTYPE')
return bytes(out)

def inflate_block_no_compression(r, o):

def code_to_bytes(code, n):
# Encodes a code that is `n` bits long into bytes that is conformant with DEFLATE spec
out = 
numbits = 0
for i in range(n-1, -1, -1):
if numbits >= 8:
out.append(0)
numbits = 0
out[-1] |= (1 if code & (1 << i) else 0) << numbits
numbits += 1
return bytes(out)

class Node:
def __init__(self):
self.symbol = ''
self.left = None
self.right = None

class HuffmanTree:
def __init__(self):
self.root = Node()
self.root.symbol = ''

def insert(self, codeword, n, symbol):
# Insert an entry into the tree mapping `codeword` of len `n` to `symbol`
node = self.root
for i in range(n-1, -1, -1):
b = codeword & (1 << i)
if b:
next_node = node.right
if next_node is None:
node.right = Node()
next_node = node.right
else:
next_node = node.left
if next_node is None:
node.left = Node()
next_node = node.left
node = next_node
node.symbol = symbol

def decode_symbol(r, t):
"Decodes one symbol from bitstream `r` using HuffmanTree `t`"
node = t.root
while node.left or node.right:
node = node.right if b else node.left
return node.symbol

LengthExtraBits = [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3,
3, 4, 4, 4, 4, 5, 5, 5, 5, 0]
LengthBase = [3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 15, 17, 19, 23, 27, 31, 35, 43,
51, 59, 67, 83, 99, 115, 131, 163, 195, 227, 258]
DistanceExtraBits = [0, 0, 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7,
8, 8, 9, 9, 10, 10, 11, 11, 12, 12, 13, 13]
DistanceBase = [1, 2, 3, 4, 5, 7, 9, 13, 17, 25, 33, 49, 65, 97, 129, 193, 257,
385, 513, 769, 1025, 1537, 2049, 3073, 4097, 6145, 8193, 12289, 16385,
24577]

def inflate_block_data(r, literal_length_tree, distance_tree, out):
while True:
sym = decode_symbol(r, literal_length_tree)
if sym <= 255: # Literal byte
out.append(sym)
elif sym == 256: # End of block
return
else: # <length, backward distance> pair
sym -= 257
dist_sym = decode_symbol(r, distance_tree)
for _ in range(length):
out.append(out[-dist])

def bl_list_to_tree(bl, alphabet):
MAX_BITS = max(bl)
bl_count = [sum(1 for x in bl if x == y and y != 0) for y in range(MAX_BITS+1)]
next_code = [0, 0]
for bits in range(2, MAX_BITS+1):
next_code.append((next_code[bits-1] + bl_count[bits-1]) << 1)
t = HuffmanTree()
for c, bitlen in zip(alphabet, bl):
if bitlen != 0:
t.insert(next_code[bitlen], bitlen, c)
next_code[bitlen] += 1
return t

CodeLengthCodesOrder = [16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15]

def decode_trees(r):
# The number of literal/length codes

# The number of distance codes

# The number of code length codes

# Read code lengths for the code length alphabet
code_length_tree_bl = [0 for _ in range(19)]
for i in range(HCLEN):

# Construct code length tree
code_length_tree = bl_list_to_tree(code_length_tree_bl, range(19))

# Read literal/length + distance code length list
bl = []
while len(bl) < HLIT + HDIST:
sym = decode_symbol(r, code_length_tree)
if 0 <= sym <= 15: # literal value
bl.append(sym)
elif sym == 16:
# copy the previous code length 3..6 times.
# the next 2 bits indicate repeat length ( 0 = 3, ..., 3 = 6 )
prev_code_length = bl[-1]
bl.extend(prev_code_length for _ in range(repeat_length))
elif sym == 17:
# repeat code length 0 for 3..10 times. (3 bits of length)
bl.extend(0 for _ in range(repeat_length))
elif sym == 18:
# repeat code length 0 for 11..138 times. (7 bits of length)
bl.extend(0 for _ in range(repeat_length))
else:
raise Exception('invalid symbol')

# Construct trees
literal_length_tree = bl_list_to_tree(bl[:HLIT], range(286))
distance_tree = bl_list_to_tree(bl[HLIT:], range(30))
return literal_length_tree, distance_tree

def inflate_block_dynamic(r, o):
literal_length_tree, distance_tree = decode_trees(r)
inflate_block_data(r, literal_length_tree, distance_tree, o)

def inflate_block_fixed(r, o):
bl = ([8 for _ in range(144)] + [9 for _ in range(144, 256)] +
[7 for _ in range(256, 280)] + [8 for _ in range(280, 288)])
literal_length_tree = bl_list_to_tree(bl, range(286))

bl = [5 for _ in range(30)]
distance_tree = bl_list_to_tree(bl, range(30))

inflate_block_data(r, literal_length_tree, distance_tree, o)

import zlib
x = zlib.compress(b'Hello World!')
print(decompress(x)) # b'Hello World!'``````