Decompression optimizations
[wimlib] / src / decompress.c
1 /*
2  * decompress.c
3  *
4  * Functions used for decompression.
5  */
6
7 /*
8  * Copyright (C) 2012 Eric Biggers
9  *
10  * This file is part of wimlib, a library for working with WIM files.
11  *
12  * wimlib is free software; you can redistribute it and/or modify it under the
13  * terms of the GNU General Public License as published by the Free
14  * Software Foundation; either version 3 of the License, or (at your option)
15  * any later version.
16  *
17  * wimlib is distributed in the hope that it will be useful, but WITHOUT ANY
18  * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
19  * A PARTICULAR PURPOSE. See the GNU General Public License for more
20  * details.
21  *
22  * You should have received a copy of the GNU General Public License
23  * along with wimlib; if not, see http://www.gnu.org/licenses/.
24  */
25
26 #include "decompress.h"
27 #include <string.h>
28
29 /* Reads @n bytes from the bitstream @stream into the location pointed to by @dest.
30  * The bitstream must be 16-bit aligned. */
31 int bitstream_read_bytes(struct input_bitstream *stream, size_t n, void *dest)
32 {
33         /* Precondition:  The bitstream is 16-byte aligned. */
34         wimlib_assert2(stream->bitsleft % 16 == 0);
35
36         u8 *p = dest;
37
38         /* Get the bytes currently in the buffer variable. */
39         while (stream->bitsleft != 0) {
40                 if (n-- == 0)
41                         return 0;
42                 *p++ = bitstream_peek_bits(stream, 8);
43                 bitstream_remove_bits(stream, 8);
44         }
45
46         /* Get the rest directly from the pointer to the data.  Of course, it's
47          * necessary to check there are really n bytes available. */
48         if (n > stream->data_bytes_left) {
49                 ERROR("Unexpected end of input when reading %zu bytes from "
50                       "bitstream (only have %u bytes left)",
51                       n, stream->data_bytes_left);
52                 return 1;
53         }
54         memcpy(p, stream->data, n);
55         stream->data += n;
56         stream->data_bytes_left -= n;
57
58         /* It's possible to copy an odd number of bytes and leave the stream in
59          * an inconsistent state. Fix it by reading the next byte, if it is
60          * there. */
61         if ((n & 1) && stream->data_bytes_left != 0) {
62                 stream->bitsleft = 8;
63                 stream->data_bytes_left--;
64                 stream->bitbuf |= (input_bitbuf_t)(*stream->data) <<
65                                         (sizeof(input_bitbuf_t) * 8 - 8);
66                 stream->data++;
67         }
68         return 0;
69 }
70
71 /*
72  * Builds a fast huffman decoding table from a canonical huffman code lengths
73  * table.  Based on code written by David Tritscher.
74  *
75  * @decode_table:       The array in which to create the fast huffman decoding
76  *                              table.  It must have a length of at least
77  *                              (2**num_bits) + 2 * num_syms to guarantee
78  *                              that there is enough space.
79  *
80  * @num_syms:   Total number of symbols in the Huffman tree.
81  *
82  * @num_bits:   Any symbols with a code length of num_bits or less can be
83  *                      decoded in one lookup of the table.  2**num_bits
84  *                      must be greater than or equal to @num_syms if there are
85  *                      any Huffman codes longer than @num_bits.
86  *
87  * @lens:       An array of length @num_syms, indexable by symbol, that
88  *                      gives the length of that symbol.  Because the Huffman
89  *                      tree is in canonical form, it can be reconstructed by
90  *                      only knowing the length of the code for each symbol.
91  *
92  * @make_codeword_len:  An integer that gives the longest possible codeword
93  *                      length.
94  *
95  * Returns 0 on success; returns 1 if the length values do not correspond to a
96  * valid Huffman tree, or if there are codes of length greater than @num_bits
97  * but 2**num_bits < num_syms.
98  *
99  * What exactly is the format of the fast Huffman decoding table?  The first
100  * (1 << num_bits) entries of the table are indexed by chunks of the input of
101  * size @num_bits.  If the next Huffman code in the input happens to have a
102  * length of exactly @num_bits, the symbol is simply read directly from the
103  * decoding table.  Alternatively, if the next Huffman code has length _less
104  * than_ @num_bits, the symbol is also read directly from the decode table; this
105  * is possible because every entry in the table that is indexed by an integer
106  * that has the shorter code as a binary prefix is filled in with the
107  * appropriate symbol.  If a code has length n <= num_bits, it will have
108  * 2**(num_bits - n) possible suffixes, and thus that many entries in the
109  * decoding table.
110  *
111  * It's a bit more complicated if the next Huffman code has length of more than
112  * @num_bits.  The table entry indexed by the first @num_bits of that code
113  * cannot give the appropriate symbol directly, because that entry is guaranteed
114  * to be referenced by the Huffman codes for multiple symbols.  And while the
115  * LZX compression format does not allow codes longer than 16 bits, a table of
116  * size (2 ** 16) = 65536 entries would be too slow to create.
117  *
118  * There are several different ways to make it possible to look up the symbols
119  * for codes longer than @num_bits.  A common way is to make the entries for the
120  * prefixes of length @num_bits of those entries be pointers to additional
121  * decoding tables that are indexed by some number of additional bits of the
122  * code symbol.  The technique used here is a bit simpler, however.  We just
123  * store the needed subtrees of the Huffman tree in the decoding table after the
124  * lookup entries, beginning at index (2**num_bits).  Real pointers are
125  * replaced by indices into the decoding table, and we distinguish symbol
126  * entries from pointers by the fact that values less than @num_syms must be
127  * symbol values.
128  */
129 int make_huffman_decode_table(u16 decode_table[],  unsigned num_syms,
130                               unsigned num_bits, const u8 lens[],
131                               unsigned max_code_len)
132 {
133         /* Number of entries in the decode table. */
134         u32 table_num_entries = 1 << num_bits;
135
136         /* Current position in the decode table. */
137         u32 decode_table_pos = 0;
138
139         /* Fill entries for codes short enough for a direct mapping.  Here we
140          * are taking advantage of the ordering of the codes, since they are for
141          * a canonical Huffman tree.  It must be the case that all the codes of
142          * some length @code_length, zero-extended or one-extended, numerically
143          * precede all the codes of length @code_length + 1.  Furthermore, if we
144          * have 2 symbols A and B, such that A is listed before B in the lens
145          * array, and both symbols have the same code length, then we know that
146          * the code for A numerically precedes the code for B.
147          * */
148         for (unsigned code_len = 1; code_len <= num_bits; code_len++) {
149
150                 /* Number of entries that a code of length @code_length would
151                  * need.  */
152                 u32 code_num_entries = 1 << (num_bits - code_len);
153
154
155                 /* For each symbol of length @code_len, fill in its entries in
156                  * the decode table. */
157                 for (unsigned sym = 0; sym < num_syms; sym++) {
158
159                         if (lens[sym] != code_len)
160                                 continue;
161
162
163                         /* Check for table overrun.  This can only happen if the
164                          * given lengths do not correspond to a valid Huffman
165                          * tree.  */
166                         if (decode_table_pos >= table_num_entries) {
167                                 ERROR("Huffman decoding table overrun: "
168                                       "pos = %u, num_entries = %u",
169                                       decode_table_pos, table_num_entries);
170                                 return 1;
171                         }
172
173                         /* Fill all possible lookups of this symbol with
174                          * the symbol itself. */
175                         for (unsigned i = 0; i < code_num_entries; i++)
176                                 decode_table[decode_table_pos + i] = sym;
177
178                         /* Increment the position in the decode table by
179                          * the number of entries that were just filled
180                          * in. */
181                         decode_table_pos += code_num_entries;
182                 }
183         }
184
185         /* If all entries of the decode table have been filled in, there are no
186          * codes longer than num_bits, so we are done filling in the decode
187          * table. */
188         if (decode_table_pos == table_num_entries)
189                 return 0;
190
191         /* Otherwise, fill in the remaining entries, which correspond to codes longer
192          * than @num_bits. */
193
194
195         /* First, zero out the rest of the entries; this is necessary so
196          * that the entries appear as "unallocated" in the next part.  */
197         for (unsigned i = decode_table_pos; i < table_num_entries; i++)
198                 decode_table[i] = 0;
199
200         /* Assert that 2**num_bits is at least num_syms.  If this wasn't the
201          * case, we wouldn't be able to distinguish pointer entries from symbol
202          * entries. */
203         wimlib_assert((1 << num_bits) >= num_syms);
204
205
206         /* The current Huffman code.  */
207         unsigned current_code = decode_table_pos;
208
209         /* The tree nodes are allocated starting at
210          * decode_table[table_num_entries].  Remember that the full size of the
211          * table, including the extra space for the tree nodes, is actually
212          * 2**num_bits + 2 * num_syms slots, while table_num_entries is only
213          * 2**num_bits. */
214         unsigned next_free_tree_slot = table_num_entries;
215
216         /* Go through every codeword of length greater than @num_bits.  Note:
217          * the LZX format guarantees that the codeword length can be at most 16
218          * bits. */
219         for (unsigned code_len = num_bits + 1; code_len <= max_code_len;
220                                                         code_len++)
221         {
222                 current_code <<= 1;
223                 for (unsigned sym = 0; sym < num_syms; sym++) {
224                         if (lens[sym] != code_len)
225                                 continue;
226
227
228                         /* i is the index of the current node; find it from the
229                          * prefix of the current Huffman code. */
230                         unsigned i = current_code >> (code_len - num_bits);
231
232                         if (i >= (1 << num_bits)) {
233                                 ERROR("Invalid canonical Huffman code");
234                                 return 1;
235                         }
236
237                         /* Go through each bit of the current Huffman code
238                          * beyond the prefix of length num_bits and walk the
239                          * tree, "allocating" slots that have not yet been
240                          * allocated. */
241                         for (int bit_num = num_bits + 1; bit_num <= code_len; bit_num++) {
242
243                                 /* If the current tree node points to nowhere
244                                  * but we need to follow it, allocate a new node
245                                  * for it to point to. */
246                                 if (decode_table[i] == 0) {
247                                         decode_table[i] = next_free_tree_slot;
248                                         decode_table[next_free_tree_slot++] = 0;
249                                         decode_table[next_free_tree_slot++] = 0;
250                                 }
251
252                                 i = decode_table[i];
253
254                                 /* Is the next bit 0 or 1? If 0, go left;
255                                  * otherwise, go right (by incrementing i by 1) */
256                                 int bit_pos = code_len - bit_num;
257
258                                 int bit = (current_code & (1 << bit_pos)) >>
259                                                                 bit_pos;
260                                 i += bit;
261                         }
262
263                         /* i is now the index of the leaf entry into which the
264                          * actual symbol will go. */
265                         decode_table[i] = sym;
266
267                         /* Increment decode_table_pos only if the prefix of the
268                          * Huffman code changes. */
269                         if (current_code >> (code_len - num_bits) !=
270                                         (current_code + 1) >> (code_len - num_bits))
271                                 decode_table_pos++;
272
273                         /* current_code is always incremented because this is
274                          * how canonical Huffman codes are generated (add 1 for
275                          * each code, then left shift whenever the code length
276                          * increases) */
277                         current_code++;
278                 }
279         }
280
281
282         /* If the lengths really represented a valid Huffman tree, all
283          * @table_num_entries in the table will have been filled.  However, it
284          * is also possible that the tree is completely empty (as noted
285          * earlier) with all 0 lengths, and this is expected to succeed. */
286
287         if (decode_table_pos != table_num_entries) {
288
289                 for (unsigned i = 0; i < num_syms; i++) {
290                         if (lens[i] != 0) {
291                                 ERROR("Lengths do not form a valid canonical "
292                                       "Huffman tree (only filled %u of %u "
293                                       "decode table slots)",
294                                       decode_table_pos, table_num_entries);
295                                 return 1;
296                         }
297                 }
298         }
299         return 0;
300 }
301
302 /* Reads a Huffman-encoded symbol when it is known there are less than
303  * MAX_CODE_LEN bits remaining in the bitstream. */
304 int read_huffsym_near_end_of_input(struct input_bitstream *istream,
305                                    const u16 decode_table[],
306                                    const u8 lens[],
307                                    unsigned num_syms,
308                                    unsigned table_bits,
309                                    unsigned *n)
310 {
311         unsigned bitsleft = istream->bitsleft;
312         unsigned key_size;
313         u16 sym;
314         u16 key_bits;
315
316         if (table_bits > bitsleft) {
317                 key_size = bitsleft;
318                 bitsleft = 0;
319                 key_bits = bitstream_peek_bits(istream, key_size) <<
320                                                 (table_bits - key_size);
321         } else {
322                 key_size = table_bits;
323                 bitsleft -= table_bits;
324                 key_bits = bitstream_peek_bits(istream, table_bits);
325         }
326
327         sym = decode_table[key_bits];
328         if (sym >= num_syms) {
329                 bitstream_remove_bits(istream, key_size);
330                 do {
331                         if (bitsleft == 0) {
332                                 ERROR("Input stream exhausted");
333                                 return 1;
334                         }
335                         key_bits = sym + bitstream_peek_bits(istream, 1);
336                         bitstream_remove_bits(istream, 1);
337                         bitsleft--;
338                 } while ((sym = decode_table[key_bits]) >= num_syms);
339         } else {
340                 bitstream_remove_bits(istream, lens[sym]);
341         }
342         *n = sym;
343         return 0;
344 }