From f308322e2f405932116152bc01754bfca7635003 Mon Sep 17 00:00:00 2001 From: Eric Biggers Date: Fri, 26 Oct 2012 12:57:14 -0500 Subject: [PATCH] xpress_compress(): Initialize all data written - Make sure the flushing of the bitstream and the finishing of the compressed data is done properly. - Fix indentation and whitespace. - Zero out chunk table before writing it. --- src/comp.c | 30 ++++++++------- src/comp.h | 22 +++++------ src/decomp.c | 62 +++++++++++++++--------------- src/decomp.h | 30 +++++++-------- src/resource.c | 10 ++--- src/xpress-comp.c | 91 +++++++++++++++++++++++++++------------------ src/xpress-decomp.c | 47 ++++++++++++----------- 7 files changed, 159 insertions(+), 133 deletions(-) diff --git a/src/comp.c b/src/comp.c index 7837a461..559abbe4 100644 --- a/src/comp.c +++ b/src/comp.c @@ -38,7 +38,7 @@ static inline void flush_bits(struct output_bitstream *ostream) /* Writes @num_bits bits, given by the @num_bits least significant bits of * @bits, to the output @ostream. */ -int bitstream_put_bits(struct output_bitstream *ostream, output_bitbuf_t bits, +int bitstream_put_bits(struct output_bitstream *ostream, output_bitbuf_t bits, uint num_bits) { uint rem_bits; @@ -49,7 +49,7 @@ int bitstream_put_bits(struct output_bitstream *ostream, output_bitbuf_t bits, ostream->free_bits -= num_bits; } else { - if (ostream->num_bytes_remaining + (ostream->output - + if (ostream->num_bytes_remaining + (ostream->output - ostream->bit_output) < 2) return 1; @@ -75,7 +75,7 @@ int bitstream_put_bits(struct output_bitstream *ostream, output_bitbuf_t bits, /* Flushes any remaining bits in the output buffer to the output byte stream. */ int flush_output_bitstream(struct output_bitstream *ostream) { - if (ostream->num_bytes_remaining + (ostream->output - + if (ostream->num_bytes_remaining + (ostream->output - ostream->bit_output) < 2) return 1; if (ostream->free_bits != 16) { @@ -87,9 +87,11 @@ int flush_output_bitstream(struct output_bitstream *ostream) /* Initializes an output bit buffer to write its output to the memory location * pointer to by @data. */ -void init_output_bitstream(struct output_bitstream *ostream, void *data, +void init_output_bitstream(struct output_bitstream *ostream, void *data, uint num_bytes) { + wimlib_assert(num_bytes >= 4); + ostream->bitbuf = 0; ostream->free_bits = 16; ostream->bit_output = (u8*)data; @@ -166,7 +168,7 @@ static void huffman_tree_compute_path_lengths(HuffmanNode *node, u16 cur_len) } } -/* Creates a canonical Huffman code from an array of symbol frequencies. +/* Creates a canonical Huffman code from an array of symbol frequencies. * * The algorithm used is similar to the well-known algorithm that builds a * Huffman tree using a minheap. In that algorithm, the leaf nodes are @@ -217,12 +219,12 @@ static void huffman_tree_compute_path_lengths(HuffmanNode *node, u16 cur_len) * codewords for each symbol will be written. * * @codewords: An array of @num_syms short integers into which the - * codewords for each symbol will be written. The first - * lens[i] bits of codewords[i] will contain the codeword + * codewords for each symbol will be written. The first + * lens[i] bits of codewords[i] will contain the codeword * for symbol i. */ -void make_canonical_huffman_code(uint num_syms, uint max_codeword_len, - const u32 freq_tab[], u8 lens[], +void make_canonical_huffman_code(uint num_syms, uint max_codeword_len, + const u32 freq_tab[], u8 lens[], u16 codewords[]) { /* We require at least 2 possible symbols in the alphabet to produce a @@ -301,7 +303,7 @@ void make_canonical_huffman_code(uint num_syms, uint max_codeword_len, * most num_used_symbols - 1 intermediate nodes when creating a Huffman * code. This is because if there were at least num_used_symbols nodes, * the code would be suboptimal because there would be at least one - * unnecessary intermediate node. + * unnecessary intermediate node. * * The worst case (greatest number of intermediate nodes) would be if * all the intermediate nodes were chained together. This results in @@ -348,7 +350,7 @@ try_building_tree_again: while (1) { /* Lowest frequency node. */ - HuffmanNode *f1 = NULL; + HuffmanNode *f1 = NULL; /* Second lowest frequency node. */ HuffmanNode *f2 = NULL; @@ -357,14 +359,14 @@ try_building_tree_again: * the remaining leaves or from the intermediate nodes. * */ - if (cur_leaf != end_leaf && (cur_inode == next_inode || + if (cur_leaf != end_leaf && (cur_inode == next_inode || cur_leaf->freq <= cur_inode->freq)) { f1 = (HuffmanNode*)cur_leaf++; } else if (cur_inode != next_inode) { f1 = cur_inode++; } - if (cur_leaf != end_leaf && (cur_inode == next_inode || + if (cur_leaf != end_leaf && (cur_inode == next_inode || cur_leaf->freq <= cur_inode->freq)) { f2 = (HuffmanNode*)cur_leaf++; } else if (cur_inode != next_inode) { @@ -403,7 +405,7 @@ try_building_tree_again: if (leaves[i].freq > 1) leaves[i].freq >>= 1; goto try_building_tree_again; - } + } next_inode++; } diff --git a/src/comp.h b/src/comp.h index 7b5f6c6b..a40dbca0 100644 --- a/src/comp.h +++ b/src/comp.h @@ -35,9 +35,9 @@ struct output_bitstream { static inline int bitstream_put_byte(struct output_bitstream *ostream, - u8 n) + u8 n) { - if (ostream->num_bytes_remaining == 0) + if (ostream->num_bytes_remaining < 1) return 1; *ostream->output = n; ostream->output++; @@ -66,30 +66,30 @@ struct lz_params { uint max_lazy_match; uint too_far; }; - + typedef uint (*lz_record_match_t)(uint, uint, void *, void *); typedef uint (*lz_record_literal_t)(u8, void *); -extern uint lz_analyze_block(const u8 uncompressed_data[], +extern uint lz_analyze_block(const u8 uncompressed_data[], uint uncompressed_len, - u32 match_tab[], + u32 match_tab[], lz_record_match_t record_match, - lz_record_literal_t record_literal, + lz_record_literal_t record_literal, void *record_match_arg1, - void *record_match_arg2, + void *record_match_arg2, void *record_literal_arg, const struct lz_params *params); -extern int bitstream_put_bits(struct output_bitstream *ostream, +extern int bitstream_put_bits(struct output_bitstream *ostream, output_bitbuf_t bits, unsigned num_bits); extern void init_output_bitstream(struct output_bitstream *ostream, - void *data, unsigned num_bytes); + void *data, unsigned num_bytes); extern int flush_output_bitstream(struct output_bitstream *ostream); -extern void make_canonical_huffman_code(uint num_syms, uint max_codeword_len, - const u32 freq_tab[], u8 lens[], +extern void make_canonical_huffman_code(uint num_syms, uint max_codeword_len, + const u32 freq_tab[], u8 lens[], u16 codewords[]); #endif /* _WIMLIB_COMP_H */ diff --git a/src/decomp.c b/src/decomp.c index 3440971b..b1a7043e 100644 --- a/src/decomp.c +++ b/src/decomp.c @@ -61,7 +61,7 @@ int bitstream_read_bytes(struct input_bitstream *stream, size_t n, void *dest) if ((n & 1) && stream->data_bytes_left != 0) { stream->bitsleft = 8; stream->data_bytes_left--; - stream->bitbuf |= (input_bitbuf_t)(*stream->data) << + stream->bitbuf |= (input_bitbuf_t)(*stream->data) << (sizeof(input_bitbuf_t) * 8 - 8); stream->data++; } @@ -76,7 +76,7 @@ int bitstream_read_bytes(struct input_bitstream *stream, size_t n, void *dest) * and length fields of an uncompressed block, however; it does not apply when * realigning the stream after the end of the uncompressed block. */ -int align_input_bitstream(struct input_bitstream *stream, +int align_input_bitstream(struct input_bitstream *stream, bool skip_word_if_aligned) { int ret; @@ -96,7 +96,7 @@ int align_input_bitstream(struct input_bitstream *stream, return 0; } -/* +/* * Builds a fast huffman decoding table from a canonical huffman code lengths * table. Based on code written by David Tritscher. * @@ -109,7 +109,7 @@ int align_input_bitstream(struct input_bitstream *stream, * * @num_bits: Any symbols with a code length of num_bits or less can be * decoded in one lookup of the table. 2**num_bits - * must be greater than or equal to @num_syms if there are + * must be greater than or equal to @num_syms if there are * any Huffman codes longer than @num_bits. * * @lens: An array of length @num_syms, indexable by symbol, that @@ -124,7 +124,7 @@ int align_input_bitstream(struct input_bitstream *stream, * valid Huffman tree, or if there are codes of length greater than @num_bits * but 2**num_bits < num_syms. * - * What exactly is the format of the fast Huffman decoding table? The first + * What exactly is the format of the fast Huffman decoding table? The first * (1 << num_bits) entries of the table are indexed by chunks of the input of * size @num_bits. If the next Huffman code in the input happens to have a * length of exactly @num_bits, the symbol is simply read directly from the @@ -154,8 +154,8 @@ int align_input_bitstream(struct input_bitstream *stream, * entries from pointers by the fact that values less than @num_syms must be * symbol values. */ -int make_huffman_decode_table(u16 decode_table[], uint num_syms, - uint num_bits, const u8 lens[], +int make_huffman_decode_table(u16 decode_table[], uint num_syms, + uint num_bits, const u8 lens[], uint max_code_len) { /* Number of entries in the decode table. */ @@ -244,8 +244,8 @@ int make_huffman_decode_table(u16 decode_table[], uint num_syms, /* Go through every codeword of length greater than @num_bits. Note: * the LZX format guarantees that the codeword length can be at most 16 * bits. */ - for (uint code_len = num_bits + 1; code_len <= max_code_len; - code_len++) + for (uint code_len = num_bits + 1; code_len <= max_code_len; + code_len++) { current_code <<= 1; for (uint sym = 0; sym < num_syms; sym++) { @@ -283,7 +283,7 @@ int make_huffman_decode_table(u16 decode_table[], uint num_syms, * otherwise, go right (by incrementing i by 1) */ int bit_pos = code_len - bit_num; - int bit = (current_code & (1 << bit_pos)) >> + int bit = (current_code & (1 << bit_pos)) >> bit_pos; i += bit; } @@ -294,7 +294,7 @@ int make_huffman_decode_table(u16 decode_table[], uint num_syms, /* Increment decode_table_pos only if the prefix of the * Huffman code changes. */ - if (current_code >> (code_len - num_bits) != + if (current_code >> (code_len - num_bits) != (current_code + 1) >> (code_len - num_bits)) decode_table_pos++; @@ -329,11 +329,11 @@ int make_huffman_decode_table(u16 decode_table[], uint num_syms, /* Reads a Huffman-encoded symbol when it is known there are less than * MAX_CODE_LEN bits remaining in the bitstream. */ -static int read_huffsym_near_end_of_input(struct input_bitstream *istream, - const u16 decode_table[], - const u8 lens[], - uint num_syms, - uint table_bits, +static int read_huffsym_near_end_of_input(struct input_bitstream *istream, + const u16 decode_table[], + const u8 lens[], + uint num_syms, + uint table_bits, uint *n) { uint bitsleft = istream->bitsleft; @@ -344,7 +344,7 @@ static int read_huffsym_near_end_of_input(struct input_bitstream *istream, if (table_bits > bitsleft) { key_size = bitsleft; bitsleft = 0; - key_bits = bitstream_peek_bits(istream, key_size) << + key_bits = bitstream_peek_bits(istream, key_size) << (table_bits - key_size); } else { key_size = table_bits; @@ -371,7 +371,7 @@ static int read_huffsym_near_end_of_input(struct input_bitstream *istream, return 0; } -/* +/* * Reads a Huffman-encoded symbol from a bitstream. * * This function may be called hundreds of millions of times when extracting a @@ -385,17 +385,17 @@ static int read_huffsym_near_end_of_input(struct input_bitstream *istream, * @lengths: The table that gives the length of the code for each * symbol. * @num_symbols: The number of symbols in the Huffman code. - * @table_bits: Huffman codes this length or less can be looked up + * @table_bits: Huffman codes this length or less can be looked up * directory in the decode_table, as the * decode_table contains 2**table_bits entries. */ -int read_huffsym(struct input_bitstream *stream, - const u16 decode_table[], - const u8 lengths[], - unsigned num_symbols, - unsigned table_bits, - uint *n, - unsigned max_codeword_len) +int read_huffsym(struct input_bitstream *stream, + const u16 decode_table[], + const u8 lengths[], + unsigned num_symbols, + unsigned table_bits, + uint *n, + unsigned max_codeword_len) { /* In the most common case, there are at least max_codeword_len bits * remaining in the stream. */ @@ -415,7 +415,7 @@ int read_huffsym(struct input_bitstream *stream, key_bits = sym + bitstream_peek_bits(stream, 1); bitstream_remove_bits(stream, 1); - wimlib_assert(key_bits < num_symbols * 2 + + wimlib_assert(key_bits < num_symbols * 2 + (1 << table_bits)); } while ((sym = decode_table[key_bits]) >= num_symbols); } else { @@ -426,9 +426,9 @@ int read_huffsym(struct input_bitstream *stream, return 0; } else { /* Otherwise, we must be careful to use only the bits that are - * actually remaining. Don't inline this part since it is very - * rarely used. */ - return read_huffsym_near_end_of_input(stream, decode_table, lengths, - num_symbols, table_bits, n); + * actually remaining. */ + return read_huffsym_near_end_of_input(stream, decode_table, + lengths, num_symbols, + table_bits, n); } } diff --git a/src/decomp.h b/src/decomp.h index 72cf1657..3ecf9d2f 100644 --- a/src/decomp.h +++ b/src/decomp.h @@ -32,7 +32,7 @@ struct input_bitstream { }; /* Initializes a bitstream to receive its input from @data. */ -static inline void init_input_bitstream(struct input_bitstream *istream, +static inline void init_input_bitstream(struct input_bitstream *istream, const void *data, uint num_data_bytes) { istream->bitbuf = 0; @@ -42,8 +42,8 @@ static inline void init_input_bitstream(struct input_bitstream *istream, } /* Ensures that the bit buffer contains @num_bits bits. */ -static inline int bitstream_ensure_bits(struct input_bitstream *istream, - uint num_bits) +static inline int bitstream_ensure_bits(struct input_bitstream *istream, + uint num_bits) { wimlib_assert(num_bits <= 16); @@ -62,7 +62,7 @@ static inline int bitstream_ensure_bits(struct input_bitstream *istream, if (istream->data_bytes_left < 2) return 1; - uint shift = sizeof(input_bitbuf_t) * 8 - 16 - + uint shift = sizeof(input_bitbuf_t) * 8 - 16 - istream->bitsleft; istream->bitbuf |= (input_bitbuf_t)le16_to_cpu( *(u16*)istream->data) << shift; @@ -75,7 +75,7 @@ static inline int bitstream_ensure_bits(struct input_bitstream *istream, /* Returns the next @num_bits bits in the bit buffer. It must contain at least * @num_bits bits to call this function. */ -static inline uint bitstream_peek_bits(const struct input_bitstream *istream, +static inline uint bitstream_peek_bits(const struct input_bitstream *istream, uint num_bits) { if (num_bits == 0) @@ -85,7 +85,7 @@ static inline uint bitstream_peek_bits(const struct input_bitstream *istream, /* Removes @num_bits bits from the bit buffer. It must contain at least * @num_bits bits to call this function. */ -static inline void bitstream_remove_bits(struct input_bitstream *istream, +static inline void bitstream_remove_bits(struct input_bitstream *istream, uint num_bits) { istream->bitbuf <<= num_bits; @@ -93,8 +93,8 @@ static inline void bitstream_remove_bits(struct input_bitstream *istream, } /* Reads and returns @num_bits bits from the input bitstream. */ -static inline int bitstream_read_bits(struct input_bitstream *istream, - uint num_bits, uint *n) +static inline int bitstream_read_bits(struct input_bitstream *istream, + uint num_bits, uint *n) { int ret; ret = bitstream_ensure_bits(istream, num_bits); @@ -111,7 +111,7 @@ static inline int bitstream_read_bits(struct input_bitstream *istream, * compressed bitstream. These bytes are basically separate from the bitstream, * as they come AFTER the bits that are currently in the buffer variable (based * on reading 16 bits at a time), even though the buffer variable may not be - * empty. + * empty. * * This function returns the next such literal length byte in the input * bitstream. Returns -1 if we are at the end of the bitstream. */ @@ -129,7 +129,7 @@ static inline int bitstream_read_byte(struct input_bitstream *istream) /* Reads @num_bits bits from the bit buffer without checking to see if that many * bits are in the buffer or not. */ -static inline uint bitstream_read_bits_nocheck(struct input_bitstream *istream, +static inline uint bitstream_read_bits_nocheck(struct input_bitstream *istream, uint num_bits) { uint n = bitstream_peek_bits(istream, num_bits); @@ -142,16 +142,16 @@ static inline void flush_input_bitstream(struct input_bitstream *istream) { bitstream_remove_bits(istream, istream->bitsleft); istream->bitsleft = 0; - istream->bitbuf = 0; + istream->bitbuf = 0; } -extern int bitstream_read_bytes(struct input_bitstream *istream, size_t n, +extern int bitstream_read_bytes(struct input_bitstream *istream, size_t n, void *dest); -extern int align_input_bitstream(struct input_bitstream *istream, +extern int align_input_bitstream(struct input_bitstream *istream, bool skip_word_if_aligned); -extern int read_huffsym(struct input_bitstream *stream, +extern int read_huffsym(struct input_bitstream *stream, const u16 decode_table[], const u8 lengths[], unsigned num_symbols, @@ -159,7 +159,7 @@ extern int read_huffsym(struct input_bitstream *stream, uint *n, unsigned max_codeword_len); -extern int make_huffman_decode_table(u16 decode_table[], uint num_syms, +extern int make_huffman_decode_table(u16 decode_table[], uint num_syms, uint num_bits, const u8 lengths[], uint max_codeword_len); diff --git a/src/resource.c b/src/resource.c index cd098f92..d35d936a 100644 --- a/src/resource.c +++ b/src/resource.c @@ -554,11 +554,10 @@ begin_wim_resource_chunk_tab(const struct lookup_table_entry *lte, { u64 size = wim_resource_size(lte); u64 num_chunks = (size + WIM_CHUNK_SIZE - 1) / WIM_CHUNK_SIZE; - struct chunk_table *chunk_tab = MALLOC(sizeof(struct chunk_table) + - num_chunks * sizeof(u64)); - int ret = 0; + size_t alloc_size = sizeof(struct chunk_table) + num_chunks * sizeof(u64); + struct chunk_table *chunk_tab = CALLOC(1, alloc_size); + int ret; - wimlib_assert(size != 0); if (!chunk_tab) { ERROR("Failed to allocate chunk table for %"PRIu64" byte " @@ -583,8 +582,9 @@ begin_wim_resource_chunk_tab(const struct lookup_table_entry *lte, goto out; } - *chunk_tab_ret = chunk_tab; + ret = 0; out: + *chunk_tab_ret = chunk_tab; return ret; } diff --git a/src/xpress-comp.c b/src/xpress-comp.c index 057e016f..aae2a4a0 100644 --- a/src/xpress-comp.c +++ b/src/xpress-comp.c @@ -46,13 +46,13 @@ static inline u32 bsr32(u32 n) } -/* +/* * Writes @match, which is a match given in the intermediate representation for * XPRESS matches, to the output stream @ostream. * * @codewords and @lens provide the Huffman code that is being used. */ -static int xpress_write_match(struct output_bitstream *ostream, u32 match, +static int xpress_write_match(struct output_bitstream *ostream, u32 match, const u16 codewords[], const u8 lens[]) { uint main_sym; @@ -87,14 +87,14 @@ static int xpress_write_match(struct output_bitstream *ostream, u32 match, return ret; } } - return bitstream_put_bits(ostream, match_offset ^ (1 << offset_bsr), + return bitstream_put_bits(ostream, match_offset ^ (1 << offset_bsr), offset_bsr); } -static int xpress_write_compressed_literals(struct output_bitstream *ostream, - const u32 match_tab[], +static int xpress_write_compressed_literals(struct output_bitstream *ostream, + const u32 match_tab[], uint num_matches, - const u16 codewords[], + const u16 codewords[], const u8 lens[]) { uint i; @@ -103,19 +103,14 @@ static int xpress_write_compressed_literals(struct output_bitstream *ostream, for (i = 0; i < num_matches; i++) { match = match_tab[i]; - if (match >= XPRESS_NUM_CHARS) { - /* match */ - ret = xpress_write_match(ostream, match, codewords, + if (match >= XPRESS_NUM_CHARS) /* match */ + ret = xpress_write_match(ostream, match, codewords, lens); - if (ret != 0) - return ret; - } else { - /* literal byte */ - ret = bitstream_put_bits(ostream, codewords[match], + else /* literal byte */ + ret = bitstream_put_bits(ostream, codewords[match], lens[match]); - if (ret != 0) - return ret; - } + if (ret != 0) + return ret; } return bitstream_put_bits(ostream, codewords[256], lens[256]); } @@ -127,16 +122,16 @@ static u32 xpress_record_literal(u8 literal, void *__freq_tab) return literal; } -static u32 xpress_record_match(uint match_offset, uint match_len, - void *__freq_tab, void *ignore) +static u32 xpress_record_match(uint match_offset, uint match_len, + void *__freq_tab, void *ignore) { u32 *freq_tab = __freq_tab; u32 len_hdr; u32 offset_bsr; u32 match; - wimlib_assert(match_len >= XPRESS_MIN_MATCH && - match_len <= XPRESS_MAX_MATCH); + wimlib_assert(match_len >= XPRESS_MIN_MATCH && + match_len <= XPRESS_MAX_MATCH); wimlib_assert(match_offset > 0); len_hdr = min(match_len - XPRESS_MIN_MATCH, 15); @@ -158,7 +153,7 @@ static const struct lz_params xpress_lz_params = { .too_far = 4096, }; -/* +/* * Performs XPRESS compression on a block of data. * * @__uncompressed_data: Pointer to the data to be compressed. @@ -173,7 +168,7 @@ static const struct lz_params xpress_lz_params = { * @compressed_data and @compressed_len_ret will contain the compressed data and * its length. A return value of nonzero means that compressing the data did * not reduce its size, and @compressed_data will not contain the full - * compressed data. + * compressed data. */ int xpress_compress(const void *__uncompressed_data, uint uncompressed_len, void *__compressed_data, uint *compressed_len_ret) @@ -197,18 +192,18 @@ int xpress_compress(const void *__uncompressed_data, uint uncompressed_len, ZERO_ARRAY(freq_tab); - num_matches = lz_analyze_block(uncompressed_data, uncompressed_len, - match_tab, xpress_record_match, - xpress_record_literal, freq_tab, - NULL, freq_tab, - &xpress_lz_params); + num_matches = lz_analyze_block(uncompressed_data, uncompressed_len, + match_tab, xpress_record_match, + xpress_record_literal, freq_tab, + NULL, freq_tab, + &xpress_lz_params); XPRESS_DEBUG("using %u matches", num_matches); freq_tab[256]++; make_canonical_huffman_code(XPRESS_NUM_SYMBOLS, XPRESS_MAX_CODEWORD_LEN, - freq_tab, lens, codewords); + freq_tab, lens, codewords); /* IMPORTANT NOTE: * @@ -217,7 +212,7 @@ int xpress_compress(const void *__uncompressed_data, uint uncompressed_len, * bitstream_put_bits() will output 2 bytes at a time in little-endian * order, which is the order that is needed for the compressed literals. * However, the bytes in the lengths table are in order, so they need to - * be written one at a time without using bitstream_put_bits(). + * be written one at a time without using bitstream_put_bits(). * * Because of this, init_output_bitstream() is not called until after * the lengths table is output. @@ -225,20 +220,44 @@ int xpress_compress(const void *__uncompressed_data, uint uncompressed_len, for (i = 0; i < XPRESS_NUM_SYMBOLS; i += 2) *compressed_data++ = (lens[i] & 0xf) | (lens[i + 1] << 4); - init_output_bitstream(&ostream, compressed_data, uncompressed_len - - XPRESS_NUM_SYMBOLS / 2 - 1); + init_output_bitstream(&ostream, compressed_data, + uncompressed_len - XPRESS_NUM_SYMBOLS / 2 - 1); - ret = xpress_write_compressed_literals(&ostream, match_tab, num_matches, - codewords, lens); + ret = xpress_write_compressed_literals(&ostream, match_tab, + num_matches, codewords, lens); if (ret != 0) return ret; + /* Flush any bits that are buffered. */ ret = flush_output_bitstream(&ostream); if (ret != 0) return ret; + /* Assert that there are no output bytes between the ostream.output + * pointer and the ostream.next_bit_output pointer. This can only + * happen if bytes had been written at the ostream.output pointer before + * the last bit word was written to the stream. But, this does not + * occur since xpress_write_match() always finishes by writing some bits + * (a Huffman symbol), and the bitstream was just flushed. */ + wimlib_assert(ostream.output - ostream.next_bit_output == 2); + + /* + * The length of the compressed data is supposed to be the value of the + * ostream.output pointer before flushing, which is now the + * output.next_bit_output pointer after flushing. + * + * There will be an extra 2 bytes at the ostream.bit_output pointer, + * which is zeroed out. (These 2 bytes may be either the last bytes in + * the compressed data, in which case they are actually unnecessary, or + * they may precede a number of bytes embedded into the bitstream.) + */ + if (ostream.bit_output > + (const u8*)__compressed_data + uncompressed_len - 3) + return 1; + *(u16*)ostream.bit_output = cpu_to_le16(0); + compressed_len = ostream.next_bit_output - (const u8*)__compressed_data; - compressed_len = ostream.output - (u8*)__compressed_data; + wimlib_assert(compressed_len <= uncompressed_len - 1); XPRESS_DEBUG("Compressed %u => %u bytes", uncompressed_len, compressed_len); @@ -249,7 +268,7 @@ int xpress_compress(const void *__uncompressed_data, uint uncompressed_len, /* Verify that we really get the same thing back when decompressing. */ XPRESS_DEBUG("Verifying the compressed data."); u8 buf[uncompressed_len]; - ret = xpress_decompress(__compressed_data, compressed_len, buf, + ret = xpress_decompress(__compressed_data, compressed_len, buf, uncompressed_len); if (ret != 0) { ERROR("xpress_compress(): Failed to decompress data we " diff --git a/src/xpress-decomp.c b/src/xpress-decomp.c index 89e006d3..45b52350 100644 --- a/src/xpress-decomp.c +++ b/src/xpress-decomp.c @@ -83,7 +83,7 @@ /* Decodes @huffsym, a value >= XPRESS_NUM_CHARS, that is the header of a match. * */ -static int xpress_decode_match(int huffsym, uint window_pos, uint window_len, +static int xpress_decode_match(int huffsym, uint window_pos, uint window_len, u8 window[], struct input_bitstream *istream) { uint match_len; @@ -157,41 +157,45 @@ static int xpress_decode_match(int huffsym, uint window_pos, uint window_len, /* Decodes the Huffman-encoded matches and literal bytes in a block of * XPRESS-encoded data. */ -static int xpress_decompress_literals(struct input_bitstream *istream, - u8 uncompressed_data[], - uint uncompressed_len, - const u8 lens[], +static int xpress_decompress_literals(struct input_bitstream *istream, + u8 uncompressed_data[], + uint uncompressed_len, + const u8 lens[], const u16 decode_table[]) { uint curpos = 0; uint huffsym; int match_len; - int ret; + int ret = 0; while (curpos < uncompressed_len) { - ret = read_huffsym(istream, decode_table, lens, - XPRESS_NUM_SYMBOLS, XPRESS_TABLEBITS, &huffsym, - XPRESS_MAX_CODEWORD_LEN); + ret = read_huffsym(istream, decode_table, lens, + XPRESS_NUM_SYMBOLS, XPRESS_TABLEBITS, + &huffsym, XPRESS_MAX_CODEWORD_LEN); if (ret != 0) - return ret; + break; if (huffsym < XPRESS_NUM_CHARS) { uncompressed_data[curpos++] = huffsym; } else { - match_len = xpress_decode_match(huffsym, curpos, - uncompressed_len, - uncompressed_data, istream); - if (match_len == -1) - return 1; + match_len = xpress_decode_match(huffsym, + curpos, + uncompressed_len, + uncompressed_data, + istream); + if (match_len == -1) { + ret = 1; + break; + } curpos += match_len; } } - return 0; + return ret; } -int xpress_decompress(const void *__compressed_data, uint compressed_len, - void *uncompressed_data, uint uncompressed_len) +int xpress_decompress(const void *__compressed_data, uint compressed_len, + void *uncompressed_data, uint uncompressed_len) { u8 lens[XPRESS_NUM_SYMBOLS]; u16 decode_table[(1 << XPRESS_TABLEBITS) + 2 * XPRESS_NUM_SYMBOLS]; @@ -225,9 +229,10 @@ int xpress_decompress(const void *__compressed_data, uint compressed_len, if (ret != 0) return ret; - init_input_bitstream(&istream, compressed_data + XPRESS_NUM_SYMBOLS / 2, + init_input_bitstream(&istream, compressed_data + XPRESS_NUM_SYMBOLS / 2, compressed_len - XPRESS_NUM_SYMBOLS / 2); - return xpress_decompress_literals(&istream, uncompressed_data, - uncompressed_len, lens, decode_table); + return xpress_decompress_literals(&istream, uncompressed_data, + uncompressed_len, lens, + decode_table); } -- 2.43.0