77fdd251903732eb4c03b2aba20debbb0b06332d
[wimlib] / src / decompress_common.c
1 /*
2  * decompress_common.c
3  *
4  * Code for decompression shared among multiple compression formats.
5  */
6
7 /*
8  * Copyright (C) 2012, 2013, 2014 Eric Biggers
9  *
10  * This file is part of wimlib, a library for working with WIM files.
11  *
12  * wimlib is free software; you can redistribute it and/or modify it under the
13  * terms of the GNU General Public License as published by the Free
14  * Software Foundation; either version 3 of the License, or (at your option)
15  * any later version.
16  *
17  * wimlib is distributed in the hope that it will be useful, but WITHOUT ANY
18  * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
19  * A PARTICULAR PURPOSE. See the GNU General Public License for more
20  * details.
21  *
22  * You should have received a copy of the GNU General Public License
23  * along with wimlib; if not, see http://www.gnu.org/licenses/.
24  */
25
26 #ifdef HAVE_CONFIG_H
27 #  include "config.h"
28 #endif
29
30 #include "wimlib/decompress_common.h"
31 #include "wimlib/error.h"
32 #include "wimlib/util.h" /* for BUILD_BUG_ON()  */
33
34 #include <string.h>
35
36 #ifdef __GNUC__
37 #  ifdef __SSE2__
38 #    define USE_SSE2_FILL
39 #    include <emmintrin.h>
40 #  else
41 #    define USE_LONG_FILL
42 #  endif
43 #endif
44
45 /* Construct a direct mapping entry in the lookup table.  */
46 #define MAKE_DIRECT_ENTRY(symbol, length) ((symbol) | ((length) << 11))
47
48 /*
49  * make_huffman_decode_table() -
50  *
51  * Build a decoding table for a canonical prefix code, or "Huffman code".
52  *
53  * This takes as input the length of the codeword for each symbol in the
54  * alphabet and produces as output a table that can be used for fast
55  * decoding of prefix-encoded symbols using read_huffsym().
56  *
57  * Strictly speaking, a canonical prefix code might not be a Huffman
58  * code.  But this algorithm will work either way; and in fact, since
59  * Huffman codes are defined in terms of symbol frequencies, there is no
60  * way for the decompressor to know whether the code is a true Huffman
61  * code or not until all symbols have been decoded.
62  *
63  * Because the prefix code is assumed to be "canonical", it can be
64  * reconstructed directly from the codeword lengths.  A prefix code is
65  * canonical if and only if a longer codeword never lexicographically
66  * precedes a shorter codeword, and the lexicographic ordering of
67  * codewords of the same length is the same as the lexicographic ordering
68  * of the corresponding symbols.  Consequently, we can sort the symbols
69  * primarily by codeword length and secondarily by symbol value, then
70  * reconstruct the prefix code by generating codewords lexicographically
71  * in that order.
72  *
73  * This function does not, however, generate the prefix code explicitly.
74  * Instead, it directly builds a table for decoding symbols using the
75  * code.  The basic idea is this: given the next 'max_codeword_len' bits
76  * in the input, we can look up the decoded symbol by indexing a table
77  * containing 2**max_codeword_len entries.  A codeword with length
78  * 'max_codeword_len' will have exactly one entry in this table, whereas
79  * a codeword shorter than 'max_codeword_len' will have multiple entries
80  * in this table.  Precisely, a codeword of length n will be represented
81  * by 2**(max_codeword_len - n) entries in this table.  The 0-based index
82  * of each such entry will contain the corresponding codeword as a prefix
83  * when zero-padded on the left to 'max_codeword_len' binary digits.
84  *
85  * That's the basic idea, but we implement two optimizations regarding
86  * the format of the decode table itself:
87  *
88  * - For many compression formats, the maximum codeword length is too
89  *   long for it to be efficient to build the full decoding table
90  *   whenever a new prefix code is used.  Instead, we can build the table
91  *   using only 2**table_bits entries, where 'table_bits' is some number
92  *   less than or equal to 'max_codeword_len'.  Then, only codewords of
93  *   length 'table_bits' and shorter can be directly looked up.  For
94  *   longer codewords, the direct lookup instead produces the root of a
95  *   binary tree.  Using this tree, the decoder can do traditional
96  *   bit-by-bit decoding of the remainder of the codeword.  Child nodes
97  *   are allocated in extra entries at the end of the table; leaf nodes
98  *   contain symbols.  Note that the long-codeword case is, in general,
99  *   not performance critical, since in Huffman codes the most frequently
100  *   used symbols are assigned the shortest codeword lengths.
101  *
102  * - When we decode a symbol using a direct lookup of the table, we still
103  *   need to know its length so that the bitstream can be advanced by the
104  *   appropriate number of bits.  The simple solution is to simply retain
105  *   the 'lens' array and use the decoded symbol as an index into it.
106  *   However, this requires two separate array accesses in the fast path.
107  *   The optimization is to store the length directly in the decode
108  *   table.  We use the bottom 11 bits for the symbol and the top 5 bits
109  *   for the length.  In addition, to combine this optimization with the
110  *   previous one, we introduce a special case where the top 2 bits of
111  *   the length are both set if the entry is actually the root of a
112  *   binary tree.
113  *
114  * @decode_table:
115  *      The array in which to create the decoding table.
116  *      This must be 16-byte aligned and must have a length of at least
117  *      ((2**table_bits) + 2 * num_syms) entries.
118  *
119  * @num_syms:
120  *      The number of symbols in the alphabet; also, the length of the
121  *      'lens' array.  Must be less than or equal to
122  *      DECODE_TABLE_MAX_SYMBOLS.
123  *
124  * @table_bits:
125  *      The order of the decode table size, as explained above.  Must be
126  *      less than or equal to DECODE_TABLE_MAX_TABLE_BITS.
127  *
128  * @lens:
129  *      An array of length @num_syms, indexable by symbol, that gives the
130  *      length of the codeword, in bits, for that symbol.  The length can
131  *      be 0, which means that the symbol does not have a codeword
132  *      assigned.
133  *
134  * @max_codeword_len:
135  *      The longest codeword length allowed in the compression format.
136  *      All entries in 'lens' must be less than or equal to this value.
137  *      This must be less than or equal to DECODE_TABLE_MAX_CODEWORD_LEN.
138  *
139  * Returns 0 on success, or -1 if the lengths do not form a valid prefix
140  * code.
141  */
142 int
143 make_huffman_decode_table(u16 decode_table[const restrict],
144                           const unsigned num_syms,
145                           const unsigned table_bits,
146                           const u8 lens[const restrict],
147                           const unsigned max_codeword_len)
148 {
149         const unsigned table_num_entries = 1 << table_bits;
150         unsigned len_counts[max_codeword_len + 1];
151         u16 sorted_syms[num_syms];
152         int left;
153         void *decode_table_ptr;
154         unsigned sym_idx;
155         unsigned codeword_len;
156         unsigned stores_per_loop;
157         unsigned decode_table_pos;
158
159 #ifdef USE_LONG_FILL
160         const unsigned entries_per_long = sizeof(unsigned long) / sizeof(decode_table[0]);
161 #endif
162
163 #ifdef USE_SSE2_FILL
164         const unsigned entries_per_xmm = sizeof(__m128i) / sizeof(decode_table[0]);
165 #endif
166
167         /* Check parameters if assertions are enabled.  */
168         wimlib_assert2((uintptr_t)decode_table % DECODE_TABLE_ALIGNMENT == 0);
169         wimlib_assert2(num_syms <= DECODE_TABLE_MAX_SYMBOLS);
170         wimlib_assert2(table_bits <= DECODE_TABLE_MAX_TABLE_BITS);
171         wimlib_assert2(max_codeword_len <= DECODE_TABLE_MAX_CODEWORD_LEN);
172         for (unsigned sym = 0; sym < num_syms; sym++)
173                 wimlib_assert2(lens[sym] <= max_codeword_len);
174
175         /* Count how many symbols have each possible codeword length.
176          * Note that a length of 0 indicates the corresponding symbol is not
177          * used in the code and therefore does not have a codeword.  */
178         for (unsigned len = 0; len <= max_codeword_len; len++)
179                 len_counts[len] = 0;
180         for (unsigned sym = 0; sym < num_syms; sym++)
181                 len_counts[lens[sym]]++;
182
183         /* We can assume all lengths are <= max_codeword_len, but we
184          * cannot assume they form a valid prefix code.  A codeword of
185          * length n should require a proportion of the codespace equaling
186          * (1/2)^n.  The code is valid if and only if the codespace is
187          * exactly filled by the lengths, by this measure.  */
188         left = 1;
189         for (unsigned len = 1; len <= max_codeword_len; len++) {
190                 left <<= 1;
191                 left -= len_counts[len];
192                 if (unlikely(left < 0)) {
193                         /* The lengths overflow the codespace; that is, the code
194                          * is over-subscribed.  */
195                         DEBUG("Invalid prefix code (over-subscribed)");
196                         return -1;
197                 }
198         }
199
200         if (unlikely(left != 0)) {
201                 /* The lengths do not fill the codespace; that is, they form an
202                  * incomplete set.  */
203                 if (left == (1 << max_codeword_len)) {
204                         /* The code is completely empty.  This is arguably
205                          * invalid, but in fact it is valid in LZX and XPRESS,
206                          * so we must allow it.  By definition, no symbols can
207                          * be decoded with an empty code.  Consequently, we
208                          * technically don't even need to fill in the decode
209                          * table.  However, to avoid accessing uninitialized
210                          * memory if the algorithm nevertheless attempts to
211                          * decode symbols using such a code, we zero out the
212                          * decode table.  */
213                         memset(decode_table, 0,
214                                table_num_entries * sizeof(decode_table[0]));
215                         return 0;
216                 }
217                 DEBUG("Invalid prefix code (incomplete set)");
218                 return -1;
219         }
220
221         /* Sort the symbols primarily by length and secondarily by symbol order.
222          */
223         {
224                 unsigned offsets[max_codeword_len + 1];
225
226                 /* Initialize 'offsets' so that offsets[len] for 1 <= len <=
227                  * max_codeword_len is the number of codewords shorter than
228                  * 'len' bits.  */
229                 offsets[1] = 0;
230                 for (unsigned len = 1; len < max_codeword_len; len++)
231                         offsets[len + 1] = offsets[len] + len_counts[len];
232
233                 /* Use the 'offsets' array to sort the symbols.
234                  * Note that we do not include symbols that are not used in the
235                  * code.  Consequently, fewer than 'num_syms' entries in
236                  * 'sorted_syms' may be filled.  */
237                 for (unsigned sym = 0; sym < num_syms; sym++)
238                         if (lens[sym] != 0)
239                                 sorted_syms[offsets[lens[sym]]++] = sym;
240         }
241
242         /* Fill entries for codewords with length <= table_bits
243          * --- that is, those short enough for a direct mapping.
244          *
245          * The table will start with entries for the shortest codeword(s), which
246          * have the most entries.  From there, the number of entries per
247          * codeword will decrease.  As an optimization, we may begin filling
248          * entries with SSE2 vector accesses (8 entries/store), then change to
249          * 'unsigned long' accesses (2 or 4 entries/store), then change to
250          * 16-bit accesses (1 entry/store).  */
251         decode_table_ptr = decode_table;
252         sym_idx = 0;
253         codeword_len = 1;
254 #ifdef USE_SSE2_FILL
255         /* Fill the entries one 128-bit vector at a time.
256          * This is 8 entries per store.  */
257         stores_per_loop = (1 << (table_bits - codeword_len)) / entries_per_xmm;
258         for (; stores_per_loop != 0; codeword_len++, stores_per_loop >>= 1) {
259                 unsigned end_sym_idx = sym_idx + len_counts[codeword_len];
260                 for (; sym_idx < end_sym_idx; sym_idx++) {
261                         /* Note: unlike in the 'long' version below, the __m128i
262                          * type already has __attribute__((may_alias)), so using
263                          * it to access the decode table, which is an array of
264                          * unsigned shorts, will not violate strict aliasing.
265                          */
266                         u16 entry;
267                         __m128i v;
268                         __m128i *p;
269                         unsigned n;
270
271                         entry = MAKE_DIRECT_ENTRY(sorted_syms[sym_idx], codeword_len);
272
273                         v = _mm_set1_epi16(entry);
274                         p = (__m128i*)decode_table_ptr;
275                         n = stores_per_loop;
276                         do {
277                                 *p++ = v;
278                         } while (--n);
279                         decode_table_ptr = p;
280                 }
281         }
282 #endif /* USE_SSE2_FILL */
283
284 #ifdef USE_LONG_FILL
285         /* Fill the entries one 'unsigned long' at a time.
286          * On 32-bit systems this is 2 entries per store, while on 64-bit
287          * systems this is 4 entries per store.  */
288         stores_per_loop = (1 << (table_bits - codeword_len)) / entries_per_long;
289         for (; stores_per_loop != 0; codeword_len++, stores_per_loop >>= 1) {
290                 unsigned end_sym_idx = sym_idx + len_counts[codeword_len];
291                 for (; sym_idx < end_sym_idx; sym_idx++) {
292
293                         /* Accessing the array of unsigned shorts as unsigned
294                          * longs would violate strict aliasing and would require
295                          * compiling the code with -fno-strict-aliasing to
296                          * guarantee correctness.  To work around this problem,
297                          * use the gcc 'may_alias' extension to define a special
298                          * unsigned long type that may alias any other in-memory
299                          * variable.  */
300                         typedef unsigned long __attribute__((may_alias)) aliased_long_t;
301
302                         unsigned long v;
303                         aliased_long_t *p;
304                         unsigned n;
305
306                         BUILD_BUG_ON(sizeof(unsigned long) != 4 &&
307                                      sizeof(unsigned long) != 8);
308
309                         v = MAKE_DIRECT_ENTRY(sorted_syms[sym_idx], codeword_len);
310                         v |= v << 16;
311                         if (sizeof(unsigned long) == 8) {
312                                 /* This may produce a compiler warning if an
313                                  * 'unsigned long' is 32 bits, but this won't be
314                                  * executed unless an 'unsigned long' is at
315                                  * least 64 bits anyway.  */
316                                 v |= v << 32;
317                         }
318
319                         p = (aliased_long_t *)decode_table_ptr;
320                         n = stores_per_loop;
321
322                         do {
323                                 *p++ = v;
324                         } while (--n);
325                         decode_table_ptr = p;
326                 }
327         }
328 #endif /* USE_LONG_FILL */
329
330         /* Fill the entries one 16-bit integer at a time.  */
331         stores_per_loop = (1 << (table_bits - codeword_len));
332         for (; stores_per_loop != 0; codeword_len++, stores_per_loop >>= 1) {
333                 unsigned end_sym_idx = sym_idx + len_counts[codeword_len];
334                 for (; sym_idx < end_sym_idx; sym_idx++) {
335                         u16 entry;
336                         u16 *p;
337                         unsigned n;
338
339                         entry = MAKE_DIRECT_ENTRY(sorted_syms[sym_idx], codeword_len);
340
341                         p = (u16*)decode_table_ptr;
342                         n = stores_per_loop;
343
344                         do {
345                                 *p++ = entry;
346                         } while (--n);
347
348                         decode_table_ptr = p;
349                 }
350         }
351
352         /* If we've filled in the entire table, we are done.  Otherwise,
353          * there are codewords longer than table_bits for which we must
354          * generate binary trees.  */
355
356         decode_table_pos = (u16*)decode_table_ptr - decode_table;
357         if (decode_table_pos != table_num_entries) {
358                 unsigned j;
359                 unsigned next_free_tree_slot;
360                 unsigned cur_codeword;
361
362                 /* First, zero out the remaining entries.  This is
363                  * necessary so that these entries appear as
364                  * "unallocated" in the next part.  Each of these entries
365                  * will eventually be filled with the representation of
366                  * the root node of a binary tree.  */
367                 j = decode_table_pos;
368                 do {
369                         decode_table[j] = 0;
370                 } while (++j != table_num_entries);
371
372                 /* We allocate child nodes starting at the end of the
373                  * direct lookup table.  Note that there should be
374                  * 2*num_syms extra entries for this purpose, although
375                  * fewer than this may actually be needed.  */
376                 next_free_tree_slot = table_num_entries;
377
378                 /* Iterate through each codeword with length greater than
379                  * 'table_bits', primarily in order of codeword length
380                  * and secondarily in order of symbol.  */
381                 for (cur_codeword = decode_table_pos << 1;
382                      codeword_len <= max_codeword_len;
383                      codeword_len++, cur_codeword <<= 1)
384                 {
385                         unsigned end_sym_idx = sym_idx + len_counts[codeword_len];
386                         for (; sym_idx < end_sym_idx; sym_idx++, cur_codeword++)
387                         {
388                                 /* 'sym' is the symbol represented by the
389                                  * codeword.  */
390                                 unsigned sym = sorted_syms[sym_idx];
391
392                                 unsigned extra_bits = codeword_len - table_bits;
393
394                                 unsigned node_idx = cur_codeword >> extra_bits;
395
396                                 /* Go through each bit of the current codeword
397                                  * beyond the prefix of length @table_bits and
398                                  * walk the appropriate binary tree, allocating
399                                  * any slots that have not yet been allocated.
400                                  *
401                                  * Note that the 'pointer' entry to the binary
402                                  * tree, which is stored in the direct lookup
403                                  * portion of the table, is represented
404                                  * identically to other internal (non-leaf)
405                                  * nodes of the binary tree; it can be thought
406                                  * of as simply the root of the tree.  The
407                                  * representation of these internal nodes is
408                                  * simply the index of the left child combined
409                                  * with the special bits 0xC000 to distingush
410                                  * the entry from direct mapping and leaf node
411                                  * entries.  */
412                                 do {
413
414                                         /* At least one bit remains in the
415                                          * codeword, but the current node is an
416                                          * unallocated leaf.  Change it to an
417                                          * internal node.  */
418                                         if (decode_table[node_idx] == 0) {
419                                                 decode_table[node_idx] =
420                                                         next_free_tree_slot | 0xC000;
421                                                 decode_table[next_free_tree_slot++] = 0;
422                                                 decode_table[next_free_tree_slot++] = 0;
423                                         }
424
425                                         /* Go to the left child if the next bit
426                                          * in the codeword is 0; otherwise go to
427                                          * the right child.  */
428                                         node_idx = decode_table[node_idx] & 0x3FFF;
429                                         --extra_bits;
430                                         node_idx += (cur_codeword >> extra_bits) & 1;
431                                 } while (extra_bits != 0);
432
433                                 /* We've traversed the tree using the entire
434                                  * codeword, and we're now at the entry where
435                                  * the actual symbol will be stored.  This is
436                                  * distinguished from internal nodes by not
437                                  * having its high two bits set.  */
438                                 decode_table[node_idx] = sym;
439                         }
440                 }
441         }
442         return 0;
443 }