+
+/* Decode and return the next bit from the range decoder (raw version).
+ *
+ * @prob is the chance out of LZMS_PROBABILITY_MAX that the next bit is 0.
+ */
+static int
+lzms_range_decoder_raw_decode_bit(struct lzms_range_decoder_raw *rd, u32 prob)
+{
+ u32 bound;
+
+ /* Ensure the range has at least 16 bits of precision. */
+ lzms_range_decoder_raw_normalize(rd);
+
+ /* Based on the probability, calculate the bound between the 0-bit
+ * region and the 1-bit region of the range. */
+ bound = (rd->range >> LZMS_PROBABILITY_BITS) * prob;
+
+ if (rd->code < bound) {
+ /* Current code is in the 0-bit region of the range. */
+ rd->range = bound;
+ return 0;
+ } else {
+ /* Current code is in the 1-bit region of the range. */
+ rd->range -= bound;
+ rd->code -= bound;
+ return 1;
+ }
+}
+
+/* Decode and return the next bit from the range decoder. This wraps around
+ * lzms_range_decoder_raw_decode_bit() to handle using and updating the
+ * appropriate probability table. */
+static int
+lzms_range_decode_bit(struct lzms_range_decoder *dec)
+{
+ struct lzms_probability_entry *prob_entry;
+ u32 prob;
+ int bit;
+
+ /* Load the probability entry corresponding to the current state. */
+ prob_entry = &dec->prob_entries[dec->state];
+
+ /* Get the probability that the next bit is 0. */
+ prob = lzms_get_probability(prob_entry);
+
+ /* Decode the next bit. */
+ bit = lzms_range_decoder_raw_decode_bit(dec->rd, prob);
+
+ /* Update the state and probability entry based on the decoded bit. */
+ dec->state = (((dec->state << 1) | bit) & dec->mask);
+ lzms_update_probability_entry(prob_entry, bit);
+
+ /* Return the decoded bit. */
+ return bit;
+}
+
+
+/* Build the decoding table for a new adaptive Huffman code using the alphabet
+ * used in the specified Huffman decoder, with the symbol frequencies
+ * dec->sym_freqs. */
+static void
+lzms_rebuild_adaptive_huffman_code(struct lzms_huffman_decoder *dec)
+{
+
+ /* XXX: This implementation makes use of code already implemented for
+ * the XPRESS and LZX compression formats. However, since for the
+ * adaptive codes used in LZMS we don't actually need the explicit codes
+ * themselves, only the decode tables, it may be possible to optimize
+ * this by somehow directly building or updating the Huffman decode
+ * table. This may be a worthwhile optimization because the adaptive
+ * codes change many times throughout a decompression run. */
+ LZMS_DEBUG("Rebuilding adaptive Huffman code (num_syms=%u)",
+ dec->num_syms);
+ make_canonical_huffman_code(dec->num_syms, LZMS_MAX_CODEWORD_LEN,
+ dec->sym_freqs, dec->lens, dec->codewords);
+#if defined(ENABLE_LZMS_DEBUG)
+ int ret =
+#endif
+ make_huffman_decode_table(dec->decode_table, dec->num_syms,
+ LZMS_DECODE_TABLE_BITS, dec->lens,
+ LZMS_MAX_CODEWORD_LEN);
+ LZMS_ASSERT(ret == 0);
+}
+
+/* Decode and return the next Huffman-encoded symbol from the LZMS-compressed
+ * block using the specified Huffman decoder. */
+static u32
+lzms_huffman_decode_symbol(struct lzms_huffman_decoder *dec)
+{
+ const u16 *decode_table = dec->decode_table;
+ struct lzms_input_bitstream *is = dec->is;
+ u16 entry;
+ u16 key_bits;
+ u16 sym;
+
+ /* The Huffman codes used in LZMS are adaptive and must be rebuilt
+ * whenever a certain number of symbols have been read. Each such
+ * rebuild uses the current symbol frequencies, but the format also
+ * requires that the symbol frequencies be halved after each code
+ * rebuild. This diminishes the effect of old symbols on the current
+ * Huffman codes, thereby causing the Huffman codes to be more locally
+ * adaptable. */
+ if (dec->num_syms_read == dec->rebuild_freq) {
+ lzms_rebuild_adaptive_huffman_code(dec);
+ for (unsigned i = 0; i < dec->num_syms; i++) {
+ dec->sym_freqs[i] >>= 1;
+ dec->sym_freqs[i] += 1;
+ }
+ dec->num_syms_read = 0;
+ }
+
+ /* XXX: Copied from read_huffsym() (decompress_common.h), since this
+ * uses a different input bitstream type. Should unify the
+ * implementations. */
+ lzms_input_bitstream_ensure_bits(is, LZMS_MAX_CODEWORD_LEN);
+
+ /* Index the decode table by the next table_bits bits of the input. */
+ key_bits = lzms_input_bitstream_peek_bits(is, LZMS_DECODE_TABLE_BITS);
+ entry = decode_table[key_bits];
+ if (likely(entry < 0xC000)) {
+ /* Fast case: The decode table directly provided the symbol and
+ * codeword length. The low 11 bits are the symbol, and the
+ * high 5 bits are the codeword length. */
+ lzms_input_bitstream_remove_bits(is, entry >> 11);
+ sym = entry & 0x7FF;
+ } else {
+ /* Slow case: The codeword for the symbol is longer than
+ * table_bits, so the symbol does not have an entry directly in
+ * the first (1 << table_bits) entries of the decode table.
+ * Traverse the appropriate binary tree bit-by-bit in order to
+ * decode the symbol. */
+ lzms_input_bitstream_remove_bits(is, LZMS_DECODE_TABLE_BITS);
+ do {
+ key_bits = (entry & 0x3FFF) + lzms_input_bitstream_pop_bits(is, 1);
+ } while ((entry = decode_table[key_bits]) >= 0xC000);
+ sym = entry;
+ }
+
+ /* Tally and return the decoded symbol. */
+ ++dec->sym_freqs[sym];
+ ++dec->num_syms_read;
+ return sym;
+}
+
+/* Decode a number from the LZMS bitstream, encoded as a Huffman-encoded symbol
+ * specifying a "slot" (whose corresponding value is looked up in a static
+ * table) plus the number specified by a number of extra bits depending on the
+ * slot. */
+static u32
+lzms_decode_value(struct lzms_huffman_decoder *dec)
+{
+ unsigned slot;
+ unsigned num_extra_bits;
+ u32 extra_bits;
+
+ LZMS_ASSERT(dec->slot_base_tab != NULL);
+ LZMS_ASSERT(dec->extra_bits_tab != NULL);
+
+ /* Read the slot (offset slot, length slot, etc.), which is encoded as a
+ * Huffman symbol. */
+ slot = lzms_huffman_decode_symbol(dec);
+
+ /* Get the number of extra bits needed to represent the range of values
+ * that share the slot. */
+ num_extra_bits = dec->extra_bits_tab[slot];
+
+ /* Read the number of extra bits and add them to the slot base to form
+ * the final decoded value. */
+ extra_bits = lzms_input_bitstream_read_bits(dec->is, num_extra_bits);
+ return dec->slot_base_tab[slot] + extra_bits;
+}
+
+/* Copy a literal to the output buffer. */
+static int
+lzms_copy_literal(struct lzms_decompressor *ctx, u8 literal)
+{
+ *ctx->out_next++ = literal;
+ return 0;
+}
+
+/* Validate an LZ match and copy it to the output buffer. */
+static int
+lzms_copy_lz_match(struct lzms_decompressor *ctx, u32 length, u32 offset)
+{
+ u8 *out_next;
+
+ if (length > ctx->out_end - ctx->out_next) {
+ LZMS_DEBUG("Match overrun!");
+ return -1;
+ }
+ if (offset > ctx->out_next - ctx->out_begin) {
+ LZMS_DEBUG("Match underrun!");
+ return -1;
+ }
+
+ out_next = ctx->out_next;
+
+ lz_copy(out_next, length, offset, ctx->out_end, 1);
+ ctx->out_next = out_next + length;
+
+ return 0;
+}
+
+/* Validate a delta match and copy it to the output buffer. */
+static int
+lzms_copy_delta_match(struct lzms_decompressor *ctx, u32 length,
+ u32 power, u32 raw_offset)
+{
+ u32 offset1 = 1U << power;
+ u32 offset2 = raw_offset << power;
+ u32 offset = offset1 + offset2;
+ u8 *out_next;
+ u8 *matchptr1;
+ u8 *matchptr2;
+ u8 *matchptr;
+
+ if (length > ctx->out_end - ctx->out_next) {
+ LZMS_DEBUG("Match overrun!");
+ return -1;
+ }
+ if (offset > ctx->out_next - ctx->out_begin) {
+ LZMS_DEBUG("Match underrun!");
+ return -1;
+ }
+
+ out_next = ctx->out_next;
+ matchptr1 = out_next - offset1;
+ matchptr2 = out_next - offset2;
+ matchptr = out_next - offset;
+
+ while (length--)
+ *out_next++ = *matchptr1++ + *matchptr2++ - *matchptr++;
+
+ ctx->out_next = out_next;
+ return 0;
+}
+
+/* Decode a (length, offset) pair from the input. */
+static int
+lzms_decode_lz_match(struct lzms_decompressor *ctx)
+{
+ int bit;
+ u32 length, offset;
+
+ /* Decode the match offset. The next range-encoded bit indicates
+ * whether it's a repeat offset or an explicit offset. */
+
+ bit = lzms_range_decode_bit(&ctx->lz_match_range_decoder);
+ if (bit == 0) {
+ /* Explicit offset. */
+ offset = lzms_decode_value(&ctx->lz_offset_decoder);
+ } else {
+ /* Repeat offset. */
+ int i;
+
+ for (i = 0; i < LZMS_NUM_RECENT_OFFSETS - 1; i++)
+ if (!lzms_range_decode_bit(&ctx->lz_repeat_match_range_decoders[i]))
+ break;
+
+ offset = ctx->lru.lz.recent_offsets[i];
+
+ for (; i < LZMS_NUM_RECENT_OFFSETS; i++)
+ ctx->lru.lz.recent_offsets[i] = ctx->lru.lz.recent_offsets[i + 1];
+ }
+
+ /* Decode match length, which is always given explicitly (there is no
+ * LRU queue for repeat lengths). */
+ length = lzms_decode_value(&ctx->length_decoder);
+
+ ctx->lru.lz.upcoming_offset = offset;
+
+ LZMS_DEBUG("Decoded %s LZ match: length=%u, offset=%u",
+ (bit ? "repeat" : "explicit"), length, offset);
+
+ /* Validate the match and copy it to the output. */
+ return lzms_copy_lz_match(ctx, length, offset);
+}
+
+/* Decodes a "delta" match from the input. */
+static int
+lzms_decode_delta_match(struct lzms_decompressor *ctx)
+{
+ int bit;
+ u32 length, power, raw_offset;
+
+ /* Decode the match power and raw offset. The next range-encoded bit
+ * indicates whether these data are a repeat, or given explicitly. */
+
+ bit = lzms_range_decode_bit(&ctx->delta_match_range_decoder);
+ if (bit == 0) {
+ power = lzms_huffman_decode_symbol(&ctx->delta_power_decoder);
+ raw_offset = lzms_decode_value(&ctx->delta_offset_decoder);
+ } else {
+ int i;
+
+ for (i = 0; i < LZMS_NUM_RECENT_OFFSETS - 1; i++)
+ if (!lzms_range_decode_bit(&ctx->delta_repeat_match_range_decoders[i]))
+ break;
+
+ power = ctx->lru.delta.recent_powers[i];
+ raw_offset = ctx->lru.delta.recent_offsets[i];
+
+ for (; i < LZMS_NUM_RECENT_OFFSETS; i++) {
+ ctx->lru.delta.recent_powers[i] = ctx->lru.delta.recent_powers[i + 1];
+ ctx->lru.delta.recent_offsets[i] = ctx->lru.delta.recent_offsets[i + 1];
+ }
+ }
+
+ length = lzms_decode_value(&ctx->length_decoder);
+
+ ctx->lru.delta.upcoming_power = power;
+ ctx->lru.delta.upcoming_offset = raw_offset;
+
+ LZMS_DEBUG("Decoded %s delta match: length=%u, power=%u, raw_offset=%u",
+ (bit ? "repeat" : "explicit"), length, power, raw_offset);
+
+ /* Validate the match and copy it to the output. */
+ return lzms_copy_delta_match(ctx, length, power, raw_offset);
+}
+
+/* Decode an LZ or delta match. */
+static int
+lzms_decode_match(struct lzms_decompressor *ctx)
+{
+ if (!lzms_range_decode_bit(&ctx->match_range_decoder))
+ return lzms_decode_lz_match(ctx);
+ else
+ return lzms_decode_delta_match(ctx);
+}
+
+/* Decode a literal byte encoded using the literal Huffman code. */
+static int
+lzms_decode_literal(struct lzms_decompressor *ctx)
+{
+ u8 literal = lzms_huffman_decode_symbol(&ctx->literal_decoder);
+ LZMS_DEBUG("Decoded literal: 0x%02x", literal);
+ return lzms_copy_literal(ctx, literal);
+}
+
+/* Decode the next LZMS match or literal. */
+static int
+lzms_decode_item(struct lzms_decompressor *ctx)
+{
+ int ret;
+
+ ctx->lru.lz.upcoming_offset = 0;
+ ctx->lru.delta.upcoming_power = 0;
+ ctx->lru.delta.upcoming_offset = 0;
+
+ if (lzms_range_decode_bit(&ctx->main_range_decoder))
+ ret = lzms_decode_match(ctx);
+ else
+ ret = lzms_decode_literal(ctx);
+
+ if (ret)
+ return ret;
+
+ lzms_update_lru_queues(&ctx->lru);
+ return 0;
+}
+
+static void
+lzms_init_range_decoder(struct lzms_range_decoder *dec,
+ struct lzms_range_decoder_raw *rd, u32 num_states)
+{
+ dec->rd = rd;
+ dec->state = 0;
+ dec->mask = num_states - 1;
+ for (u32 i = 0; i < num_states; i++) {
+ dec->prob_entries[i].num_recent_zero_bits = LZMS_INITIAL_PROBABILITY;
+ dec->prob_entries[i].recent_bits = LZMS_INITIAL_RECENT_BITS;
+ }
+}
+
+static void
+lzms_init_huffman_decoder(struct lzms_huffman_decoder *dec,
+ struct lzms_input_bitstream *is,
+ const u32 *slot_base_tab,
+ const u8 *extra_bits_tab,
+ unsigned num_syms,
+ unsigned rebuild_freq)
+{
+ dec->is = is;
+ dec->slot_base_tab = slot_base_tab;
+ dec->extra_bits_tab = extra_bits_tab;
+ dec->num_syms = num_syms;
+ dec->num_syms_read = rebuild_freq;
+ dec->rebuild_freq = rebuild_freq;
+ for (unsigned i = 0; i < num_syms; i++)
+ dec->sym_freqs[i] = 1;
+}
+
+/* Prepare to decode items from an LZMS-compressed block. */
+static void
+lzms_init_decompressor(struct lzms_decompressor *ctx,
+ const void *cdata, unsigned clen,
+ void *ubuf, unsigned ulen)
+{
+ unsigned num_offset_slots;
+
+ LZMS_DEBUG("Initializing decompressor (clen=%u, ulen=%u)", clen, ulen);
+
+ /* Initialize output pointers. */
+ ctx->out_begin = ubuf;
+ ctx->out_next = ubuf;
+ ctx->out_end = (u8*)ubuf + ulen;
+
+ /* Initialize the raw range decoder (reading forwards). */
+ lzms_range_decoder_raw_init(&ctx->rd, cdata, clen / 2);
+
+ /* Initialize the input bitstream for Huffman symbols (reading
+ * backwards) */
+ lzms_input_bitstream_init(&ctx->is, cdata, clen / 2);
+
+ /* Calculate the number of offset slots needed for this compressed
+ * block. */
+ num_offset_slots = lzms_get_offset_slot(ulen - 1) + 1;
+
+ LZMS_DEBUG("Using %u offset slots", num_offset_slots);
+
+ /* Initialize Huffman decoders for each alphabet used in the compressed
+ * representation. */
+ lzms_init_huffman_decoder(&ctx->literal_decoder, &ctx->is,
+ NULL, NULL, LZMS_NUM_LITERAL_SYMS,
+ LZMS_LITERAL_CODE_REBUILD_FREQ);
+
+ lzms_init_huffman_decoder(&ctx->lz_offset_decoder, &ctx->is,
+ lzms_offset_slot_base,
+ lzms_extra_offset_bits,
+ num_offset_slots,
+ LZMS_LZ_OFFSET_CODE_REBUILD_FREQ);
+
+ lzms_init_huffman_decoder(&ctx->length_decoder, &ctx->is,
+ lzms_length_slot_base,
+ lzms_extra_length_bits,
+ LZMS_NUM_LEN_SYMS,
+ LZMS_LENGTH_CODE_REBUILD_FREQ);
+
+ lzms_init_huffman_decoder(&ctx->delta_offset_decoder, &ctx->is,
+ lzms_offset_slot_base,
+ lzms_extra_offset_bits,
+ num_offset_slots,
+ LZMS_DELTA_OFFSET_CODE_REBUILD_FREQ);
+
+ lzms_init_huffman_decoder(&ctx->delta_power_decoder, &ctx->is,
+ NULL, NULL, LZMS_NUM_DELTA_POWER_SYMS,
+ LZMS_DELTA_POWER_CODE_REBUILD_FREQ);
+
+
+ /* Initialize range decoders, all of which wrap around the same
+ * lzms_range_decoder_raw. */
+ lzms_init_range_decoder(&ctx->main_range_decoder,
+ &ctx->rd, LZMS_NUM_MAIN_STATES);
+
+ lzms_init_range_decoder(&ctx->match_range_decoder,
+ &ctx->rd, LZMS_NUM_MATCH_STATES);
+
+ lzms_init_range_decoder(&ctx->lz_match_range_decoder,
+ &ctx->rd, LZMS_NUM_LZ_MATCH_STATES);
+
+ for (size_t i = 0; i < ARRAY_LEN(ctx->lz_repeat_match_range_decoders); i++)
+ lzms_init_range_decoder(&ctx->lz_repeat_match_range_decoders[i],
+ &ctx->rd, LZMS_NUM_LZ_REPEAT_MATCH_STATES);
+
+ lzms_init_range_decoder(&ctx->delta_match_range_decoder,
+ &ctx->rd, LZMS_NUM_DELTA_MATCH_STATES);
+
+ for (size_t i = 0; i < ARRAY_LEN(ctx->delta_repeat_match_range_decoders); i++)
+ lzms_init_range_decoder(&ctx->delta_repeat_match_range_decoders[i],
+ &ctx->rd, LZMS_NUM_DELTA_REPEAT_MATCH_STATES);
+
+ /* Initialize LRU match information. */
+ lzms_init_lru_queues(&ctx->lru);
+
+ LZMS_DEBUG("Decompressor successfully initialized");
+}
+
+/* Decode the series of literals and matches from the LZMS-compressed data.
+ * Returns 0 on success; nonzero if the compressed data is invalid. */
+static int
+lzms_decode_items(const u8 *cdata, size_t clen, u8 *ubuf, size_t ulen,
+ struct lzms_decompressor *ctx)
+{
+ /* Initialize the LZMS decompressor. */
+ lzms_init_decompressor(ctx, cdata, clen, ubuf, ulen);
+
+ /* Decode the sequence of items. */
+ while (ctx->out_next != ctx->out_end) {
+ LZMS_DEBUG("Position %u", ctx->out_next - ctx->out_begin);
+ if (lzms_decode_item(ctx))
+ return -1;
+ }
+ return 0;
+}
+
+static int
+lzms_decompress(const void *compressed_data, size_t compressed_size,
+ void *uncompressed_data, size_t uncompressed_size, void *_ctx)
+{
+ struct lzms_decompressor *ctx = _ctx;
+
+ /* The range decoder requires that a minimum of 4 bytes of compressed
+ * data be initially available. */
+ if (compressed_size < 4) {
+ LZMS_DEBUG("Compressed size too small (got %zu, expected >= 4)",
+ compressed_size);
+ return -1;
+ }
+
+ /* An LZMS-compressed data block should be evenly divisible into 16-bit
+ * integers. */
+ if (compressed_size % 2 != 0) {
+ LZMS_DEBUG("Compressed size not divisible by 2 (got %zu)",
+ compressed_size);
+ return -1;
+ }
+
+ /* Handle the trivial case where nothing needs to be decompressed.
+ * (Necessary because a window of size 0 does not have a valid offset
+ * slot.) */
+ if (uncompressed_size == 0)
+ return 0;
+
+ /* Decode the literals and matches. */
+ if (lzms_decode_items(compressed_data, compressed_size,
+ uncompressed_data, uncompressed_size, ctx))
+ return -1;
+
+ /* Postprocess the data. */
+ lzms_x86_filter(uncompressed_data, uncompressed_size,
+ ctx->last_target_usages, true);
+
+ LZMS_DEBUG("Decompression successful.");
+ return 0;
+}
+
+static void
+lzms_free_decompressor(void *_ctx)
+{
+ struct lzms_decompressor *ctx = _ctx;
+
+ ALIGNED_FREE(ctx);
+}
+
+static int
+lzms_create_decompressor(size_t max_block_size, void **ctx_ret)
+{
+ struct lzms_decompressor *ctx;
+
+ /* The x86 post-processor requires that the uncompressed length fit into
+ * a signed 32-bit integer. Also, the offset slot table cannot be
+ * searched for an offset of INT32_MAX or greater. */
+ if (max_block_size >= INT32_MAX)
+ return WIMLIB_ERR_INVALID_PARAM;
+
+ ctx = ALIGNED_MALLOC(sizeof(struct lzms_decompressor),
+ DECODE_TABLE_ALIGNMENT);
+ if (ctx == NULL)
+ return WIMLIB_ERR_NOMEM;
+
+ /* Initialize offset and length slot data if not done already. */
+ lzms_init_slots();
+
+ *ctx_ret = ctx;
+ return 0;
+}
+
+const struct decompressor_ops lzms_decompressor_ops = {
+ .create_decompressor = lzms_create_decompressor,
+ .decompress = lzms_decompress,
+ .free_decompressor = lzms_free_decompressor,
+};