+#include "wimlib/util.h"
+
+#include <limits.h>
+
+#define LZMS_DECODE_TABLE_BITS 10
+
+/* Structure used for range decoding, reading bits forwards. This is the first
+ * logical bitstream mentioned above. */
+struct lzms_range_decoder_raw {
+ /* The relevant part of the current range. Although the logical range
+ * for range decoding is a very large integer, only a small portion
+ * matters at any given time, and it can be normalized (shifted left)
+ * whenever it gets too small. */
+ u32 range;
+
+ /* The current position in the range encoded by the portion of the input
+ * read so far. */
+ u32 code;
+
+ /* Pointer to the next little-endian 16-bit integer in the compressed
+ * input data (reading forwards). */
+ const le16 *in;
+
+ /* Number of 16-bit integers remaining in the compressed input data
+ * (reading forwards). */
+ size_t num_le16_remaining;
+};
+
+/* Structure used for reading raw bits backwards. This is the second logical
+ * bitstream mentioned above. */
+struct lzms_input_bitstream {
+ /* Holding variable for bits that have been read from the compressed
+ * data. The bits are ordered from high-order to low-order. */
+ /* XXX: Without special-case code to handle reading more than 17 bits
+ * at a time, this needs to be 64 bits rather than 32 bits. */
+ u64 bitbuf;
+
+ /* Number of bits in @bitbuf that are used. */
+ unsigned num_filled_bits;
+
+ /* Pointer to the one past the next little-endian 16-bit integer in the
+ * compressed input data (reading backwards). */
+ const le16 *in;
+
+ /* Number of 16-bit integers remaining in the compressed input data
+ * (reading backwards). */
+ size_t num_le16_remaining;
+};
+
+/* Structure used for range decoding. This wraps around `struct
+ * lzms_range_decoder_raw' to use and maintain probability entries. */
+struct lzms_range_decoder {
+ /* Pointer to the raw range decoder, which has no persistent knowledge
+ * of probabilities. Multiple lzms_range_decoder's share the same
+ * lzms_range_decoder_raw. */
+ struct lzms_range_decoder_raw *rd;
+
+ /* Bits recently decoded by this range decoder. This are used as in
+ * index into @prob_entries. */
+ u32 state;
+
+ /* Bitmask for @state to prevent its value from exceeding the number of
+ * probability entries. */
+ u32 mask;
+
+ /* Probability entries being used for this range decoder. */
+ struct lzms_probability_entry prob_entries[LZMS_MAX_NUM_STATES];
+};
+
+/* Structure used for Huffman decoding, optionally using the decoded symbols as
+ * slots into a base table to determine how many extra bits need to be read to
+ * reconstitute the full value. */
+struct lzms_huffman_decoder {
+
+ /* Bitstream to read Huffman-encoded symbols and verbatim bits from.
+ * Multiple lzms_huffman_decoder's share the same lzms_input_bitstream.
+ */
+ struct lzms_input_bitstream *is;
+
+ /* Pointer to the slot base table to use. It is indexed by the decoded
+ * Huffman symbol that specifies the slot. The entry specifies the base
+ * value to use, and the position of its high bit is the number of
+ * additional bits that must be read to reconstitute the full value.
+ *
+ * This member need not be set if only raw Huffman symbols are being
+ * read using this decoder. */
+ const u32 *slot_base_tab;
+
+ const u8 *extra_bits_tab;
+
+ /* Number of symbols that have been read using this code far. Reset to
+ * 0 whenever the code is rebuilt. */
+ u32 num_syms_read;
+
+ /* When @num_syms_read reaches this number, the Huffman code must be
+ * rebuilt. */
+ u32 rebuild_freq;
+
+ /* Number of symbols in the represented Huffman code. */
+ unsigned num_syms;
+
+ /* Running totals of symbol frequencies. These are diluted slightly
+ * whenever the code is rebuilt. */
+ u32 sym_freqs[LZMS_MAX_NUM_SYMS];
+
+ /* The length, in bits, of each symbol in the Huffman code. */
+ u8 lens[LZMS_MAX_NUM_SYMS];
+
+ /* The codeword of each symbol in the Huffman code. */
+ u32 codewords[LZMS_MAX_NUM_SYMS];
+
+ /* A table for quickly decoding symbols encoded using the Huffman code.
+ */
+ u16 decode_table[(1U << LZMS_DECODE_TABLE_BITS) + 2 * LZMS_MAX_NUM_SYMS]
+ _aligned_attribute(DECODE_TABLE_ALIGNMENT);
+};
+
+/* State of the LZMS decompressor. */
+struct lzms_decompressor {
+
+ /* Pointer to the beginning of the uncompressed data buffer. */
+ u8 *out_begin;
+
+ /* Pointer to the next position in the uncompressed data buffer. */
+ u8 *out_next;
+
+ /* Pointer to one past the end of the uncompressed data buffer. */
+ u8 *out_end;
+
+ /* Range decoder, which reads bits from the beginning of the compressed
+ * block, going forwards. */
+ struct lzms_range_decoder_raw rd;
+
+ /* Input bitstream, which reads from the end of the compressed block,
+ * going backwards. */
+ struct lzms_input_bitstream is;
+
+ /* Range decoders. */
+ struct lzms_range_decoder main_range_decoder;
+ struct lzms_range_decoder match_range_decoder;
+ struct lzms_range_decoder lz_match_range_decoder;
+ struct lzms_range_decoder lz_repeat_match_range_decoders[LZMS_NUM_RECENT_OFFSETS - 1];
+ struct lzms_range_decoder delta_match_range_decoder;
+ struct lzms_range_decoder delta_repeat_match_range_decoders[LZMS_NUM_RECENT_OFFSETS - 1];
+
+ /* Huffman decoders. */
+ struct lzms_huffman_decoder literal_decoder;
+ struct lzms_huffman_decoder lz_offset_decoder;
+ struct lzms_huffman_decoder length_decoder;
+ struct lzms_huffman_decoder delta_power_decoder;
+ struct lzms_huffman_decoder delta_offset_decoder;
+
+ /* LRU (least-recently-used) queues for match information. */
+ struct lzms_lru_queues lru;
+
+ /* Used for postprocessing. */
+ s32 last_target_usages[65536];
+};
+
+/* Initialize the input bitstream @is to read forwards from the specified
+ * compressed data buffer @in that is @in_limit 16-bit integers long. */
+static void
+lzms_input_bitstream_init(struct lzms_input_bitstream *is,
+ const le16 *in, size_t in_limit)
+{
+ is->bitbuf = 0;
+ is->num_filled_bits = 0;
+ is->in = in + in_limit;
+ is->num_le16_remaining = in_limit;
+}
+
+/* Ensures that @num_bits bits are buffered in the input bitstream. */
+static int
+lzms_input_bitstream_ensure_bits(struct lzms_input_bitstream *is,
+ unsigned num_bits)
+{
+ while (is->num_filled_bits < num_bits) {
+ u64 next;
+
+ LZMS_ASSERT(is->num_filled_bits + 16 <= sizeof(is->bitbuf) * 8);
+
+ if (unlikely(is->num_le16_remaining == 0))
+ return -1;
+
+ next = le16_to_cpu(*--is->in);
+ is->num_le16_remaining--;
+
+ is->bitbuf |= next << (sizeof(is->bitbuf) * 8 - is->num_filled_bits - 16);
+ is->num_filled_bits += 16;
+ }
+ return 0;
+
+}
+
+/* Returns the next @num_bits bits that are buffered in the input bitstream. */
+static u32
+lzms_input_bitstream_peek_bits(struct lzms_input_bitstream *is,
+ unsigned num_bits)
+{
+ LZMS_ASSERT(is->num_filled_bits >= num_bits);
+ return is->bitbuf >> (sizeof(is->bitbuf) * 8 - num_bits);
+}
+
+/* Removes the next @num_bits bits that are buffered in the input bitstream. */
+static void
+lzms_input_bitstream_remove_bits(struct lzms_input_bitstream *is,
+ unsigned num_bits)
+{
+ LZMS_ASSERT(is->num_filled_bits >= num_bits);
+ is->bitbuf <<= num_bits;
+ is->num_filled_bits -= num_bits;
+}
+
+/* Removes and returns the next @num_bits bits that are buffered in the input
+ * bitstream. */
+static u32
+lzms_input_bitstream_pop_bits(struct lzms_input_bitstream *is,
+ unsigned num_bits)
+{
+ u32 bits = lzms_input_bitstream_peek_bits(is, num_bits);
+ lzms_input_bitstream_remove_bits(is, num_bits);
+ return bits;
+}
+
+/* Reads the next @num_bits from the input bitstream. */
+static u32
+lzms_input_bitstream_read_bits(struct lzms_input_bitstream *is,
+ unsigned num_bits)
+{
+ if (unlikely(lzms_input_bitstream_ensure_bits(is, num_bits)))
+ return 0;
+ return lzms_input_bitstream_pop_bits(is, num_bits);
+}
+
+/* Initialize the range decoder @rd to read forwards from the specified
+ * compressed data buffer @in that is @in_limit 16-bit integers long. */
+static void
+lzms_range_decoder_raw_init(struct lzms_range_decoder_raw *rd,
+ const le16 *in, size_t in_limit)
+{
+ rd->range = 0xffffffff;
+ rd->code = ((u32)le16_to_cpu(in[0]) << 16) |
+ ((u32)le16_to_cpu(in[1]) << 0);
+ rd->in = in + 2;
+ rd->num_le16_remaining = in_limit - 2;
+}
+
+/* Ensures the current range of the range decoder has at least 16 bits of
+ * precision. */
+static int
+lzms_range_decoder_raw_normalize(struct lzms_range_decoder_raw *rd)
+{
+ if (rd->range <= 0xffff) {
+ rd->range <<= 16;
+ if (unlikely(rd->num_le16_remaining == 0))
+ return -1;
+ rd->code = (rd->code << 16) | le16_to_cpu(*rd->in++);
+ rd->num_le16_remaining--;
+ }
+ return 0;
+}
+
+/* Decode and return the next bit from the range decoder (raw version).
+ *
+ * @prob is the chance out of LZMS_PROBABILITY_MAX that the next bit is 0.
+ */
+static int
+lzms_range_decoder_raw_decode_bit(struct lzms_range_decoder_raw *rd, u32 prob)
+{
+ u32 bound;
+
+ /* Ensure the range has at least 16 bits of precision. */
+ lzms_range_decoder_raw_normalize(rd);
+
+ /* Based on the probability, calculate the bound between the 0-bit
+ * region and the 1-bit region of the range. */
+ bound = (rd->range >> LZMS_PROBABILITY_BITS) * prob;
+
+ if (rd->code < bound) {
+ /* Current code is in the 0-bit region of the range. */
+ rd->range = bound;
+ return 0;
+ } else {
+ /* Current code is in the 1-bit region of the range. */
+ rd->range -= bound;
+ rd->code -= bound;
+ return 1;
+ }
+}
+
+/* Decode and return the next bit from the range decoder. This wraps around
+ * lzms_range_decoder_raw_decode_bit() to handle using and updating the
+ * appropriate probability table. */
+static int
+lzms_range_decode_bit(struct lzms_range_decoder *dec)
+{
+ struct lzms_probability_entry *prob_entry;
+ u32 prob;
+ int bit;
+
+ /* Load the probability entry corresponding to the current state. */
+ prob_entry = &dec->prob_entries[dec->state];
+
+ /* Treat the number of zero bits in the most recently decoded
+ * LZMS_PROBABILITY_MAX bits with this probability entry as the chance,
+ * out of LZMS_PROBABILITY_MAX, that the next bit will be a 0. However,
+ * don't allow 0% or 100% probabilities. */
+ prob = prob_entry->num_recent_zero_bits;
+ if (prob == LZMS_PROBABILITY_MAX)
+ prob = LZMS_PROBABILITY_MAX - 1;
+ else if (prob == 0)
+ prob = 1;
+
+ /* Decode the next bit. */
+ bit = lzms_range_decoder_raw_decode_bit(dec->rd, prob);
+
+ /* Update the state based on the newly decoded bit. */
+ dec->state = (((dec->state << 1) | bit) & dec->mask);
+
+ /* Update the recent bits, including the cached count of 0's. */
+ BUILD_BUG_ON(LZMS_PROBABILITY_MAX > sizeof(prob_entry->recent_bits) * 8);
+ if (bit == 0) {
+ if (prob_entry->recent_bits & (1ULL << (LZMS_PROBABILITY_MAX - 1))) {
+ /* Replacing 1 bit with 0 bit; increment the zero count.
+ */
+ prob_entry->num_recent_zero_bits++;
+ }
+ } else {
+ if (!(prob_entry->recent_bits & (1ULL << (LZMS_PROBABILITY_MAX - 1)))) {
+ /* Replacing 0 bit with 1 bit; decrement the zero count.
+ */
+ prob_entry->num_recent_zero_bits--;
+ }
+ }
+ prob_entry->recent_bits = (prob_entry->recent_bits << 1) | bit;
+
+ /* Return the decoded bit. */
+ return bit;
+}
+
+
+/* Build the decoding table for a new adaptive Huffman code using the alphabet
+ * used in the specified Huffman decoder, with the symbol frequencies
+ * dec->sym_freqs. */
+static void
+lzms_rebuild_adaptive_huffman_code(struct lzms_huffman_decoder *dec)
+{
+
+ /* XXX: This implementation makes use of code already implemented for
+ * the XPRESS and LZX compression formats. However, since for the
+ * adaptive codes used in LZMS we don't actually need the explicit codes
+ * themselves, only the decode tables, it may be possible to optimize
+ * this by somehow directly building or updating the Huffman decode
+ * table. This may be a worthwhile optimization because the adaptive
+ * codes change many times throughout a decompression run. */
+ LZMS_DEBUG("Rebuilding adaptive Huffman code (num_syms=%u)",
+ dec->num_syms);
+ make_canonical_huffman_code(dec->num_syms, LZMS_MAX_CODEWORD_LEN,
+ dec->sym_freqs, dec->lens, dec->codewords);
+#if defined(ENABLE_LZMS_DEBUG)
+ int ret =
+#endif
+ make_huffman_decode_table(dec->decode_table, dec->num_syms,
+ LZMS_DECODE_TABLE_BITS, dec->lens,
+ LZMS_MAX_CODEWORD_LEN);
+ LZMS_ASSERT(ret == 0);
+}
+
+/* Decode and return the next Huffman-encoded symbol from the LZMS-compressed
+ * block using the specified Huffman decoder. */
+static u32
+lzms_huffman_decode_symbol(struct lzms_huffman_decoder *dec)
+{
+ const u16 *decode_table = dec->decode_table;
+ struct lzms_input_bitstream *is = dec->is;
+ u16 entry;
+ u16 key_bits;
+ u16 sym;
+
+ /* The Huffman codes used in LZMS are adaptive and must be rebuilt
+ * whenever a certain number of symbols have been read. Each such
+ * rebuild uses the current symbol frequencies, but the format also
+ * requires that the symbol frequencies be halved after each code
+ * rebuild. This diminishes the effect of old symbols on the current
+ * Huffman codes, thereby causing the Huffman codes to be more locally
+ * adaptable. */
+ if (dec->num_syms_read == dec->rebuild_freq) {
+ lzms_rebuild_adaptive_huffman_code(dec);
+ for (unsigned i = 0; i < dec->num_syms; i++) {
+ dec->sym_freqs[i] >>= 1;
+ dec->sym_freqs[i] += 1;
+ }
+ dec->num_syms_read = 0;
+ }