Table-based Assymetric Numeral System (tANS)
2021-05-25
Overview
A fast an efficient way to implement entropy coding is table-based asymmetric numeral systems (tANS), also known as finite-state entropy (FSE). It uses a number of states (represented by integers) and a lookup table to determine the symbol to emit and the way to compute the next state. The lookup table has an entry for each state, and the entry specifies:
- The symbol to emit.
- The number of bits to read.
- A base value to add to the newly read bits.
Here is the code for decoding a symbol:
pub fn decode_sym(&mut self, input: &mut dyn ReadBits) -> BoxResult<S> {
let (_, nbits, base) = self.state;
let s = *base | input.read_bits(*nbits as u32)?;
self.state = &self.table[s as usize];
Ok(self.state.0)
}
Compression is achieved by having nbits
be smaller for symbols that occur more frequently. To still ensure
that every state in the table can be reached after every symbol, the
table may contain multiple entries that code for a given symbol. To make
this possible, the table must have more entries (states) than there
are symbols. Optimal compression is achieved by having a symbol that
makes up 1/x of the input also make up 1/x of the states. The larger
the table size, the more closely we can approximate this optimum.
As an example, if we are using 8 states, symbol sym that
makes up 5/8 in the input would appear 5 times in the table. We then
choose base
and nbits
for these 5 entries so
that all states in the table can be reached from a state with symbol
sym. Since there are 5 entries, we can reach all 8 states
by choosing nbits = 0 for two of the entries and nbits = 1 for the
remaining 3 entries. Since nbits = 0 gives a single successor state,
nbits = 1 gives two possible successor states, and 2 * 1 + 3 * 2 = 8,
this allows us to reach all possible states. If the compressed data
uses these entries equally often, we use an average of 0.75 bits per
occurrence.
Here is what a full table might look like for an alphabet of 3 symbols (a, b, c), with relative frequencies (a: 2, b: 5, c: 1):
state | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | sym | c | b | a | b | b | a | b | b | nbits | 3 | 1 | 2 | 0 | 1 | 2 | 0 | 1 | base | 0 | 6 | 4 | 1 | 4 | 0 | 0 | 2 |
Using this table, a decoder in state 2 will emit
the symbol a
, then read two bits and add them to the base
value of 4, resulting in a new state in the range [4, 7]. A
decoder in state 1 will emit symbol b
and move to state 0
without needing to read any bits.
A Simple Encoder
Encoding works in the opposite direction from decoding. It starts from the end of the data and works towards the beginning, and creates a sequence of bits that will decode to the original data.
The decode step proceeds from an origin state to a successor state by emitting a symbol, reading a number of bits, and adding those bits to a base value to get the successor state. The encoder needs to determine the origin state, given the successor state and the symbol.
One way to implement an encoder would be to simply look at every state and see if it has the correct symbol and its base and nbits allow the successor state to be reached from there:
fn tans_prev_state<S: Copy + Eq>(successor: u32,
sym: S,
table: &[(S, u8, u32)]) -> u32 {
for i in 0..table.len() {
let (s, nbits, base) = table[i];
if s == sym && base <= successor && (base + (1 << nbits)) > successor {
return i as u32;
}
}
panic!("Cannot encode symbol; table malformed.");
}
This produces correct results (given an appropriate
table), but it's slow: O(nstates)
for every symbol in the
input.
Fast Encoding - Overview
To find the origin state quickly, we use the symbol and the successor state to compute an index in the range [0, number of states), then use this index to look up the origin state in an array.
First, let's consider a single symbol. In the decode table, each
entry allows successor states in a certain interval to be
reached. The interval runs from base
to base + (1
<< nbits) - 1
. For example, an entry with
base
4 and nbits
2 describes an interval
from binary 1000 to binary 1011. We construct the
decode table such that, for each symbol:
- Every successor state is in exactly one interval.
- The highest nbits value is at most one greater than the lowest nbits value.
- Intervals with lower nbits values have lower base values than intervals with higher nbits values.
For example, if we are creating a table with 8 entries and a symbol occurs 5 times in the table, we must create 5 intervals that cover all successor states. We can do this using 2 intervals of size 1 (nbits = 0) and 3 intervals of size 2 (nbits = 1). Then all the small intervals come first, followed by the larger intervals:
successor state | 0 | 1 | 2 3 | 4 5 | 6 7 | base | 0 | 1 | 2 | 4 | 6 | nbits | 0 | 0 | 1 | 1 | 1 |
This gives us a unique origin state for every
successor state. It also gives us a way to determine
nbits
from the successor state: successor states below
some threshold have the low nbits value, and successor states starting
with the threshold have an nbits value one higher. In the example,
this threshold is 2.
Using the nbits
value, we can assign each interval a
unique number from a consecutive range. At http://fastcompression.blogspot.com/2014/02/fse-tricks-memory-efficient-subrange.html,
Yann Collet describes a way to compute such a number: (state +
nstates) >> nbits
, where nstates is the total number of
states (in our example, 8). This maps successor states to a range of
consecutive integers:
successor state | 0 | 1 | 2 3 | 4 5 | 6 7 |
base | 0 | 1 | 2 | 4 | 6 |
nbits | 0 | 0 | 1 | 1 | 1 |
interval id | 8 | 9 | 5 | 6 | 7 |
By doing this for each symbol, we give each symbol its own mapping of successor state to interval id. For the { a: 2, b: 5, c: 1 } relative frequency example, we get:
successor state | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | a | 2 | 3 | b | 8 | 9 | 5 | 6 | 7 | c | 1 |
Each symbol gets a number of interval ids equal to
the number of entries for that symbol in the decode table. As a
result, the number of interval ids for all symbols together equals the
number of states. This suggests that we can use a table of size
nstates
to look up the origin state for each successor
state and symbol. To do this, we must convert each interval id to an
index into such a table. Because the interval ids ranges are
contiguous for each symbol, we can do this by adjusting each interval
id by a per-symbol offset. For example, if we pick
offsets
symbol | offset a | 6 b | 5 c | 6
and compute the position in the lookup table as
((interval id) + offset) mod nstates
, we get
non-overlapping ranges that fully use the range [0, nstates):
successor state | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | a | 0 | 1 | b | 5 | 6 | 2 | 3 | 4 | c | 7 |
The way these offsets are computed is by, for each symbol,
(1) selecting a starting position in the lookup table,
(2) finding the lowest interval id for that symbol, and
(3) subtracting (2) from (1), modulo nstates
so that we get a nonnegative value.
For the starting position, we simply start at 0 for the first symbol, then increment by the number of entries for each symbol that we process. In the example, this means the starting position is 0 for symbol a, 2 for symbol b, and 7 for symbol c.
The lowest interval id for a symbol is equal to the number of occurrences of that symbol in the table. In the example, it is 2 for a, 5 for b, and 1 for c.
Computing the offsets:
- For symbol a, the lowest interval id is 2, and we want to start at 0, so the offset is -2 modulo 8, which is 6.
- Since a occurs twice, we want to start the range for the next symbol at 2. Since that symbol, b, occurs 5 times, its lowest interval id is 5. Therefore, the offset we need is -3 modulo 8, which is 5.
- For symbol c, we want to start the range at 7. The symbol occurs once, so its lowest interval id is 1. This means the offset we need is 6 modulo 8, which is 6.
Now we have [0, 1] for a, [2, 6] for b, and [7] for c.
Summarizing the above, the lookup tables we need are:
- Per symbol:
- low nbits value
- threshold where we transition to the high nbits value
- offset by which to adjust interval id
- Per adjusted interval id (table size
nstates
): - origin state
The next section puts this into code.
Fast Encoding - Code
The encoder determines the origin state, given a
symbol and a successor state. To do this, it uses two lookup tables.
The first, symtab
, is indexed by symbol, and contains
the number of bits to emit and an offset value that is used to
determine the index into the second table. The second table,
origin
gives the origin state. Besides references
to the tables, the encoder also keeps track of the successor
state and the total number of states that exist.
pub struct Encoder {
/// One entry per symbol. First item is coded_nbits, second is an offset
/// used to get an index into the origins table.
symtab: Vec::<(u32, u32)>,
/// Lookup table used to find origin state.
origin: Vec::<u32>,
/// Current state.
state: u32,
/// Total number of states.
nstates: u32,
}
Here is the function that does the encoding:
pub fn encode_sym(&mut self, sym: u32) {
// Get nbits and offset.
let (coded_nbits, offset) : (u32, u32) = self.symtab[sym as usize];
let nbits = (self.state + coded_nbits) >> 24;
self.acc_bits(self.state, nbits);
// Set new state.
let idx = ((self.state + self.nstates) >> nbits) + offset;
self.state = self.origin[(idx & (self.nstates - 1)) as usize];
}
Here, acc_bits()
is used to accumulate output
bits, which we will eventually output. The reason we don't output them directly
is that the bits the decoder will need first are generated last.
The Encoder's Symbol Table
The symbol table (symtab
) in the encoder
allows us to look up two things, given a symbol:
nbits
, which tells us how many bits we need to provide to the decoderoffset
, which is used to compute an index into the origin table
First, we compute how many bits we need to encode into the stream for the transition from an origin state with that symbol to a given successor state. The number of bits is either the same for all successor states or it is one bit greater for successor state starting at some threshold. Given the lower number of bits and the threshold, we could compute the number of bits for a given successor state as
nbits = if successor >= threshold { low_nbits + 1 } else { low_nbits }
but we save a branch instruction by encoding this as a single
number coded_nbits = ((low_nbits + 1) << 24) - threshold
,
after which we can compute nbits as
nbits = (coded_nbits + successor) >> 24;
The following function computes coded_nbits:
/// Computes a value x such that (x + s) >> 24 gives the number
/// of bits to read in state s.
fn compute_coded_nbits(freq: u32, sbits: u32) -> u32 {
// For a symbol with no occurrences, return 0.
if freq == 0 { return 0 }
// We need sbits bits to encode every possible state value.
// Not all of these need to be read/written at every state
// transition: If a symbol has 2**n occurrences in the table,
// we can get n bits of information from knowing which of
// the states for that symbol we are in, and only need to
// encode sbits - n bits.
//
// We compute n as 32 - (freq - 1).leading_zeros(), which
// gives us a lower bound on the number of bits that need to
// be encode for the state transition.
let low_nbits = sbits - (32 - (freq - 1).leading_zeros());
// Number of successor states we can get to using freq
// origin states and low_nbits per state.
let covered = freq << low_nbits;
// If we do not already cover all possible successor states,
// we will increase our coverage by adding an extra bit to
// encode for some states. We do this for successor states
// starting at some threshold. Since we have 1 << sbits
// successor states, we need to cover an additional
// (1 << sbits) - covered states. The threshold, then, is
// the old value of covered minus the additional states to be
// covered, which can be simplified to:
let threshold = covered + covered - (1 << sbits);
// The return value is computed so that adding the number
// of the successor state, then right-shifting the result by
// 24 results in the number of bits to encode for the state.
((low_nbits + 1) << 24) - threshold
}
The choice of 24 as the number of bits to shift by means we can accommodate up to 2**24 states. Since 24 is a multiple of 8, a shift by 24 can also be relatively efficiently implemented on processors that don't have an instruction for it.
Having computed coded_nbits
, we now compute the offset
to store in the symbol table. This is much simpler. The code is:
let offset = (o - freqs[s]) & mask;
where freqs[s]
is the number of entries
in the decoder table for symbol s, and the bitwise and with
mask
ensures that the value we compute is in the range
[0, nstates). The variable o
is the index into the
origin
table where we want entries for symbol s to
start. It is initialized to 0 for the first symbol, then incremented
by freqs[s]
for each symbol processed. As a result,
the entries for symbol 0 start at index 0, the entries for symbol 1
start at index freqs[0], the entries for symbol 2 start at index
freqs[0] + freqs[1], and so on.
The Encoder's Origin Table
The origin table is where the encoder looks up the
origin state, given the successor state and nbits
. To
populate the origin table, we visit each symbol sym and
select freq[sym]
states to code for that symbol. The most
straightforward way would be to just select consecutive states:
state | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | sym | a | a | b | b | b | b | b | c |
While simple, this assignment is not generally the best choice for tANS. Lower-numbered successor states tend to require fewer bits to encode, so assigning states to symbols in order assigns all the cheaper to encode states to lower-numbered symbols. We can avoid that by spreading the symbols across the states more evenly.
To illustrate how spreading the symbols helps, imagine that, instead
of 8 states, we had 32. Keeping the same relative frequencies, we now
have (a: 8, b: 20, c: 4). Of the 20 states for b
, 8 have nbits = 0.
By spreading the symbols throughout the state space, we might end up
with those 8 states being:
state | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | … sym | b | a | b | c | b | a | b | b | …
Using this assignment, there is a 0-bit transition from
b
to any symbol, whereas with sequential assignment of states to
symbols, all 0-bit transitions from b
would lead to a
.
To spread out symbols across states, instead of assigning symbols to consecutive states, we step through the states using a stride:
/// Compute stride so that:
/// (a) It is a relative prime of nstates.
/// (b) It is a bit over half of nstates.
/// Property (a) ensures that a single iteration will populate all states.
/// Property (b) ensures that symbols with multiple occurrences will
/// be spread roughly evenly across the state space.
fn compute_stride(nstates: u32) -> u32 {
if nstates <= 8 {
return 5;
} else {
return (nstates >> 1) + (nstates >> 3) + 3;
}
}
Initializing the Encoder
Using the helper functions shown before, here is the code that initializers the encoder state:
// impl Encoder {
pub fn new(sbits: u32, freqs: &[u32]) -> Encoder {
let nstates = 1 << sbits;
let nsyms = freqs.len();
let mask = nstates - 1;
let mut encoder = Encoder {
symtab: Vec::with_capacity(nsyms),
origin: Vec::new(),
output: Vec::new(),
bits: 0,
state: 0,
nstates: nstates,
need_bits: 32,
};
encoder.origin.resize(nstates as usize, 0);
let mut o : u32 = 0;
// Populate symbol table with coded_nbits and offset.
for s in 0..nsyms {
let coded_nbits = compute_coded_nbits(freqs[s], sbits);
let offset = o.wrapping_sub(freqs[s]) & mask;
encoder.symtab.push((coded_nbits, offset));
o += freqs[s];
}
// Populate origin table.
let stride = compute_stride(nstates);
let mut o = 0; // index into origin table
let mut s = stride & mask; // state number
for sym in 0..nsyms {
for _ in 0..freqs[sym] {
encoder.origin[o] = s;
o += 1;
s = (s + stride) & mask;
}
}
encoder
}
Other Bits and Pieces
The previous sections describe the core algorithms for encoding and decoding data using tANS. For a complete implementation, the following are also needed:
- Functions to read the input and write the output.
- A way for the decoder to build the decode table. For example, in some cases, we might specify the decode table ahead of time. In others, we might transmit symbol counts before the encoded data, and use those counts to build the decode table.
- A way to set the correct initial state for the encoder. For example, the data to be decoded could start by a number of bits enough to encode the initial state.
- The decoder must also know how many symbols to decode. For an example of why this is necessary, consider a situation where the decoder has processed all input, but finds itself in a state from which there is a 0-bit transition. Should the decoder take that transition and decode an additional symbol or not?
- In the encoder, we don't need to output any bits for the first symbol we encode (these would be bits that come after the last symbol the decoder decodes). We may want to add a special case for this.
The example code contains the support code necessary to produce a working decoder and encoder for tANS.