decompress_common: switch to subtables for Huffman decoding
[wimlib] / include / wimlib / decompress_common.h
1 /*
2  * decompress_common.h
3  *
4  * Header for decompression code shared by multiple compression formats.
5  *
6  * The following copying information applies to this specific source code file:
7  *
8  * Written in 2012-2016 by Eric Biggers <ebiggers3@gmail.com>
9  *
10  * To the extent possible under law, the author(s) have dedicated all copyright
11  * and related and neighboring rights to this software to the public domain
12  * worldwide via the Creative Commons Zero 1.0 Universal Public Domain
13  * Dedication (the "CC0").
14  *
15  * This software is distributed in the hope that it will be useful, but WITHOUT
16  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
17  * FOR A PARTICULAR PURPOSE. See the CC0 for more details.
18  *
19  * You should have received a copy of the CC0 along with this software; if not
20  * see <http://creativecommons.org/publicdomain/zero/1.0/>.
21  */
22
23 #ifndef _WIMLIB_DECOMPRESS_COMMON_H
24 #define _WIMLIB_DECOMPRESS_COMMON_H
25
26 #include <string.h>
27
28 #include "wimlib/compiler.h"
29 #include "wimlib/types.h"
30 #include "wimlib/unaligned.h"
31
32 /******************************************************************************/
33 /*                   Input bitstream for XPRESS and LZX                       */
34 /*----------------------------------------------------------------------------*/
35
36 /* Structure that encapsulates a block of in-memory data being interpreted as a
37  * stream of bits, optionally with interwoven literal bytes.  Bits are assumed
38  * to be stored in little endian 16-bit coding units, with the bits ordered high
39  * to low.  */
40 struct input_bitstream {
41
42         /* Bits that have been read from the input buffer.  The bits are
43          * left-justified; the next bit is always bit 31.  */
44         u32 bitbuf;
45
46         /* Number of bits currently held in @bitbuf.  */
47         u32 bitsleft;
48
49         /* Pointer to the next byte to be retrieved from the input buffer.  */
50         const u8 *next;
51
52         /* Pointer past the end of the input buffer.  */
53         const u8 *end;
54 };
55
56 /* Initialize a bitstream to read from the specified input buffer.  */
57 static inline void
58 init_input_bitstream(struct input_bitstream *is, const void *buffer, u32 size)
59 {
60         is->bitbuf = 0;
61         is->bitsleft = 0;
62         is->next = buffer;
63         is->end = is->next + size;
64 }
65
66 /* Note: for performance reasons, the following methods don't return error codes
67  * to the caller if the input buffer is overrun.  Instead, they just assume that
68  * all overrun data is zeroes.  This has no effect on well-formed compressed
69  * data.  The only disadvantage is that bad compressed data may go undetected,
70  * but even this is irrelevant if higher level code checksums the uncompressed
71  * data anyway.  */
72
73 /* Ensure the bit buffer variable for the bitstream contains at least @num_bits
74  * bits.  Following this, bitstream_peek_bits() and/or bitstream_remove_bits()
75  * may be called on the bitstream to peek or remove up to @num_bits bits.  */
76 static inline void
77 bitstream_ensure_bits(struct input_bitstream *is, const unsigned num_bits)
78 {
79         /* This currently works for at most 17 bits.  */
80
81         if (is->bitsleft >= num_bits)
82                 return;
83
84         if (unlikely(is->end - is->next < 2))
85                 goto overflow;
86
87         is->bitbuf |= (u32)get_unaligned_le16(is->next) << (16 - is->bitsleft);
88         is->next += 2;
89         is->bitsleft += 16;
90
91         if (unlikely(num_bits == 17 && is->bitsleft == 16)) {
92                 if (unlikely(is->end - is->next < 2))
93                         goto overflow;
94
95                 is->bitbuf |= (u32)get_unaligned_le16(is->next);
96                 is->next += 2;
97                 is->bitsleft = 32;
98         }
99
100         return;
101
102 overflow:
103         is->bitsleft = 32;
104 }
105
106 /* Return the next @num_bits bits from the bitstream, without removing them.
107  * There must be at least @num_bits remaining in the buffer variable, from a
108  * previous call to bitstream_ensure_bits().  */
109 static inline u32
110 bitstream_peek_bits(const struct input_bitstream *is, const unsigned num_bits)
111 {
112         return (is->bitbuf >> 1) >> (sizeof(is->bitbuf) * 8 - num_bits - 1);
113 }
114
115 /* Remove @num_bits from the bitstream.  There must be at least @num_bits
116  * remaining in the buffer variable, from a previous call to
117  * bitstream_ensure_bits().  */
118 static inline void
119 bitstream_remove_bits(struct input_bitstream *is, unsigned num_bits)
120 {
121         is->bitbuf <<= num_bits;
122         is->bitsleft -= num_bits;
123 }
124
125 /* Remove and return @num_bits bits from the bitstream.  There must be at least
126  * @num_bits remaining in the buffer variable, from a previous call to
127  * bitstream_ensure_bits().  */
128 static inline u32
129 bitstream_pop_bits(struct input_bitstream *is, unsigned num_bits)
130 {
131         u32 bits = bitstream_peek_bits(is, num_bits);
132         bitstream_remove_bits(is, num_bits);
133         return bits;
134 }
135
136 /* Read and return the next @num_bits bits from the bitstream.  */
137 static inline u32
138 bitstream_read_bits(struct input_bitstream *is, unsigned num_bits)
139 {
140         bitstream_ensure_bits(is, num_bits);
141         return bitstream_pop_bits(is, num_bits);
142 }
143
144 /* Read and return the next literal byte embedded in the bitstream.  */
145 static inline u8
146 bitstream_read_byte(struct input_bitstream *is)
147 {
148         if (unlikely(is->end == is->next))
149                 return 0;
150         return *is->next++;
151 }
152
153 /* Read and return the next 16-bit integer embedded in the bitstream.  */
154 static inline u16
155 bitstream_read_u16(struct input_bitstream *is)
156 {
157         u16 v;
158
159         if (unlikely(is->end - is->next < 2))
160                 return 0;
161         v = get_unaligned_le16(is->next);
162         is->next += 2;
163         return v;
164 }
165
166 /* Read and return the next 32-bit integer embedded in the bitstream.  */
167 static inline u32
168 bitstream_read_u32(struct input_bitstream *is)
169 {
170         u32 v;
171
172         if (unlikely(is->end - is->next < 4))
173                 return 0;
174         v = get_unaligned_le32(is->next);
175         is->next += 4;
176         return v;
177 }
178
179 /* Read into @dst_buffer an array of literal bytes embedded in the bitstream.
180  * Return 0 if there were enough bytes remaining in the input, otherwise -1. */
181 static inline int
182 bitstream_read_bytes(struct input_bitstream *is, void *dst_buffer, size_t count)
183 {
184         if (unlikely(is->end - is->next < count))
185                 return -1;
186         memcpy(dst_buffer, is->next, count);
187         is->next += count;
188         return 0;
189 }
190
191 /* Align the input bitstream on a coding-unit boundary.  */
192 static inline void
193 bitstream_align(struct input_bitstream *is)
194 {
195         is->bitsleft = 0;
196         is->bitbuf = 0;
197 }
198
199 /******************************************************************************/
200 /*                             Huffman decoding                               */
201 /*----------------------------------------------------------------------------*/
202
203 /*
204  * Required alignment for the Huffman decode tables.  We require this alignment
205  * so that we can fill the entries with vector or word instructions and not have
206  * to deal with misaligned buffers.
207  */
208 #define DECODE_TABLE_ALIGNMENT 16
209
210 /*
211  * Each decode table entry is 16 bits divided into two fields: 'symbol' (high 12
212  * bits) and 'length' (low 4 bits).  The precise meaning of these fields depends
213  * on the type of entry:
214  *
215  * Root table entries which are *not* subtable pointers:
216  *      symbol: symbol to decode
217  *      length: codeword length in bits
218  *
219  * Root table entries which are subtable pointers:
220  *      symbol: index of start of subtable
221  *      length: number of bits with which the subtable is indexed
222  *
223  * Subtable entries:
224  *      symbol: symbol to decode
225  *      length: codeword length in bits, minus the number of bits with which the
226  *              root table is indexed
227  */
228 #define DECODE_TABLE_SYMBOL_SHIFT  4
229 #define DECODE_TABLE_MAX_SYMBOL    ((1 << (16 - DECODE_TABLE_SYMBOL_SHIFT)) - 1)
230 #define DECODE_TABLE_MAX_LENGTH    ((1 << DECODE_TABLE_SYMBOL_SHIFT) - 1)
231 #define DECODE_TABLE_LENGTH_MASK   DECODE_TABLE_MAX_LENGTH
232 #define MAKE_DECODE_TABLE_ENTRY(symbol, length) \
233         (((symbol) << DECODE_TABLE_SYMBOL_SHIFT) | (length))
234
235 /*
236  * Read and return the next Huffman-encoded symbol from the given bitstream
237  * using the given decode table.
238  *
239  * If the input data is exhausted, then the Huffman symbol will be decoded as if
240  * the missing bits were all zeroes.
241  *
242  * XXX: This is mostly duplicated in lzms_decode_huffman_symbol() in
243  * lzms_decompress.c; keep them in sync!
244  */
245 static inline unsigned
246 read_huffsym(struct input_bitstream *is, const u16 decode_table[],
247              unsigned table_bits, unsigned max_codeword_len)
248 {
249         unsigned entry;
250         unsigned symbol;
251         unsigned length;
252
253         /* Preload the bitbuffer with 'max_codeword_len' bits so that we're
254          * guaranteed to be able to fully decode a codeword. */
255         bitstream_ensure_bits(is, max_codeword_len);
256
257         /* Index the root table by the next 'table_bits' bits of input. */
258         entry = decode_table[bitstream_peek_bits(is, table_bits)];
259
260         /* Extract the "symbol" and "length" from the entry. */
261         symbol = entry >> DECODE_TABLE_SYMBOL_SHIFT;
262         length = entry & DECODE_TABLE_LENGTH_MASK;
263
264         /* If the root table is indexed by the full 'max_codeword_len' bits,
265          * then there cannot be any subtables, and this will be known at compile
266          * time.  Otherwise, we must check whether the decoded symbol is really
267          * a subtable pointer.  If so, we must discard the bits with which the
268          * root table was indexed, then index the subtable by the next 'length'
269          * bits of input to get the real entry. */
270         if (max_codeword_len > table_bits &&
271             entry >= (1U << (table_bits + DECODE_TABLE_SYMBOL_SHIFT)))
272         {
273                 /* Subtable required */
274                 bitstream_remove_bits(is, table_bits);
275                 entry = decode_table[symbol + bitstream_peek_bits(is, length)];
276                 symbol = entry >> DECODE_TABLE_SYMBOL_SHIFT;
277                 length = entry & DECODE_TABLE_LENGTH_MASK;
278         }
279
280         /* Discard the bits (or the remaining bits, if a subtable was required)
281          * of the codeword. */
282         bitstream_remove_bits(is, length);
283
284         /* Return the decoded symbol. */
285         return symbol;
286 }
287
288 /*
289  * The DECODE_TABLE_ENOUGH() macro evaluates to the maximum number of decode
290  * table entries, including all subtable entries, that may be required for
291  * decoding a given Huffman code.  This depends on three parameters:
292  *
293  *      num_syms: the maximum number of symbols in the code
294  *      table_bits: the number of bits with which the root table will be indexed
295  *      max_codeword_len: the maximum allowed codeword length in the code
296  *
297  * Given these parameters, the utility program 'enough' from zlib, when passed
298  * the three arguments 'num_syms', 'table_bits', and 'max_codeword_len', will
299  * compute the maximum number of entries required.  This has already been done
300  * for the combinations we need and incorporated into the macro below so that
301  * the mapping can be done at compilation time.  If an unknown combination is
302  * used, then a compilation error will result.  To fix this, use 'enough' to
303  * find the missing value and add it below.  If that still doesn't fix the
304  * compilation error, then most likely a constraint would be violated by the
305  * requested parameters, so they cannot be used, at least without other changes
306  * to the decode table --- see DECODE_TABLE_SIZE().
307  */
308 #define DECODE_TABLE_ENOUGH(num_syms, table_bits, max_codeword_len) ( \
309         ((num_syms) == 8 && (table_bits) == 7 && (max_codeword_len) == 15) ? 128 : \
310         ((num_syms) == 8 && (table_bits) == 5 && (max_codeword_len) == 7) ? 36 : \
311         ((num_syms) == 8 && (table_bits) == 6 && (max_codeword_len) == 7) ? 66 : \
312         ((num_syms) == 8 && (table_bits) == 7 && (max_codeword_len) == 7) ? 128 : \
313         ((num_syms) == 20 && (table_bits) == 5 && (max_codeword_len) == 15) ? 1062 : \
314         ((num_syms) == 20 && (table_bits) == 6 && (max_codeword_len) == 15) ? 582 : \
315         ((num_syms) == 20 && (table_bits) == 7 && (max_codeword_len) == 15) ? 390 : \
316         ((num_syms) == 54 && (table_bits) == 9 && (max_codeword_len) == 15) ? 618 : \
317         ((num_syms) == 54 && (table_bits) == 10 && (max_codeword_len) == 15) ? 1098 : \
318         ((num_syms) == 249 && (table_bits) == 9 && (max_codeword_len) == 16) ? 878 : \
319         ((num_syms) == 249 && (table_bits) == 10 && (max_codeword_len) == 16) ? 1326 : \
320         ((num_syms) == 249 && (table_bits) == 11 && (max_codeword_len) == 16) ? 2318 : \
321         ((num_syms) == 256 && (table_bits) == 9 && (max_codeword_len) == 15) ? 822 : \
322         ((num_syms) == 256 && (table_bits) == 10 && (max_codeword_len) == 15) ? 1302 : \
323         ((num_syms) == 256 && (table_bits) == 11 && (max_codeword_len) == 15) ? 2310 : \
324         ((num_syms) == 512 && (table_bits) == 10 && (max_codeword_len) == 15) ? 1558 : \
325         ((num_syms) == 512 && (table_bits) == 11 && (max_codeword_len) == 15) ? 2566 : \
326         ((num_syms) == 512 && (table_bits) == 12 && (max_codeword_len) == 15) ? 4606 : \
327         ((num_syms) == 656 && (table_bits) == 10 && (max_codeword_len) == 16) ? 1734 : \
328         ((num_syms) == 656 && (table_bits) == 11 && (max_codeword_len) == 16) ? 2726 : \
329         ((num_syms) == 656 && (table_bits) == 12 && (max_codeword_len) == 16) ? 4758 : \
330         ((num_syms) == 799 && (table_bits) == 9 && (max_codeword_len) == 15) ? 1366 : \
331         ((num_syms) == 799 && (table_bits) == 10 && (max_codeword_len) == 15) ? 1846 : \
332         ((num_syms) == 799 && (table_bits) == 11 && (max_codeword_len) == 15) ? 2854 : \
333         -1)
334
335 /* Wrapper around DECODE_TABLE_ENOUGH() that does additional compile-time
336  * validation. */
337 #define DECODE_TABLE_SIZE(num_syms, table_bits, max_codeword_len) (     \
338                                                                         \
339         /* All values must be positive. */                              \
340         STATIC_ASSERT_ZERO((num_syms) > 0) +                            \
341         STATIC_ASSERT_ZERO((table_bits) > 0) +                          \
342         STATIC_ASSERT_ZERO((max_codeword_len) > 0) +                    \
343                                                                         \
344         /* There cannot be more symbols than possible codewords. */     \
345         STATIC_ASSERT_ZERO((num_syms) <= 1U << (max_codeword_len)) +    \
346                                                                         \
347         /* There is no reason for the root table to be indexed with
348          * more bits than the maximum codeword length. */               \
349         STATIC_ASSERT_ZERO((table_bits) <= (max_codeword_len)) +        \
350                                                                         \
351         /* The maximum symbol value must fit in the 'symbol' field. */  \
352         STATIC_ASSERT_ZERO((num_syms) - 1 <= DECODE_TABLE_MAX_SYMBOL) + \
353                                                                         \
354         /* The maximum codeword length in the root table must fit in
355          * the 'length' field. */                                       \
356         STATIC_ASSERT_ZERO((table_bits) <= DECODE_TABLE_MAX_LENGTH) +   \
357                                                                         \
358         /* The maximum codeword length in a subtable must fit in the
359          * 'length' field. */                                           \
360         STATIC_ASSERT_ZERO((max_codeword_len) - (table_bits) <=         \
361                            DECODE_TABLE_MAX_LENGTH) +                   \
362                                                                         \
363         /* The minimum subtable index must be greater than the maximum
364          * symbol value.  If this were not the case, then there would
365          * be no way to tell whether a given root table entry is a
366          * "subtable pointer" or not.  (An alternate solution would be
367          * to reserve a flag bit specifically for this purpose.) */     \
368         STATIC_ASSERT_ZERO((1U << table_bits) > (num_syms) - 1) +       \
369                                                                         \
370         /* The needed 'enough' value must have been defined. */         \
371         STATIC_ASSERT_ZERO(DECODE_TABLE_ENOUGH(                         \
372                                 (num_syms), (table_bits),               \
373                                 (max_codeword_len)) > 0) +              \
374                                                                         \
375         /* The maximum subtable index must fit in the 'symbol' field. */\
376         STATIC_ASSERT_ZERO(DECODE_TABLE_ENOUGH(                         \
377                                 (num_syms), (table_bits),               \
378                                 (max_codeword_len)) - 1 <=              \
379                                         DECODE_TABLE_MAX_SYMBOL) +      \
380                                                                         \
381         /* Finally, make the macro evaluate to the needed maximum
382          * number of decode table entries. */                           \
383         DECODE_TABLE_ENOUGH((num_syms), (table_bits),                   \
384                             (max_codeword_len))                         \
385 )
386
387 /*
388  * Declare the decode table for a Huffman code, given several compile-time
389  * constants that describe the code.  See DECODE_TABLE_ENOUGH() for details.
390  *
391  * Decode tables must be aligned to a DECODE_TABLE_ALIGNMENT-byte boundary.
392  * This implies that if a decode table is nested inside a dynamically allocated
393  * structure, then the outer structure must be allocated on a
394  * DECODE_TABLE_ALIGNMENT-byte aligned boundary as well.
395  */
396 #define DECODE_TABLE(name, num_syms, table_bits, max_codeword_len) \
397         u16 name[DECODE_TABLE_SIZE((num_syms), (table_bits), \
398                                    (max_codeword_len))] \
399                 _aligned_attribute(DECODE_TABLE_ALIGNMENT)
400
401 extern int
402 make_huffman_decode_table(u16 decode_table[], unsigned num_syms,
403                           unsigned table_bits, const u8 lens[],
404                           unsigned max_codeword_len);
405
406 /******************************************************************************/
407 /*                             LZ match copying                               */
408 /*----------------------------------------------------------------------------*/
409
410 static inline void
411 copy_word_unaligned(const void *src, void *dst)
412 {
413         store_word_unaligned(load_word_unaligned(src), dst);
414 }
415
416 static inline machine_word_t
417 repeat_u16(u16 b)
418 {
419         machine_word_t v = b;
420
421         STATIC_ASSERT(WORDBITS == 32 || WORDBITS == 64);
422         v |= v << 16;
423         v |= v << ((WORDBITS == 64) ? 32 : 0);
424         return v;
425 }
426
427 static inline machine_word_t
428 repeat_byte(u8 b)
429 {
430         return repeat_u16(((u16)b << 8) | b);
431 }
432
433 /*
434  * Copy an LZ77 match at (dst - offset) to dst.
435  *
436  * The length and offset must be already validated --- that is, (dst - offset)
437  * can't underrun the output buffer, and (dst + length) can't overrun the output
438  * buffer.  Also, the length cannot be 0.
439  *
440  * @winend points to the byte past the end of the output buffer.
441  * This function won't write any data beyond this position.
442  */
443 static inline void
444 lz_copy(u8 *dst, u32 length, u32 offset, const u8 *winend, u32 min_length)
445 {
446         const u8 *src = dst - offset;
447         const u8 * const end = dst + length;
448
449         /*
450          * Try to copy one machine word at a time.  On i386 and x86_64 this is
451          * faster than copying one byte at a time, unless the data is
452          * near-random and all the matches have very short lengths.  Note that
453          * since this requires unaligned memory accesses, it won't necessarily
454          * be faster on every architecture.
455          *
456          * Also note that we might copy more than the length of the match.  For
457          * example, if a word is 8 bytes and the match is of length 5, then
458          * we'll simply copy 8 bytes.  This is okay as long as we don't write
459          * beyond the end of the output buffer, hence the check for (winend -
460          * end >= WORDBYTES - 1).
461          */
462         if (UNALIGNED_ACCESS_IS_FAST && likely(winend - end >= WORDBYTES - 1)) {
463
464                 if (offset >= WORDBYTES) {
465                         /* The source and destination words don't overlap.  */
466
467                         /* To improve branch prediction, one iteration of this
468                          * loop is unrolled.  Most matches are short and will
469                          * fail the first check.  But if that check passes, then
470                          * it becomes increasing likely that the match is long
471                          * and we'll need to continue copying.  */
472
473                         copy_word_unaligned(src, dst);
474                         src += WORDBYTES;
475                         dst += WORDBYTES;
476
477                         if (dst < end) {
478                                 do {
479                                         copy_word_unaligned(src, dst);
480                                         src += WORDBYTES;
481                                         dst += WORDBYTES;
482                                 } while (dst < end);
483                         }
484                         return;
485                 } else if (offset == 1) {
486
487                         /* Offset 1 matches are equivalent to run-length
488                          * encoding of the previous byte.  This case is common
489                          * if the data contains many repeated bytes.  */
490
491                         machine_word_t v = repeat_byte(*(dst - 1));
492                         do {
493                                 store_word_unaligned(v, dst);
494                                 src += WORDBYTES;
495                                 dst += WORDBYTES;
496                         } while (dst < end);
497                         return;
498                 }
499                 /*
500                  * We don't bother with special cases for other 'offset <
501                  * WORDBYTES', which are usually rarer than 'offset == 1'.
502                  * Extra checks will just slow things down.  Actually, it's
503                  * possible to handle all the 'offset < WORDBYTES' cases using
504                  * the same code, but it still becomes more complicated doesn't
505                  * seem any faster overall; it definitely slows down the more
506                  * common 'offset == 1' case.
507                  */
508         }
509
510         /* Fall back to a bytewise copy.  */
511
512         if (min_length >= 2) {
513                 *dst++ = *src++;
514                 length--;
515         }
516         if (min_length >= 3) {
517                 *dst++ = *src++;
518                 length--;
519         }
520         if (min_length >= 4) {
521                 *dst++ = *src++;
522                 length--;
523         }
524         do {
525                 *dst++ = *src++;
526         } while (--length);
527 }
528
529 #endif /* _WIMLIB_DECOMPRESS_COMMON_H */