LZX, XPRESS decompression: Return 0 bits on overrun
authorEric Biggers <ebiggers3@gmail.com>
Tue, 27 May 2014 16:03:07 +0000 (11:03 -0500)
committerEric Biggers <ebiggers3@gmail.com>
Tue, 27 May 2014 16:19:15 +0000 (11:19 -0500)
If the compressed data is invalid such that the compressed data buffer is
overrun, it's simpler to just return 0 bits instead of explicitly
checking the return value at every call site of bitstream_read_bits() and
read_huffsym().

This doesn't necessarily mean that invalid data will go undetected.  Just
for LZX decompression, chances are there will be another problem if all
0's start being returned (e.g. invalid match or invalid Huffman tree).
For WIM operations like extraction, the uncompressed data is checked with
SHA-1 message digests anyway, so it's virtually impossible for corruption
to go undetected.

Also, the LZMS decompressor already does this.

include/wimlib/decompress_common.h
src/decompress_common.c
src/lzx-decompress.c
src/xpress-decompress.c

index 4b5a7e3..3b20f4e 100644 (file)
@@ -51,22 +51,22 @@ init_input_bitstream(struct input_bitstream *istream,
        istream->data_bytes_left = num_data_bytes;
 }
 
-/* Ensures that the bit buffer variable for the bitstream contains @num_bits
- * bits.
+/* Ensures the bit buffer variable for the bitstream contains at least @num_bits
+ * bits.  Following this, bitstream_peek_bits() and/or bitstream_remove_bits()
+ * may be called on the bitstream to peek or remove up to @num_bits bits.
  *
- * If there are at least @num_bits bits remaining in the bitstream, 0 is
- * returned.  Otherwise, -1 is returned.  */
-static inline int
+ * If the input data is exhausted, any further bits are assumed to be 0.  */
+static inline void
 bitstream_ensure_bits(struct input_bitstream *istream, unsigned num_bits)
 {
        for (int nbits = num_bits; (int)istream->bitsleft < nbits; nbits -= 16) {
                u16 nextword;
                unsigned shift;
 
-               if (unlikely(istream->data_bytes_left < 2))
-                       return -1;
-
-               wimlib_assert2(istream->bitsleft <= sizeof(istream->bitbuf) * 8 - 16);
+               if (unlikely(istream->data_bytes_left < 2)) {
+                       istream->bitsleft = num_bits;
+                       return;
+               }
 
                nextword = le16_to_cpu(*(const le16*)istream->data);
                shift = sizeof(istream->bitbuf) * 8 - 16 - istream->bitsleft;
@@ -74,37 +74,33 @@ bitstream_ensure_bits(struct input_bitstream *istream, unsigned num_bits)
                istream->data += 2;
                istream->bitsleft += 16;
                istream->data_bytes_left -= 2;
-
        }
-       return 0;
 }
 
-/* Returns the next @num_bits bits in the buffer variable, which must contain at
- * least @num_bits bits, for the bitstream.  */
+/* Returns the next @num_bits bits from the bitstream, without removing them.
+ * There must be at least @num_bits remaining in the buffer variable, from a
+ * previous call to bitstream_ensure_bits().  */
 static inline u32
 bitstream_peek_bits(const struct input_bitstream *istream, unsigned num_bits)
 {
-       wimlib_assert2(istream->bitsleft >= num_bits);
-
        if (unlikely(num_bits == 0))
                return 0;
-
        return istream->bitbuf >> (sizeof(istream->bitbuf) * 8 - num_bits);
 }
 
-/* Removes @num_bits bits from the buffer variable, which must contain at least
- * @num_bits bits, for the bitstream.  */
+/* Removes @num_bits from the bitstream.  There must be at least @num_bits
+ * remaining in the buffer variable, from a previous call to
+ * bitstream_ensure_bits().  */
 static inline void
 bitstream_remove_bits(struct input_bitstream *istream, unsigned num_bits)
 {
-       wimlib_assert2(istream->bitsleft >= num_bits);
-
        istream->bitbuf <<= num_bits;
        istream->bitsleft -= num_bits;
 }
 
-/* Gets and removes @num_bits bits from the buffer variable, which must contain
- * at least @num_bits bits, for the bitstream.  */
+/* Removes and returns @num_bits bits from the bitstream.  There must be at
+ * least @num_bits remaining in the buffer variable, from a previous call to
+ * bitstream_ensure_bits().  */
 static inline u32
 bitstream_pop_bits(struct input_bitstream *istream, unsigned num_bits)
 {
@@ -113,66 +109,39 @@ bitstream_pop_bits(struct input_bitstream *istream, unsigned num_bits)
        return n;
 }
 
-/* Reads @num_bits bits from the input bitstream.  On success, returns 0 and
- * returns the requested bits in @n.  If there are fewer than @num_bits
- * remaining in the bitstream, -1 is returned. */
-static inline int
-bitstream_read_bits(struct input_bitstream *istream, unsigned num_bits, u32 *n)
+/* Reads and returns the next @num_bits bits from the bitstream.
+ * If the input data is exhausted, the bits are assumed to be 0.  */
+static inline u32
+bitstream_read_bits(struct input_bitstream *istream, unsigned num_bits)
 {
-       if (unlikely(bitstream_ensure_bits(istream, num_bits)))
-               return -1;
-
-       *n = bitstream_pop_bits(istream, num_bits);
-       return 0;
+       bitstream_ensure_bits(istream, num_bits);
+       return bitstream_pop_bits(istream, num_bits);
 }
 
-/* Return the next literal byte embedded in the bitstream, or -1 if the input
- * was exhausted.  */
-static inline int
+/* Reads and returns the next literal byte embedded in the bitstream.
+ * If the input data is exhausted, the byte is assumed to be 0.  */
+static inline u8
 bitstream_read_byte(struct input_bitstream *istream)
 {
-       if (unlikely(istream->data_bytes_left < 1))
-               return -1;
-
+       if (unlikely(istream->data_bytes_left == 0))
+               return 0;
        istream->data_bytes_left--;
        return *istream->data++;
 }
 
-/* Reads @num_bits bits from the buffer variable for a bistream without checking
- * to see if that many bits are in the buffer or not.  */
-static inline u32
-bitstream_read_bits_nocheck(struct input_bitstream *istream, unsigned num_bits)
-{
-       u32 n = bitstream_peek_bits(istream, num_bits);
-       bitstream_remove_bits(istream, num_bits);
-       return n;
-}
-
-extern int
-read_huffsym_near_end_of_input(struct input_bitstream *istream,
-                              const u16 decode_table[],
-                              const u8 lens[],
-                              unsigned num_syms,
-                              unsigned table_bits,
-                              unsigned *n);
-
-/* Read a Huffman-encoded symbol from a bitstream.  */
-static inline int
+/* Reads and returns the next Huffman-encoded symbol from a bitstream.  If the
+ * input data is exhausted, the Huffman symbol is decoded as if the missing bits
+ * are all zeroes.  */
+static inline u16
 read_huffsym(struct input_bitstream * restrict istream,
             const u16 decode_table[restrict],
             const u8 lens[restrict],
             unsigned num_syms,
             unsigned table_bits,
-            unsigned *restrict n,
             unsigned max_codeword_len)
 {
-       /* If there are fewer bits remaining in the input than the maximum
-        * codeword length, use the slow path that has extra checks.  */
-       if (unlikely(bitstream_ensure_bits(istream, max_codeword_len))) {
-               return read_huffsym_near_end_of_input(istream, decode_table,
-                                                     lens, num_syms,
-                                                     table_bits, n);
-       }
+
+       bitstream_ensure_bits(istream, max_codeword_len);
 
        /* Use the next table_bits of the input as an index into the
         * decode_table.  */
@@ -189,12 +158,10 @@ read_huffsym(struct input_bitstream * restrict istream,
                 * end of the decode table.  */
                bitstream_remove_bits(istream, table_bits);
                do {
-                       key_bits = sym + bitstream_peek_bits(istream, 1);
-                       bitstream_remove_bits(istream, 1);
+                       key_bits = sym + bitstream_pop_bits(istream, 1);
                } while ((sym = decode_table[key_bits]) >= num_syms);
        }
-       *n = sym;
-       return 0;
+       return sym;
 }
 
 extern int
index 767527d..ee85bc9 100644 (file)
@@ -387,46 +387,3 @@ make_huffman_decode_table(u16 *decode_table,  unsigned num_syms,
        }
        return 0;
 }
-
-/* Reads a Huffman-encoded symbol from the bistream when the number of remaining
- * bits is less than the maximum codeword length. */
-int
-read_huffsym_near_end_of_input(struct input_bitstream *istream,
-                              const u16 decode_table[],
-                              const u8 lens[],
-                              unsigned num_syms,
-                              unsigned table_bits,
-                              unsigned *n)
-{
-       unsigned bitsleft = istream->bitsleft;
-       unsigned key_size;
-       u16 sym;
-       u16 key_bits;
-
-       if (table_bits > bitsleft) {
-               key_size = bitsleft;
-               bitsleft = 0;
-               key_bits = bitstream_peek_bits(istream, key_size) <<
-                                               (table_bits - key_size);
-       } else {
-               key_size = table_bits;
-               bitsleft -= table_bits;
-               key_bits = bitstream_peek_bits(istream, table_bits);
-       }
-
-       sym = decode_table[key_bits];
-       if (sym >= num_syms) {
-               bitstream_remove_bits(istream, key_size);
-               do {
-                       if (bitsleft == 0)
-                               return -1;
-                       key_bits = sym + bitstream_peek_bits(istream, 1);
-                       bitstream_remove_bits(istream, 1);
-                       bitsleft--;
-               } while ((sym = decode_table[key_bits]) >= num_syms);
-       } else {
-               bitstream_remove_bits(istream, lens[sym]);
-       }
-       *n = sym;
-       return 0;
-}
index d3ba46a..f7f6661 100644 (file)
@@ -149,49 +149,46 @@ struct lzx_decompressor {
 /*
  * Reads a Huffman-encoded symbol using the pre-tree.
  */
-static inline int
+static inline u16
 read_huffsym_using_pretree(struct input_bitstream *istream,
                           const u16 pretree_decode_table[],
-                          const u8 pretree_lens[], unsigned *n)
+                          const u8 pretree_lens[])
 {
        return read_huffsym(istream, pretree_decode_table, pretree_lens,
-                           LZX_PRECODE_NUM_SYMBOLS, LZX_PRECODE_TABLEBITS, n,
+                           LZX_PRECODE_NUM_SYMBOLS, LZX_PRECODE_TABLEBITS,
                            LZX_MAX_PRE_CODEWORD_LEN);
 }
 
 /* Reads a Huffman-encoded symbol using the main tree. */
-static inline int
+static inline u16
 read_huffsym_using_maintree(struct input_bitstream *istream,
                            const struct lzx_tables *tables,
-                           unsigned *n,
                            unsigned num_main_syms)
 {
        return read_huffsym(istream, tables->maintree_decode_table,
                            tables->maintree_lens, num_main_syms,
-                           LZX_MAINCODE_TABLEBITS, n, LZX_MAX_MAIN_CODEWORD_LEN);
+                           LZX_MAINCODE_TABLEBITS, LZX_MAX_MAIN_CODEWORD_LEN);
 }
 
 /* Reads a Huffman-encoded symbol using the length tree. */
-static inline int
+static inline u16
 read_huffsym_using_lentree(struct input_bitstream *istream,
-                          const struct lzx_tables *tables,
-                          unsigned *n)
+                          const struct lzx_tables *tables)
 {
        return read_huffsym(istream, tables->lentree_decode_table,
                            tables->lentree_lens, LZX_LENCODE_NUM_SYMBOLS,
-                           LZX_LENCODE_TABLEBITS, n, LZX_MAX_LEN_CODEWORD_LEN);
+                           LZX_LENCODE_TABLEBITS, LZX_MAX_LEN_CODEWORD_LEN);
 }
 
 /* Reads a Huffman-encoded symbol using the aligned offset tree. */
-static inline int
+static inline u16
 read_huffsym_using_alignedtree(struct input_bitstream *istream,
-                              const struct lzx_tables *tables,
-                              unsigned *n)
+                              const struct lzx_tables *tables)
 {
        return read_huffsym(istream, tables->alignedtree_decode_table,
                            tables->alignedtree_lens,
                            LZX_ALIGNEDCODE_NUM_SYMBOLS,
-                           LZX_ALIGNEDCODE_TABLEBITS, n,
+                           LZX_ALIGNEDCODE_TABLEBITS,
                            LZX_MAX_ALIGNED_CODEWORD_LEN);
 }
 
@@ -217,17 +214,13 @@ lzx_read_code_lens(struct input_bitstream *istream, u8 lens[],
                                        _aligned_attribute(DECODE_TABLE_ALIGNMENT);
        u8 pretree_lens[LZX_PRECODE_NUM_SYMBOLS];
        unsigned i;
-       u32 len;
        int ret;
 
        /* Read the code lengths of the pretree codes.  There are 20 lengths of
         * 4 bits each. */
        for (i = 0; i < LZX_PRECODE_NUM_SYMBOLS; i++) {
-               ret = bitstream_read_bits(istream, LZX_PRECODE_ELEMENT_SIZE,
-                                         &len);
-               if (ret)
-                       return ret;
-               pretree_lens[i] = len;
+               pretree_lens[i] = bitstream_read_bits(istream,
+                                                     LZX_PRECODE_ELEMENT_SIZE);
        }
 
        /* Make the decoding table for the pretree. */
@@ -256,15 +249,12 @@ lzx_read_code_lens(struct input_bitstream *istream, u8 lens[],
                u32 num_same;
                signed char value;
 
-               ret = read_huffsym_using_pretree(istream, pretree_decode_table,
-                                                pretree_lens, &tree_code);
-               if (ret)
-                       return ret;
+               tree_code = read_huffsym_using_pretree(istream,
+                                                      pretree_decode_table,
+                                                      pretree_lens);
                switch (tree_code) {
                case 17: /* Run of 0's */
-                       ret = bitstream_read_bits(istream, 4, &num_zeroes);
-                       if (ret)
-                               return ret;
+                       num_zeroes = bitstream_read_bits(istream, 4);
                        num_zeroes += 4;
                        while (num_zeroes--) {
                                *lens = 0;
@@ -273,9 +263,7 @@ lzx_read_code_lens(struct input_bitstream *istream, u8 lens[],
                        }
                        break;
                case 18: /* Longer run of 0's */
-                       ret = bitstream_read_bits(istream, 5, &num_zeroes);
-                       if (ret)
-                               return ret;
+                       num_zeroes = bitstream_read_bits(istream, 5);
                        num_zeroes += 20;
                        while (num_zeroes--) {
                                *lens = 0;
@@ -284,16 +272,11 @@ lzx_read_code_lens(struct input_bitstream *istream, u8 lens[],
                        }
                        break;
                case 19: /* Run of identical lengths */
-                       ret = bitstream_read_bits(istream, 1, &num_same);
-                       if (ret)
-                               return ret;
+                       num_same = bitstream_read_bits(istream, 1);
                        num_same += 4;
-                       ret = read_huffsym_using_pretree(istream,
-                                                        pretree_decode_table,
-                                                        pretree_lens,
-                                                        &code);
-                       if (ret)
-                               return ret;
+                       code = read_huffsym_using_pretree(istream,
+                                                         pretree_decode_table,
+                                                         pretree_lens);
                        value = (signed char)*lens - (signed char)code;
                        if (value < 0)
                                value += 17;
@@ -343,38 +326,29 @@ lzx_read_block_header(struct input_bitstream *istream,
        unsigned block_type;
        unsigned block_size;
 
-       ret = bitstream_ensure_bits(istream, 4);
-       if (ret)
-               return ret;
+       bitstream_ensure_bits(istream, 4);
 
        /* The first three bits tell us what kind of block it is, and are one
         * of the LZX_BLOCKTYPE_* values.  */
-       block_type = bitstream_read_bits_nocheck(istream, 3);
+       block_type = bitstream_pop_bits(istream, 3);
 
        /* Read the block size.  This mirrors the behavior
         * lzx_write_compressed_block() in lzx-compress.c; see that for more
         * details.  */
-       if (bitstream_read_bits_nocheck(istream, 1)) {
+       if (bitstream_pop_bits(istream, 1)) {
                block_size = LZX_DEFAULT_BLOCK_SIZE;
        } else {
                u32 tmp;
                block_size = 0;
 
-               ret = bitstream_read_bits(istream, 8, &tmp);
-               if (ret)
-                       return ret;
+               tmp = bitstream_read_bits(istream, 8);
                block_size |= tmp;
-
-               ret = bitstream_read_bits(istream, 8, &tmp);
-               if (ret)
-                       return ret;
+               tmp = bitstream_read_bits(istream, 8);
                block_size <<= 8;
                block_size |= tmp;
 
                if (max_window_size >= 65536) {
-                       ret = bitstream_read_bits(istream, 8, &tmp);
-                       if (ret)
-                               return ret;
+                       tmp = bitstream_read_bits(istream, 8);
                        block_size <<= 8;
                        block_size |= tmp;
                }
@@ -386,14 +360,9 @@ lzx_read_block_header(struct input_bitstream *istream,
                 * then build it. */
 
                for (unsigned i = 0; i < LZX_ALIGNEDCODE_NUM_SYMBOLS; i++) {
-                       u32 len;
-
-                       ret = bitstream_read_bits(istream,
-                                                 LZX_ALIGNEDCODE_ELEMENT_SIZE,
-                                                 &len);
-                       if (ret)
-                               return ret;
-                       tables->alignedtree_lens[i] = len;
+                       tables->alignedtree_lens[i] =
+                               bitstream_read_bits(istream,
+                                                   LZX_ALIGNEDCODE_ELEMENT_SIZE);
                }
 
                LZX_DEBUG("Building the aligned tree.");
@@ -565,12 +534,10 @@ lzx_decode_match(unsigned main_element, int block_type,
        unsigned position_slot;
        unsigned match_len;
        unsigned match_offset;
-       unsigned additional_len;
        unsigned num_extra_bits;
        u32 verbatim_bits;
        u32 aligned_bits;
        unsigned i;
-       int ret;
        u8 *match_dest;
        u8 *match_src;
 
@@ -589,14 +556,8 @@ lzx_decode_match(unsigned main_element, int block_type,
         * the length tree, offset by 9 (LZX_MIN_MATCH_LEN +
         * LZX_NUM_PRIMARY_LENS) */
        match_len = LZX_MIN_MATCH_LEN + length_header;
-       if (length_header == LZX_NUM_PRIMARY_LENS) {
-               ret = read_huffsym_using_lentree(istream, tables,
-                                                &additional_len);
-               if (ret)
-                       return ret;
-               match_len += additional_len;
-       }
-
+       if (length_header == LZX_NUM_PRIMARY_LENS)
+               match_len += read_huffsym_using_lentree(istream, tables);
 
        /* If the position_slot is 0, 1, or 2, the match offset is retrieved
         * from the LRU queue.  Otherwise, the match offset is not in the LRU
@@ -639,27 +600,17 @@ lzx_decode_match(unsigned main_element, int block_type,
                         * equal to 3.  (Note that in the case with
                         * num_extra_bits == 3, the assignment to verbatim_bits
                         * will just set it to 0. ) */
-                       ret = bitstream_read_bits(istream, num_extra_bits - 3,
-                                                 &verbatim_bits);
-                       if (ret)
-                               return ret;
-
+                       verbatim_bits = bitstream_read_bits(istream,
+                                                           num_extra_bits - 3);
                        verbatim_bits <<= 3;
-
-                       ret = read_huffsym_using_alignedtree(istream, tables,
-                                                            &aligned_bits);
-                       if (ret)
-                               return ret;
+                       aligned_bits = read_huffsym_using_alignedtree(istream,
+                                                                     tables);
                } else {
                        /* For non-aligned blocks, or for aligned blocks with
                         * less than 3 extra bits, the extra bits are added
                         * directly to the match offset, and the correction for
                         * the alignment is taken to be 0. */
-                       ret = bitstream_read_bits(istream, num_extra_bits,
-                                                 &verbatim_bits);
-                       if (ret)
-                               return ret;
-
+                       verbatim_bits = bitstream_read_bits(istream, num_extra_bits);
                        aligned_bits = 0;
                }
 
@@ -678,13 +629,13 @@ lzx_decode_match(unsigned main_element, int block_type,
         * currently in use, then copy the source of the match to the current
         * position. */
 
-       if (match_len > bytes_remaining) {
+       if (unlikely(match_len > bytes_remaining)) {
                LZX_DEBUG("Match of length %u bytes overflows "
                          "uncompressed block size", match_len);
                return -1;
        }
 
-       if (match_offset > window_pos) {
+       if (unlikely(match_offset > window_pos)) {
                LZX_DEBUG("Match of length %u bytes references "
                          "data before window (match_offset = %u, "
                          "window_pos = %u)",
@@ -879,17 +830,12 @@ lzx_decompress_block(int block_type, unsigned block_size,
 {
        unsigned main_element;
        unsigned end;
-       int ret;
        int match_len;
 
        end = window_pos + block_size;
        while (window_pos < end) {
-               ret = read_huffsym_using_maintree(istream, tables,
-                                                 &main_element,
-                                                 num_main_syms);
-               if (ret)
-                       return ret;
-
+               main_element = read_huffsym_using_maintree(istream, tables,
+                                                          num_main_syms);
                if (main_element < LZX_NUM_CHARS) {
                        /* literal: 0 to LZX_NUM_CHARS - 1 */
                        window[window_pos++] = main_element;
@@ -903,7 +849,7 @@ lzx_decompress_block(int block_type, unsigned block_size,
                                                     tables,
                                                     queue,
                                                     istream);
-                       if (match_len < 0)
+                       if (unlikely(match_len < 0))
                                return match_len;
                        window_pos += match_len;
                }
index 233e935..98ad03d 100644 (file)
@@ -92,7 +92,6 @@ xpress_decode_match(unsigned sym, input_idx_t window_pos,
 
        u8 len_hdr;
        u8 offset_bsr;
-       int ret;
        u8 *match_dest;
        u8 *match_src;
        unsigned i;
@@ -103,27 +102,15 @@ xpress_decode_match(unsigned sym, input_idx_t window_pos,
        len_hdr = sym & 0xf;
        offset_bsr = sym >> 4;
 
-       if (bitstream_ensure_bits(istream, 16))
-               return -1;
+       bitstream_ensure_bits(istream, 16);
 
        match_offset = (1U << offset_bsr) | bitstream_pop_bits(istream, offset_bsr);
 
        if (len_hdr == 0xf) {
-               ret = bitstream_read_byte(istream);
-               if (ret < 0)
-                       return ret;
-               match_len = ret;
+               match_len = bitstream_read_byte(istream);
                if (unlikely(match_len == 0xff)) {
-                       ret = bitstream_read_byte(istream);
-                       if (ret < 0)
-                               return ret;
-                       match_len = ret;
-
-                       ret = bitstream_read_byte(istream);
-                       if (ret < 0)
-                               return ret;
-
-                       match_len |= (ret << 8);
+                       match_len = bitstream_read_byte(istream);
+                       match_len |= (unsigned)bitstream_read_byte(istream) << 8;
                } else {
                        match_len += 0xf;
                }
@@ -167,14 +154,11 @@ xpress_lz_decode(struct input_bitstream * restrict istream,
                unsigned sym;
                int ret;
 
-               if (unlikely(bitstream_ensure_bits(istream, 16)))
-                       return -1;
-
-               if (unlikely(read_huffsym(istream, decode_table, lens,
-                                         XPRESS_NUM_SYMBOLS, XPRESS_TABLEBITS,
-                                         &sym, XPRESS_MAX_CODEWORD_LEN)))
-                       return -1;
+               bitstream_ensure_bits(istream, 16);
 
+               sym = read_huffsym(istream, decode_table, lens,
+                                  XPRESS_NUM_SYMBOLS, XPRESS_TABLEBITS,
+                                  XPRESS_MAX_CODEWORD_LEN);
                if (sym < XPRESS_NUM_CHARS) {
                        /* Literal  */
                        uncompressed_data[curpos] = sym;