]> wimlib.net Git - wimlib/blob - src/lzms-compress.c
72aef31b14fda69dd530bb7c5bc8e1cfaeccd92b
[wimlib] / src / lzms-compress.c
1 /*
2  * lzms-compress.c
3  */
4
5 /*
6  * Copyright (C) 2013 Eric Biggers
7  *
8  * This file is part of wimlib, a library for working with WIM files.
9  *
10  * wimlib is free software; you can redistribute it and/or modify it under the
11  * terms of the GNU General Public License as published by the Free
12  * Software Foundation; either version 3 of the License, or (at your option)
13  * any later version.
14  *
15  * wimlib is distributed in the hope that it will be useful, but WITHOUT ANY
16  * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
17  * A PARTICULAR PURPOSE. See the GNU General Public License for more
18  * details.
19  *
20  * You should have received a copy of the GNU General Public License
21  * along with wimlib; if not, see http://www.gnu.org/licenses/.
22  */
23
24 /* This a compressor for the LZMS compression format.  More details about this
25  * format can be found in lzms-decompress.c.
26  *
27  * This is currently an unsophisticated implementation that is fast but does not
28  * attain the best compression ratios allowed by the format.
29  */
30
31 #ifdef HAVE_CONFIG_H
32 #  include "config.h"
33 #endif
34
35 #include "wimlib.h"
36 #include "wimlib/compiler.h"
37 #include "wimlib/compressor_ops.h"
38 #include "wimlib/compress_common.h"
39 #include "wimlib/endianness.h"
40 #include "wimlib/error.h"
41 #include "wimlib/lz_hash.h"
42 #include "wimlib/lz_sarray.h"
43 #include "wimlib/lzms.h"
44 #include "wimlib/util.h"
45
46 #include <string.h>
47 #include <limits.h>
48 #include <pthread.h>
49
50 #define LZMS_OPTIM_ARRAY_SIZE   1024
51
52 struct lzms_compressor;
53 struct lzms_adaptive_state {
54         struct lzms_lz_lru_queues lru;
55         u8 main_state;
56         u8 match_state;
57         u8 lz_match_state;
58         u8 lz_repeat_match_state[LZMS_NUM_RECENT_OFFSETS - 1];
59 };
60 #define LZ_ADAPTIVE_STATE struct lzms_adaptive_state
61 #define LZ_COMPRESSOR     struct lzms_compressor
62 #include "wimlib/lz_optimal.h"
63
64 /* Stucture used for writing raw bits to the end of the LZMS-compressed data as
65  * a series of 16-bit little endian coding units.  */
66 struct lzms_output_bitstream {
67         /* Buffer variable containing zero or more bits that have been logically
68          * written to the bitstream but not yet written to memory.  This must be
69          * at least as large as the coding unit size.  */
70         u16 bitbuf;
71
72         /* Number of bits in @bitbuf that are valid.  */
73         unsigned num_free_bits;
74
75         /* Pointer to one past the next position in the compressed data buffer
76          * at which to output a 16-bit coding unit.  */
77         le16 *out;
78
79         /* Maximum number of 16-bit coding units that can still be output to
80          * the compressed data buffer.  */
81         size_t num_le16_remaining;
82
83         /* Set to %true if not all coding units could be output due to
84          * insufficient space.  */
85         bool overrun;
86 };
87
88 /* Stucture used for range encoding (raw version).  */
89 struct lzms_range_encoder_raw {
90
91         /* A 33-bit variable that holds the low boundary of the current range.
92          * The 33rd bit is needed to catch carries.  */
93         u64 low;
94
95         /* Size of the current range.  */
96         u32 range;
97
98         /* Next 16-bit coding unit to output.  */
99         u16 cache;
100
101         /* Number of 16-bit coding units whose output has been delayed due to
102          * possible carrying.  The first such coding unit is @cache; all
103          * subsequent such coding units are 0xffff.  */
104         u32 cache_size;
105
106         /* Pointer to the next position in the compressed data buffer at which
107          * to output a 16-bit coding unit.  */
108         le16 *out;
109
110         /* Maximum number of 16-bit coding units that can still be output to
111          * the compressed data buffer.  */
112         size_t num_le16_remaining;
113
114         /* %true when the very first coding unit has not yet been output.  */
115         bool first;
116
117         /* Set to %true if not all coding units could be output due to
118          * insufficient space.  */
119         bool overrun;
120 };
121
122 /* Structure used for range encoding.  This wraps around `struct
123  * lzms_range_encoder_raw' to use and maintain probability entries.  */
124 struct lzms_range_encoder {
125         /* Pointer to the raw range encoder, which has no persistent knowledge
126          * of probabilities.  Multiple lzms_range_encoder's share the same
127          * lzms_range_encoder_raw.  */
128         struct lzms_range_encoder_raw *rc;
129
130         /* Bits recently encoded by this range encoder.  This are used as in
131          * index into @prob_entries.  */
132         u32 state;
133
134         /* Bitmask for @state to prevent its value from exceeding the number of
135          * probability entries.  */
136         u32 mask;
137
138         /* Probability entries being used for this range encoder.  */
139         struct lzms_probability_entry prob_entries[LZMS_MAX_NUM_STATES];
140 };
141
142 /* Structure used for Huffman encoding.  */
143 struct lzms_huffman_encoder {
144
145         /* Bitstream to write Huffman-encoded symbols and verbatim bits to.
146          * Multiple lzms_huffman_encoder's share the same lzms_output_bitstream.
147          */
148         struct lzms_output_bitstream *os;
149
150         /* Number of symbols that have been written using this code far.  Reset
151          * to 0 whenever the code is rebuilt.  */
152         u32 num_syms_written;
153
154         /* When @num_syms_written reaches this number, the Huffman code must be
155          * rebuilt.  */
156         u32 rebuild_freq;
157
158         /* Number of symbols in the represented Huffman code.  */
159         unsigned num_syms;
160
161         /* Running totals of symbol frequencies.  These are diluted slightly
162          * whenever the code is rebuilt.  */
163         u32 sym_freqs[LZMS_MAX_NUM_SYMS];
164
165         /* The length, in bits, of each symbol in the Huffman code.  */
166         u8 lens[LZMS_MAX_NUM_SYMS];
167
168         /* The codeword of each symbol in the Huffman code.  */
169         u16 codewords[LZMS_MAX_NUM_SYMS];
170 };
171
172 /* State of the LZMS compressor.  */
173 struct lzms_compressor {
174         /* Pointer to a buffer holding the preprocessed data to compress.  */
175         u8 *window;
176
177         /* Current position in @buffer.  */
178         u32 cur_window_pos;
179
180         /* Size of the data in @buffer.  */
181         u32 window_size;
182
183 #if 0
184         /* Temporary array used by lz_analyze_block(); must be at least as long
185          * as the window.  */
186         u32 *prev_tab;
187 #endif
188
189         /* Suffix array match-finder.  */
190         struct lz_sarray lz_sarray;
191
192         /* Temporary space to store found matches.  */
193         struct raw_match *matches;
194
195         /* Match-chooser.  */
196         struct lz_match_chooser mc;
197
198         /* Maximum block size this compressor instantiation allows.  This is the
199          * allocated size of @window.  */
200         u32 max_block_size;
201
202         /* Raw range encoder which outputs to the beginning of the compressed
203          * data buffer, proceeding forwards.  */
204         struct lzms_range_encoder_raw rc;
205
206         /* Bitstream which outputs to the end of the compressed data buffer,
207          * proceeding backwards.  */
208         struct lzms_output_bitstream os;
209
210         /* Range encoders.  */
211         struct lzms_range_encoder main_range_encoder;
212         struct lzms_range_encoder match_range_encoder;
213         struct lzms_range_encoder lz_match_range_encoder;
214         struct lzms_range_encoder lz_repeat_match_range_encoders[LZMS_NUM_RECENT_OFFSETS - 1];
215         struct lzms_range_encoder delta_match_range_encoder;
216         struct lzms_range_encoder delta_repeat_match_range_encoders[LZMS_NUM_RECENT_OFFSETS - 1];
217
218         /* Huffman encoders.  */
219         struct lzms_huffman_encoder literal_encoder;
220         struct lzms_huffman_encoder lz_offset_encoder;
221         struct lzms_huffman_encoder length_encoder;
222         struct lzms_huffman_encoder delta_power_encoder;
223         struct lzms_huffman_encoder delta_offset_encoder;
224
225         /* LRU (least-recently-used) queues for match information.  */
226         struct lzms_lru_queues lru;
227
228         /* Used for preprocessing.  */
229         s32 last_target_usages[65536];
230 };
231
232 /* Initialize the output bitstream @os to write forwards to the specified
233  * compressed data buffer @out that is @out_limit 16-bit integers long.  */
234 static void
235 lzms_output_bitstream_init(struct lzms_output_bitstream *os,
236                            le16 *out, size_t out_limit)
237 {
238         os->bitbuf = 0;
239         os->num_free_bits = 16;
240         os->out = out + out_limit;
241         os->num_le16_remaining = out_limit;
242         os->overrun = false;
243 }
244
245 /* Write @num_bits bits, contained in the low @num_bits bits of @bits (ordered
246  * from high-order to low-order), to the output bitstream @os.  */
247 static void
248 lzms_output_bitstream_put_bits(struct lzms_output_bitstream *os,
249                                u32 bits, unsigned num_bits)
250 {
251         bits &= (1U << num_bits) - 1;
252
253         while (num_bits > os->num_free_bits) {
254
255                 if (unlikely(os->num_le16_remaining == 0)) {
256                         os->overrun = true;
257                         return;
258                 }
259
260                 unsigned num_fill_bits = os->num_free_bits;
261
262                 os->bitbuf <<= num_fill_bits;
263                 os->bitbuf |= bits >> (num_bits - num_fill_bits);
264
265                 *--os->out = cpu_to_le16(os->bitbuf);
266                 --os->num_le16_remaining;
267
268                 os->num_free_bits = 16;
269                 num_bits -= num_fill_bits;
270                 bits &= (1U << num_bits) - 1;
271         }
272         os->bitbuf <<= num_bits;
273         os->bitbuf |= bits;
274         os->num_free_bits -= num_bits;
275 }
276
277 /* Flush the output bitstream, ensuring that all bits written to it have been
278  * written to memory.  Returns %true if all bits were output successfully, or
279  * %false if an overrun occurred.  */
280 static bool
281 lzms_output_bitstream_flush(struct lzms_output_bitstream *os)
282 {
283         if (os->num_free_bits != 16)
284                 lzms_output_bitstream_put_bits(os, 0, os->num_free_bits + 1);
285         return !os->overrun;
286 }
287
288 /* Initialize the range encoder @rc to write forwards to the specified
289  * compressed data buffer @out that is @out_limit 16-bit integers long.  */
290 static void
291 lzms_range_encoder_raw_init(struct lzms_range_encoder_raw *rc,
292                             le16 *out, size_t out_limit)
293 {
294         rc->low = 0;
295         rc->range = 0xffffffff;
296         rc->cache = 0;
297         rc->cache_size = 1;
298         rc->out = out;
299         rc->num_le16_remaining = out_limit;
300         rc->first = true;
301         rc->overrun = false;
302 }
303
304 /*
305  * Attempt to flush bits from the range encoder.
306  *
307  * Note: this is based on the public domain code for LZMA written by Igor
308  * Pavlov.  The only differences in this function are that in LZMS the bits must
309  * be output in 16-bit coding units instead of 8-bit coding units, and that in
310  * LZMS the first coding unit is not ignored by the decompressor, so the encoder
311  * cannot output a dummy value to that position.
312  *
313  * The basic idea is that we're writing bits from @rc->low to the output.
314  * However, due to carrying, the writing of coding units with value 0xffff, as
315  * well as one prior coding unit, must be delayed until it is determined whether
316  * a carry is needed.
317  */
318 static void
319 lzms_range_encoder_raw_shift_low(struct lzms_range_encoder_raw *rc)
320 {
321         LZMS_DEBUG("low=%"PRIx64", cache=%"PRIx64", cache_size=%u",
322                    rc->low, rc->cache, rc->cache_size);
323         if ((u32)(rc->low) < 0xffff0000 ||
324             (u32)(rc->low >> 32) != 0)
325         {
326                 /* Carry not needed (rc->low < 0xffff0000), or carry occurred
327                  * ((rc->low >> 32) != 0, a.k.a. the carry bit is 1).  */
328                 do {
329                         if (!rc->first) {
330                                 if (rc->num_le16_remaining == 0) {
331                                         rc->overrun = true;
332                                         return;
333                                 }
334                                 *rc->out++ = cpu_to_le16(rc->cache +
335                                                          (u16)(rc->low >> 32));
336                                 --rc->num_le16_remaining;
337                         } else {
338                                 rc->first = false;
339                         }
340
341                         rc->cache = 0xffff;
342                 } while (--rc->cache_size != 0);
343
344                 rc->cache = (rc->low >> 16) & 0xffff;
345         }
346         ++rc->cache_size;
347         rc->low = (rc->low & 0xffff) << 16;
348 }
349
350 static void
351 lzms_range_encoder_raw_normalize(struct lzms_range_encoder_raw *rc)
352 {
353         if (rc->range <= 0xffff) {
354                 rc->range <<= 16;
355                 lzms_range_encoder_raw_shift_low(rc);
356         }
357 }
358
359 static bool
360 lzms_range_encoder_raw_flush(struct lzms_range_encoder_raw *rc)
361 {
362         for (unsigned i = 0; i < 4; i++)
363                 lzms_range_encoder_raw_shift_low(rc);
364         return !rc->overrun;
365 }
366
367 /* Encode the next bit using the range encoder (raw version).
368  *
369  * @prob is the chance out of LZMS_PROBABILITY_MAX that the next bit is 0.  */
370 static void
371 lzms_range_encoder_raw_encode_bit(struct lzms_range_encoder_raw *rc, int bit,
372                                   u32 prob)
373 {
374         lzms_range_encoder_raw_normalize(rc);
375
376         u32 bound = (rc->range >> LZMS_PROBABILITY_BITS) * prob;
377         if (bit == 0) {
378                 rc->range = bound;
379         } else {
380                 rc->low += bound;
381                 rc->range -= bound;
382         }
383 }
384
385 /* Encode a bit using the specified range encoder. This wraps around
386  * lzms_range_encoder_raw_encode_bit() to handle using and updating the
387  * appropriate probability table.  */
388 static void
389 lzms_range_encode_bit(struct lzms_range_encoder *enc, int bit)
390 {
391         struct lzms_probability_entry *prob_entry;
392         u32 prob;
393
394         /* Load the probability entry corresponding to the current state.  */
395         prob_entry = &enc->prob_entries[enc->state];
396
397         /* Treat the number of zero bits in the most recently encoded
398          * LZMS_PROBABILITY_MAX bits with this probability entry as the chance,
399          * out of LZMS_PROBABILITY_MAX, that the next bit will be a 0.  However,
400          * don't allow 0% or 100% probabilities.  */
401         prob = prob_entry->num_recent_zero_bits;
402         if (prob == 0)
403                 prob = 1;
404         else if (prob == LZMS_PROBABILITY_MAX)
405                 prob = LZMS_PROBABILITY_MAX - 1;
406
407         /* Encode the next bit.  */
408         lzms_range_encoder_raw_encode_bit(enc->rc, bit, prob);
409
410         /* Update the state based on the newly encoded bit.  */
411         enc->state = ((enc->state << 1) | bit) & enc->mask;
412
413         /* Update the recent bits, including the cached count of 0's.  */
414         BUILD_BUG_ON(LZMS_PROBABILITY_MAX > sizeof(prob_entry->recent_bits) * 8);
415         if (bit == 0) {
416                 if (prob_entry->recent_bits & (1ULL << (LZMS_PROBABILITY_MAX - 1))) {
417                         /* Replacing 1 bit with 0 bit; increment the zero count.
418                          */
419                         prob_entry->num_recent_zero_bits++;
420                 }
421         } else {
422                 if (!(prob_entry->recent_bits & (1ULL << (LZMS_PROBABILITY_MAX - 1)))) {
423                         /* Replacing 0 bit with 1 bit; decrement the zero count.
424                          */
425                         prob_entry->num_recent_zero_bits--;
426                 }
427         }
428         prob_entry->recent_bits = (prob_entry->recent_bits << 1) | bit;
429 }
430
431 /* Encode a symbol using the specified Huffman encoder.  */
432 static void
433 lzms_huffman_encode_symbol(struct lzms_huffman_encoder *enc, u32 sym)
434 {
435         LZMS_ASSERT(sym < enc->num_syms);
436         lzms_output_bitstream_put_bits(enc->os,
437                                        enc->codewords[sym],
438                                        enc->lens[sym]);
439         ++enc->sym_freqs[sym];
440         if (++enc->num_syms_written == enc->rebuild_freq) {
441                 /* Adaptive code needs to be rebuilt.  */
442                 LZMS_DEBUG("Rebuilding code (num_syms=%u)", enc->num_syms);
443                 make_canonical_huffman_code(enc->num_syms,
444                                             LZMS_MAX_CODEWORD_LEN,
445                                             enc->sym_freqs,
446                                             enc->lens,
447                                             enc->codewords);
448
449                 /* Dilute the frequencies.  */
450                 for (unsigned i = 0; i < enc->num_syms; i++) {
451                         enc->sym_freqs[i] >>= 1;
452                         enc->sym_freqs[i] += 1;
453                 }
454                 enc->num_syms_written = 0;
455         }
456 }
457
458 static void
459 lzms_encode_length(struct lzms_huffman_encoder *enc, u32 length)
460 {
461         unsigned slot;
462         unsigned num_extra_bits;
463         u32 extra_bits;
464
465         slot = lzms_get_length_slot(length);
466
467         num_extra_bits = lzms_extra_length_bits[slot];
468
469         extra_bits = length - lzms_length_slot_base[slot];
470
471         lzms_huffman_encode_symbol(enc, slot);
472         lzms_output_bitstream_put_bits(enc->os, extra_bits, num_extra_bits);
473 }
474
475 static void
476 lzms_encode_offset(struct lzms_huffman_encoder *enc, u32 offset)
477 {
478         unsigned slot;
479         unsigned num_extra_bits;
480         u32 extra_bits;
481
482         slot = lzms_get_position_slot(offset);
483
484         num_extra_bits = lzms_extra_position_bits[slot];
485
486         extra_bits = offset - lzms_position_slot_base[slot];
487
488         lzms_huffman_encode_symbol(enc, slot);
489         lzms_output_bitstream_put_bits(enc->os, extra_bits, num_extra_bits);
490 }
491
492 static void
493 lzms_begin_encode_item(struct lzms_compressor *ctx)
494 {
495         ctx->lru.lz.upcoming_offset = 0;
496         ctx->lru.delta.upcoming_offset = 0;
497         ctx->lru.delta.upcoming_power = 0;
498 }
499
500 static void
501 lzms_end_encode_item(struct lzms_compressor *ctx, u32 length)
502 {
503         LZMS_ASSERT(ctx->window_size - ctx->cur_window_pos >= length);
504         ctx->cur_window_pos += length;
505         lzms_update_lru_queues(&ctx->lru);
506 }
507
508 /* Encode a literal byte.  */
509 static void
510 lzms_encode_literal(struct lzms_compressor *ctx, u8 literal)
511 {
512         LZMS_DEBUG("Position %u: Encoding literal 0x%02x ('%c')",
513                    ctx->cur_window_pos, literal, literal);
514
515         lzms_begin_encode_item(ctx);
516
517         /* Main bit: 0 = a literal, not a match.  */
518         lzms_range_encode_bit(&ctx->main_range_encoder, 0);
519
520         /* Encode the literal using the current literal Huffman code.  */
521         lzms_huffman_encode_symbol(&ctx->literal_encoder, literal);
522
523         lzms_end_encode_item(ctx, 1);
524 }
525
526 /* Encode a (length, offset) pair (LZ match).  */
527 static void
528 lzms_encode_lz_match(struct lzms_compressor *ctx, u32 length, u32 offset)
529 {
530         int recent_offset_idx;
531
532         LZMS_DEBUG("Position %u: Encoding LZ match {length=%u, offset=%u}",
533                    ctx->cur_window_pos, length, offset);
534
535         LZMS_ASSERT(length <= ctx->window_size - ctx->cur_window_pos);
536         LZMS_ASSERT(offset <= ctx->cur_window_pos);
537         LZMS_ASSERT(!memcmp(&ctx->window[ctx->cur_window_pos],
538                             &ctx->window[ctx->cur_window_pos - offset],
539                             length));
540
541         lzms_begin_encode_item(ctx);
542
543         /* Main bit: 1 = a match, not a literal.  */
544         lzms_range_encode_bit(&ctx->main_range_encoder, 1);
545
546         /* Match bit: 0 = a LZ match, not a delta match.  */
547         lzms_range_encode_bit(&ctx->match_range_encoder, 0);
548
549         /* Determine if the offset can be represented as a recent offset.  */
550         for (recent_offset_idx = 0;
551              recent_offset_idx < LZMS_NUM_RECENT_OFFSETS;
552              recent_offset_idx++)
553                 if (offset == ctx->lru.lz.recent_offsets[recent_offset_idx])
554                         break;
555
556         if (recent_offset_idx == LZMS_NUM_RECENT_OFFSETS) {
557                 /* Explicit offset.  */
558
559                 /* LZ match bit: 0 = explicit offset, not a recent offset.  */
560                 lzms_range_encode_bit(&ctx->lz_match_range_encoder, 0);
561
562                 /* Encode the match offset.  */
563                 lzms_encode_offset(&ctx->lz_offset_encoder, offset);
564         } else {
565                 int i;
566
567                 /* Recent offset.  */
568
569                 /* LZ match bit: 1 = recent offset, not an explicit offset.  */
570                 lzms_range_encode_bit(&ctx->lz_match_range_encoder, 1);
571
572                 /* Encode the recent offset index.  A 1 bit is encoded for each
573                  * index passed up.  This sequence of 1 bits is terminated by a
574                  * 0 bit, or automatically when (LZMS_NUM_RECENT_OFFSETS - 1) 1
575                  * bits have been encoded.  */
576                 for (i = 0; i < recent_offset_idx; i++)
577                         lzms_range_encode_bit(&ctx->lz_repeat_match_range_encoders[i], 1);
578
579                 if (i < LZMS_NUM_RECENT_OFFSETS - 1)
580                         lzms_range_encode_bit(&ctx->lz_repeat_match_range_encoders[i], 0);
581
582                 /* Initial update of the LZ match offset LRU queue.  */
583                 for (; i < LZMS_NUM_RECENT_OFFSETS; i++)
584                         ctx->lru.lz.recent_offsets[i] = ctx->lru.lz.recent_offsets[i + 1];
585         }
586
587         /* Encode the match length.  */
588         lzms_encode_length(&ctx->length_encoder, length);
589
590         /* Save the match offset for later insertion at the front of the LZ
591          * match offset LRU queue.  */
592         ctx->lru.lz.upcoming_offset = offset;
593
594         lzms_end_encode_item(ctx, length);
595 }
596
597 #if 0
598 static void
599 lzms_record_literal(u8 literal, void *_ctx)
600 {
601         struct lzms_compressor *ctx = _ctx;
602
603         lzms_encode_literal(ctx, literal);
604 }
605
606 static void
607 lzms_record_match(unsigned length, unsigned offset, void *_ctx)
608 {
609         struct lzms_compressor *ctx = _ctx;
610
611         lzms_encode_lz_match(ctx, length, offset);
612 }
613
614 static void
615 lzms_fast_encode(struct lzms_compressor *ctx)
616 {
617         static const struct lz_params lzms_lz_params = {
618                 .min_match      = 3,
619                 .max_match      = UINT_MAX,
620                 .max_offset     = UINT_MAX,
621                 .nice_match     = 64,
622                 .good_match     = 32,
623                 .max_chain_len  = 64,
624                 .max_lazy_match = 258,
625                 .too_far        = 4096,
626         };
627
628         lz_analyze_block(ctx->window,
629                          ctx->window_size,
630                          lzms_record_match,
631                          lzms_record_literal,
632                          ctx,
633                          &lzms_lz_params,
634                          ctx->prev_tab);
635
636 }
637 #endif
638
639 /* Fast heuristic cost evaluation to use in the inner loop of the match-finder.
640  * Unlike lzms_get_lz_match_cost(), which does a true cost evaluation, this
641  * simply prioritize matches based on their offset.  */
642 static input_idx_t
643 lzms_lz_match_cost_fast(input_idx_t length, input_idx_t offset, const void *_lru)
644 {
645         const struct lzms_lz_lru_queues *lru = _lru;
646
647         for (input_idx_t i = 0; i < LZMS_NUM_RECENT_OFFSETS; i++)
648                 if (offset == lru->recent_offsets[i])
649                         return i;
650
651         return offset;
652 }
653
654 #define LZMS_COST_SHIFT 5
655
656 /*#define LZMS_RC_COSTS_USE_FLOATING_POINT*/
657
658 static u32
659 lzms_rc_costs[LZMS_PROBABILITY_MAX + 1];
660
661 #ifdef LZMS_RC_COSTS_USE_FLOATING_POINT
662 #  include <math.h>
663 #endif
664
665 static void
666 lzms_do_init_rc_costs(void)
667 {
668         /* Fill in a table that maps range coding probabilities needed to code a
669          * bit X (0 or 1) to the number of bits (scaled by a constant factor, to
670          * handle fractional costs) needed to code that bit X.
671          *
672          * Consider the range of the range decoder.  To eliminate exactly half
673          * the range (logical probability of 0.5), we need exactly 1 bit.  For
674          * lower probabilities we need more bits and for higher probabilities we
675          * need fewer bits.  In general, a logical probability of N will
676          * eliminate the proportion 1 - N of the range; this information takes
677          * log2(1 / N) bits to encode.
678          *
679          * The below loop is simply calculating this number of bits for each
680          * possible probability allowed by the LZMS compression format, but
681          * without using real numbers.  To handle fractional probabilities, each
682          * cost is multiplied by (1 << LZMS_COST_SHIFT).  These techniques are
683          * based on those used by LZMA.
684          *
685          * Note that in LZMS, a probability x really means x / 64, and 0 / 64 is
686          * really interpreted as 1 / 64 and 64 / 64 is really interpreted as
687          * 63 / 64.
688          */
689         for (u32 i = 0; i <= LZMS_PROBABILITY_MAX; i++) {
690                 u32 prob = i;
691
692                 if (prob == 0)
693                         prob = 1;
694                 else if (prob == LZMS_PROBABILITY_MAX)
695                         prob = LZMS_PROBABILITY_MAX - 1;
696
697         #ifdef LZMS_RC_COSTS_USE_FLOATING_POINT
698                 lzms_rc_costs[i] = log2((double)LZMS_PROBABILITY_MAX / prob) *
699                                         (1 << LZMS_COST_SHIFT);
700         #else
701                 u32 w = prob;
702                 u32 bit_count = 0;
703                 for (u32 j = 0; j < LZMS_COST_SHIFT; j++) {
704                         w *= w;
705                         bit_count <<= 1;
706                         while (w >= (1U << 16)) {
707                                 w >>= 1;
708                                 ++bit_count;
709                         }
710                 }
711                 lzms_rc_costs[i] = (LZMS_PROBABILITY_BITS << LZMS_COST_SHIFT) -
712                                    (15 + bit_count);
713         #endif
714         }
715 }
716
717 static void
718 lzms_init_rc_costs(void)
719 {
720         static bool done = false;
721         static pthread_mutex_t mutex = PTHREAD_MUTEX_INITIALIZER;
722
723         if (unlikely(!done)) {
724                 pthread_mutex_lock(&mutex);
725                 if (!done) {
726                         lzms_do_init_rc_costs();
727                         done = true;
728                 }
729                 pthread_mutex_unlock(&mutex);
730         }
731 }
732
733 /*
734  * Return the cost to range-encode the specified bit when in the specified
735  * state.
736  *
737  * @enc         The range encoder to use.
738  * @cur_state   Current state, which indicates the probability entry to choose.
739  *              Updated by this function.
740  * @bit         The bit to encode (0 or 1).
741  */
742 static u32
743 lzms_rc_bit_cost(const struct lzms_range_encoder *enc, u8 *cur_state, int bit)
744 {
745         u32 prob_zero;
746         u32 prob_correct;
747
748         prob_zero = enc->prob_entries[*cur_state & enc->mask].num_recent_zero_bits;
749
750         *cur_state = (*cur_state << 1) | bit;
751
752         if (bit == 0)
753                 prob_correct = prob_zero;
754         else
755                 prob_correct = LZMS_PROBABILITY_MAX - prob_zero;
756
757         return lzms_rc_costs[prob_correct];
758 }
759
760 static u32
761 lzms_huffman_symbol_cost(const struct lzms_huffman_encoder *enc, u32 sym)
762 {
763         return enc->lens[sym] << LZMS_COST_SHIFT;
764 }
765
766 static u32
767 lzms_offset_cost(const struct lzms_huffman_encoder *enc, u32 offset)
768 {
769         u32 slot;
770         u32 num_extra_bits;
771         u32 cost = 0;
772
773         slot = lzms_get_position_slot(offset);
774
775         cost += lzms_huffman_symbol_cost(enc, slot);
776
777         num_extra_bits = lzms_extra_position_bits[slot];
778
779         cost += num_extra_bits << LZMS_COST_SHIFT;
780
781         return cost;
782 }
783
784 static u32
785 lzms_length_cost(const struct lzms_huffman_encoder *enc, u32 length)
786 {
787         u32 slot;
788         u32 num_extra_bits;
789         u32 cost = 0;
790
791         slot = lzms_get_length_slot(length);
792
793         cost += lzms_huffman_symbol_cost(enc, slot);
794
795         num_extra_bits = lzms_extra_length_bits[slot];
796
797         cost += num_extra_bits << LZMS_COST_SHIFT;
798
799         return cost;
800 }
801
802 static u32
803 lzms_get_matches(struct lzms_compressor *ctx,
804                  const struct lzms_adaptive_state *state,
805                  struct raw_match **matches_ret)
806 {
807         *matches_ret = ctx->matches;
808         return lz_sarray_get_matches(&ctx->lz_sarray,
809                                      ctx->matches,
810                                      lzms_lz_match_cost_fast,
811                                      &state->lru);
812 }
813
814 static void
815 lzms_skip_bytes(struct lzms_compressor *ctx, input_idx_t n)
816 {
817         while (n--)
818                 lz_sarray_skip_position(&ctx->lz_sarray);
819 }
820
821 static u32
822 lzms_get_prev_literal_cost(struct lzms_compressor *ctx,
823                            struct lzms_adaptive_state *state)
824 {
825         u8 literal = ctx->window[lz_sarray_get_pos(&ctx->lz_sarray) - 1];
826         u32 cost = 0;
827
828         state->lru.upcoming_offset = 0;
829         lzms_update_lz_lru_queues(&state->lru);
830
831         cost += lzms_rc_bit_cost(&ctx->main_range_encoder,
832                                  &state->main_state, 0);
833
834         cost += lzms_huffman_symbol_cost(&ctx->literal_encoder, literal);
835
836         return cost;
837 }
838
839 static u32
840 lzms_get_lz_match_cost(struct lzms_compressor *ctx,
841                        struct lzms_adaptive_state *state,
842                        input_idx_t length, input_idx_t offset)
843 {
844         u32 cost = 0;
845         int recent_offset_idx;
846
847         cost += lzms_rc_bit_cost(&ctx->main_range_encoder,
848                                  &state->main_state, 1);
849         cost += lzms_rc_bit_cost(&ctx->match_range_encoder,
850                                  &state->match_state, 0);
851
852         for (recent_offset_idx = 0;
853              recent_offset_idx < LZMS_NUM_RECENT_OFFSETS;
854              recent_offset_idx++)
855                 if (offset == state->lru.recent_offsets[recent_offset_idx])
856                         break;
857
858         if (recent_offset_idx == LZMS_NUM_RECENT_OFFSETS) {
859                 /* Explicit offset.  */
860                 cost += lzms_rc_bit_cost(&ctx->lz_match_range_encoder,
861                                          &state->lz_match_state, 0);
862
863                 cost += lzms_offset_cost(&ctx->lz_offset_encoder, offset);
864         } else {
865                 int i;
866
867                 /* Recent offset.  */
868                 cost += lzms_rc_bit_cost(&ctx->lz_match_range_encoder,
869                                          &state->lz_match_state, 1);
870
871                 for (i = 0; i < recent_offset_idx; i++)
872                         cost += lzms_rc_bit_cost(&ctx->lz_repeat_match_range_encoders[i],
873                                                  &state->lz_repeat_match_state[i], 0);
874
875                 if (i < LZMS_NUM_RECENT_OFFSETS - 1)
876                         cost += lzms_rc_bit_cost(&ctx->lz_repeat_match_range_encoders[i],
877                                                  &state->lz_repeat_match_state[i], 1);
878
879
880                 /* Initial update of the LZ match offset LRU queue.  */
881                 for (; i < LZMS_NUM_RECENT_OFFSETS; i++)
882                         state->lru.recent_offsets[i] = state->lru.recent_offsets[i + 1];
883         }
884
885         cost += lzms_length_cost(&ctx->length_encoder, length);
886
887         state->lru.upcoming_offset = offset;
888         lzms_update_lz_lru_queues(&state->lru);
889
890         return cost;
891 }
892
893 static struct raw_match
894 lzms_get_near_optimal_match(struct lzms_compressor *ctx)
895 {
896         struct lzms_adaptive_state initial_state;
897
898         initial_state.lru = ctx->lru.lz;
899         initial_state.main_state = ctx->main_range_encoder.state;
900         initial_state.match_state = ctx->match_range_encoder.state;
901         initial_state.lz_match_state = ctx->lz_match_range_encoder.state;
902         for (int i = 0; i < LZMS_NUM_RECENT_OFFSETS - 1; i++)
903                 initial_state.lz_repeat_match_state[i] =
904                         ctx->lz_repeat_match_range_encoders[i].state;
905         return lz_get_near_optimal_match(&ctx->mc,
906                                          lzms_get_matches,
907                                          lzms_skip_bytes,
908                                          lzms_get_prev_literal_cost,
909                                          lzms_get_lz_match_cost,
910                                          ctx,
911                                          &initial_state);
912 }
913
914 /*
915  * The main loop for the LZMS compressor.
916  *
917  * Notes:
918  *
919  * - This uses near-optimal LZ parsing backed by a suffix-array match-finder.
920  *   More details can be found in the corresponding files (lz_optimal.h,
921  *   lz_sarray.{h,c}).
922  *
923  * - This does not output any delta matches.  It would take a specialized
924  *   algorithm to find them, then more code in lz_optimal.h and here to handle
925  *   evaluating and outputting them.
926  *
927  * - The costs of literals and matches are estimated using the range encoder
928  *   states and the semi-adaptive Huffman codes.  Except for range encoding
929  *   states, costs are assumed to be constant throughout a single run of the
930  *   parsing algorithm, which can parse up to LZMS_OPTIM_ARRAY_SIZE bytes of
931  *   data.  This introduces a source of inaccuracy because the probabilities and
932  *   Huffman codes can change over this part of the data.
933  */
934 static void
935 lzms_normal_encode(struct lzms_compressor *ctx)
936 {
937         struct raw_match match;
938
939         /* Load window into suffix array match-finder.  */
940         lz_sarray_load_window(&ctx->lz_sarray, ctx->window, ctx->window_size);
941
942         /* Reset the match-chooser.  */
943         lz_match_chooser_begin(&ctx->mc);
944
945         while (ctx->cur_window_pos != ctx->window_size) {
946                 match = lzms_get_near_optimal_match(ctx);
947                 if (match.len <= 1)
948                         lzms_encode_literal(ctx, ctx->window[ctx->cur_window_pos]);
949                 else
950                         lzms_encode_lz_match(ctx, match.len, match.offset);
951         }
952 }
953
954 static void
955 lzms_init_range_encoder(struct lzms_range_encoder *enc,
956                         struct lzms_range_encoder_raw *rc, u32 num_states)
957 {
958         enc->rc = rc;
959         enc->state = 0;
960         enc->mask = num_states - 1;
961         for (u32 i = 0; i < num_states; i++) {
962                 enc->prob_entries[i].num_recent_zero_bits = LZMS_INITIAL_PROBABILITY;
963                 enc->prob_entries[i].recent_bits = LZMS_INITIAL_RECENT_BITS;
964         }
965 }
966
967 static void
968 lzms_init_huffman_encoder(struct lzms_huffman_encoder *enc,
969                           struct lzms_output_bitstream *os,
970                           unsigned num_syms,
971                           unsigned rebuild_freq)
972 {
973         enc->os = os;
974         enc->num_syms_written = 0;
975         enc->rebuild_freq = rebuild_freq;
976         enc->num_syms = num_syms;
977         for (unsigned i = 0; i < num_syms; i++)
978                 enc->sym_freqs[i] = 1;
979
980         make_canonical_huffman_code(enc->num_syms,
981                                     LZMS_MAX_CODEWORD_LEN,
982                                     enc->sym_freqs,
983                                     enc->lens,
984                                     enc->codewords);
985 }
986
987 /* Initialize the LZMS compressor.  */
988 static void
989 lzms_init_compressor(struct lzms_compressor *ctx, const u8 *udata, u32 ulen,
990                      le16 *cdata, u32 clen16)
991 {
992         unsigned num_position_slots;
993
994         /* Copy the uncompressed data into the @ctx->window buffer.  */
995         memcpy(ctx->window, udata, ulen);
996         memset(&ctx->window[ulen], 0, 8);
997         ctx->cur_window_pos = 0;
998         ctx->window_size = ulen;
999
1000         /* Initialize the raw range encoder (writing forwards).  */
1001         lzms_range_encoder_raw_init(&ctx->rc, cdata, clen16);
1002
1003         /* Initialize the output bitstream for Huffman symbols and verbatim bits
1004          * (writing backwards).  */
1005         lzms_output_bitstream_init(&ctx->os, cdata, clen16);
1006
1007         /* Calculate the number of position slots needed for this compressed
1008          * block.  */
1009         num_position_slots = lzms_get_position_slot(ulen - 1) + 1;
1010
1011         LZMS_DEBUG("Using %u position slots", num_position_slots);
1012
1013         /* Initialize Huffman encoders for each alphabet used in the compressed
1014          * representation.  */
1015         lzms_init_huffman_encoder(&ctx->literal_encoder, &ctx->os,
1016                                   LZMS_NUM_LITERAL_SYMS,
1017                                   LZMS_LITERAL_CODE_REBUILD_FREQ);
1018
1019         lzms_init_huffman_encoder(&ctx->lz_offset_encoder, &ctx->os,
1020                                   num_position_slots,
1021                                   LZMS_LZ_OFFSET_CODE_REBUILD_FREQ);
1022
1023         lzms_init_huffman_encoder(&ctx->length_encoder, &ctx->os,
1024                                   LZMS_NUM_LEN_SYMS,
1025                                   LZMS_LENGTH_CODE_REBUILD_FREQ);
1026
1027         lzms_init_huffman_encoder(&ctx->delta_offset_encoder, &ctx->os,
1028                                   num_position_slots,
1029                                   LZMS_DELTA_OFFSET_CODE_REBUILD_FREQ);
1030
1031         lzms_init_huffman_encoder(&ctx->delta_power_encoder, &ctx->os,
1032                                   LZMS_NUM_DELTA_POWER_SYMS,
1033                                   LZMS_DELTA_POWER_CODE_REBUILD_FREQ);
1034
1035         /* Initialize range encoders, all of which wrap around the same
1036          * lzms_range_encoder_raw.  */
1037         lzms_init_range_encoder(&ctx->main_range_encoder,
1038                                 &ctx->rc, LZMS_NUM_MAIN_STATES);
1039
1040         lzms_init_range_encoder(&ctx->match_range_encoder,
1041                                 &ctx->rc, LZMS_NUM_MATCH_STATES);
1042
1043         lzms_init_range_encoder(&ctx->lz_match_range_encoder,
1044                                 &ctx->rc, LZMS_NUM_LZ_MATCH_STATES);
1045
1046         for (size_t i = 0; i < ARRAY_LEN(ctx->lz_repeat_match_range_encoders); i++)
1047                 lzms_init_range_encoder(&ctx->lz_repeat_match_range_encoders[i],
1048                                         &ctx->rc, LZMS_NUM_LZ_REPEAT_MATCH_STATES);
1049
1050         lzms_init_range_encoder(&ctx->delta_match_range_encoder,
1051                                 &ctx->rc, LZMS_NUM_DELTA_MATCH_STATES);
1052
1053         for (size_t i = 0; i < ARRAY_LEN(ctx->delta_repeat_match_range_encoders); i++)
1054                 lzms_init_range_encoder(&ctx->delta_repeat_match_range_encoders[i],
1055                                         &ctx->rc, LZMS_NUM_DELTA_REPEAT_MATCH_STATES);
1056
1057         /* Initialize LRU match information.  */
1058         lzms_init_lru_queues(&ctx->lru);
1059 }
1060
1061 /* Flush the output streams, prepare the final compressed data, and return its
1062  * size in bytes.
1063  *
1064  * A return value of 0 indicates that the data could not be compressed to fit in
1065  * the available space.  */
1066 static size_t
1067 lzms_finalize(struct lzms_compressor *ctx, u8 *cdata, size_t csize_avail)
1068 {
1069         size_t num_forwards_bytes;
1070         size_t num_backwards_bytes;
1071         size_t compressed_size;
1072
1073         /* Flush both the forwards and backwards streams, and make sure they
1074          * didn't cross each other and start overwriting each other's data.  */
1075         if (!lzms_output_bitstream_flush(&ctx->os)) {
1076                 LZMS_DEBUG("Backwards bitstream overrun.");
1077                 return 0;
1078         }
1079
1080         if (!lzms_range_encoder_raw_flush(&ctx->rc)) {
1081                 LZMS_DEBUG("Forwards bitstream overrun.");
1082                 return 0;
1083         }
1084
1085         if (ctx->rc.out > ctx->os.out) {
1086                 LZMS_DEBUG("Two bitstreams crossed.");
1087                 return 0;
1088         }
1089
1090         /* Now the compressed buffer contains the data output by the forwards
1091          * bitstream, then empty space, then data output by the backwards
1092          * bitstream.  Move the data output by the backwards bitstream to be
1093          * adjacent to the data output by the forward bitstream, and calculate
1094          * the compressed size that this results in.  */
1095         num_forwards_bytes = (u8*)ctx->rc.out - (u8*)cdata;
1096         num_backwards_bytes = ((u8*)cdata + csize_avail) - (u8*)ctx->os.out;
1097
1098         memmove(cdata + num_forwards_bytes, ctx->os.out, num_backwards_bytes);
1099
1100         compressed_size = num_forwards_bytes + num_backwards_bytes;
1101         LZMS_DEBUG("num_forwards_bytes=%zu, num_backwards_bytes=%zu, "
1102                    "compressed_size=%zu",
1103                    num_forwards_bytes, num_backwards_bytes, compressed_size);
1104         LZMS_ASSERT(compressed_size % 2 == 0);
1105         return compressed_size;
1106 }
1107
1108 static size_t
1109 lzms_compress(const void *uncompressed_data, size_t uncompressed_size,
1110               void *compressed_data, size_t compressed_size_avail, void *_ctx)
1111 {
1112         struct lzms_compressor *ctx = _ctx;
1113         size_t compressed_size;
1114
1115         LZMS_DEBUG("uncompressed_size=%zu, compressed_size_avail=%zu",
1116                    uncompressed_size, compressed_size_avail);
1117
1118         /* Make sure the uncompressed size is compatible with this compressor.
1119          */
1120         if (uncompressed_size > ctx->max_block_size) {
1121                 LZMS_DEBUG("Can't compress %zu bytes: LZMS context "
1122                            "only supports %u bytes",
1123                            uncompressed_size, ctx->max_block_size);
1124                 return 0;
1125         }
1126
1127         /* Don't bother compressing extremely small inputs.  */
1128         if (uncompressed_size < 4) {
1129                 LZMS_DEBUG("Input too small to bother compressing.");
1130                 return 0;
1131         }
1132
1133         /* Cap the available compressed size to a 32-bit integer and round it
1134          * down to the nearest multiple of 2.  */
1135         if (compressed_size_avail > UINT32_MAX)
1136                 compressed_size_avail = UINT32_MAX;
1137         if (compressed_size_avail & 1)
1138                 compressed_size_avail--;
1139
1140         /* Initialize the compressor structures.  */
1141         lzms_init_compressor(ctx, uncompressed_data, uncompressed_size,
1142                              compressed_data, compressed_size_avail / 2);
1143
1144         /* Preprocess the uncompressed data.  */
1145         lzms_x86_filter(ctx->window, ctx->window_size,
1146                         ctx->last_target_usages, false);
1147
1148         /* Compute and encode a literal/match sequence that decompresses to the
1149          * preprocessed data.  */
1150 #if 1
1151         lzms_normal_encode(ctx);
1152 #else
1153         lzms_fast_encode(ctx);
1154 #endif
1155
1156         /* Get and return the compressed data size.  */
1157         compressed_size = lzms_finalize(ctx, compressed_data,
1158                                         compressed_size_avail);
1159
1160         if (compressed_size == 0) {
1161                 LZMS_DEBUG("Data did not compress to requested size or less.");
1162                 return 0;
1163         }
1164
1165         LZMS_DEBUG("Compressed %zu => %zu bytes",
1166                    uncompressed_size, compressed_size);
1167
1168 #if defined(ENABLE_VERIFY_COMPRESSION) || defined(ENABLE_LZMS_DEBUG)
1169         /* Verify that we really get the same thing back when decompressing.  */
1170         {
1171                 struct wimlib_decompressor *decompressor;
1172
1173                 LZMS_DEBUG("Verifying LZMS compression.");
1174
1175                 if (0 == wimlib_create_decompressor(WIMLIB_COMPRESSION_TYPE_LZMS,
1176                                                     ctx->max_block_size,
1177                                                     NULL,
1178                                                     &decompressor))
1179                 {
1180                         int ret;
1181                         ret = wimlib_decompress(compressed_data,
1182                                                 compressed_size,
1183                                                 ctx->window,
1184                                                 uncompressed_size,
1185                                                 decompressor);
1186                         wimlib_free_decompressor(decompressor);
1187
1188                         if (ret) {
1189                                 ERROR("Failed to decompress data we "
1190                                       "compressed using LZMS algorithm");
1191                                 wimlib_assert(0);
1192                                 return 0;
1193                         }
1194                         if (memcmp(uncompressed_data, ctx->window,
1195                                    uncompressed_size))
1196                         {
1197                                 ERROR("Data we compressed using LZMS algorithm "
1198                                       "didn't decompress to original");
1199                                 wimlib_assert(0);
1200                                 return 0;
1201                         }
1202                 } else {
1203                         WARNING("Failed to create decompressor for "
1204                                 "data verification!");
1205                 }
1206         }
1207 #endif /* ENABLE_LZMS_DEBUG || ENABLE_VERIFY_COMPRESSION  */
1208
1209         return compressed_size;
1210 }
1211
1212 static void
1213 lzms_free_compressor(void *_ctx)
1214 {
1215         struct lzms_compressor *ctx = _ctx;
1216
1217         if (ctx) {
1218                 FREE(ctx->window);
1219 #if 0
1220                 FREE(ctx->prev_tab);
1221 #endif
1222                 FREE(ctx->matches);
1223                 lz_sarray_destroy(&ctx->lz_sarray);
1224                 lz_match_chooser_destroy(&ctx->mc);
1225                 FREE(ctx);
1226         }
1227 }
1228
1229 static const struct wimlib_lzms_compressor_params lzms_default = {
1230         .hdr = sizeof(struct wimlib_lzms_compressor_params),
1231         .min_match_length = 2,
1232         .max_match_length = UINT32_MAX,
1233         .nice_match_length = 32,
1234         .max_search_depth = 50,
1235         .max_matches_per_pos = 3,
1236         .optim_array_length = 1024,
1237 };
1238
1239 static const struct wimlib_lzms_compressor_params *
1240 lzms_get_params(const struct wimlib_compressor_params_header *_params)
1241 {
1242         const struct wimlib_lzms_compressor_params *params =
1243                 (const struct wimlib_lzms_compressor_params*)_params;
1244
1245         if (params == NULL)
1246                 params = &lzms_default;
1247
1248         return params;
1249 }
1250
1251 static int
1252 lzms_create_compressor(size_t max_block_size,
1253                        const struct wimlib_compressor_params_header *_params,
1254                        void **ctx_ret)
1255 {
1256         struct lzms_compressor *ctx;
1257         const struct wimlib_lzms_compressor_params *params = lzms_get_params(_params);
1258
1259         if (max_block_size == 0 || max_block_size >= INT32_MAX) {
1260                 LZMS_DEBUG("Invalid max_block_size (%u)", max_block_size);
1261                 return WIMLIB_ERR_INVALID_PARAM;
1262         }
1263
1264         ctx = CALLOC(1, sizeof(struct lzms_compressor));
1265         if (ctx == NULL)
1266                 goto oom;
1267
1268         ctx->window = MALLOC(max_block_size);
1269         if (ctx->window == NULL)
1270                 goto oom;
1271
1272 #if 0
1273         ctx->prev_tab = MALLOC(max_block_size * sizeof(ctx->prev_tab[0]));
1274         if (ctx->prev_tab == NULL)
1275                 goto oom;
1276 #endif
1277
1278         ctx->matches = MALLOC(min(params->max_match_length -
1279                                         params->min_match_length + 1,
1280                                   params->max_matches_per_pos) *
1281                                 sizeof(ctx->matches[0]));
1282         if (ctx->matches == NULL)
1283                 goto oom;
1284
1285         if (!lz_sarray_init(&ctx->lz_sarray, max_block_size,
1286                             params->min_match_length,
1287                             params->max_match_length,
1288                             params->max_search_depth,
1289                             params->max_matches_per_pos))
1290                 goto oom;
1291
1292         if (!lz_match_chooser_init(&ctx->mc,
1293                                    params->optim_array_length,
1294                                    params->nice_match_length,
1295                                    params->max_match_length))
1296                 goto oom;
1297
1298         /* Initialize position and length slot data if not done already.  */
1299         lzms_init_slots();
1300
1301         /* Initialize range encoding cost table if not done already.  */
1302         lzms_init_rc_costs();
1303
1304         ctx->max_block_size = max_block_size;
1305
1306         *ctx_ret = ctx;
1307         return 0;
1308
1309 oom:
1310         lzms_free_compressor(ctx);
1311         return WIMLIB_ERR_NOMEM;
1312 }
1313
1314 static u64
1315 lzms_get_needed_memory(size_t max_block_size,
1316                        const struct wimlib_compressor_params_header *_params)
1317 {
1318         const struct wimlib_lzms_compressor_params *params = lzms_get_params(_params);
1319
1320         u64 size = 0;
1321
1322         size += max_block_size;
1323         size += sizeof(struct lzms_compressor);
1324         size += lz_sarray_get_needed_memory(max_block_size);
1325         size += lz_match_chooser_get_needed_memory(params->optim_array_length,
1326                                                    params->nice_match_length,
1327                                                    params->max_match_length);
1328         size += min(params->max_match_length -
1329                     params->min_match_length + 1,
1330                     params->max_matches_per_pos) *
1331                 sizeof(((struct lzms_compressor*)0)->matches[0]);
1332         return size;
1333 }
1334
1335 static bool
1336 lzms_params_valid(const struct wimlib_compressor_params_header *_params)
1337 {
1338         const struct wimlib_lzms_compressor_params *params =
1339                 (const struct wimlib_lzms_compressor_params*)_params;
1340
1341         if (params->hdr.size != sizeof(*params) ||
1342             params->max_match_length < params->min_match_length ||
1343             params->min_match_length < 2 ||
1344             params->optim_array_length == 0 ||
1345             min(params->max_match_length, params->nice_match_length) > 65536)
1346                 return false;
1347
1348         return true;
1349 }
1350
1351 const struct compressor_ops lzms_compressor_ops = {
1352         .params_valid       = lzms_params_valid,
1353         .get_needed_memory  = lzms_get_needed_memory,
1354         .create_compressor  = lzms_create_compressor,
1355         .compress           = lzms_compress,
1356         .free_compressor    = lzms_free_compressor,
1357 };