Most Large Language Models are gargantuan. BigScience's Bloom LLM has 176 billion parameters. Try fitting that on a single machine. 176 billion is 176 with 9 zeroes. With a 32-bit floating point data type, it translates to 704 gigabytes of memory. And those are just the parameters(weights & biases) of the model. Add in the memory required for compuatations for each layer and the number runs much higher. Discouraged? Fear not, quantization to the rescue.
The parameters of an LLM are represented as floating point numbers. Floating-point numbers are typically represented using a sign, a significand (mantissa), a base, and an exponent. The sign indicates whether the number is positive or negative. The significand (mantissa) represents the significant digits of the number, and the exponent determines the magnitude (position of the decimal point). The exponent indirectly represents the range while the mantissa represents the precision. A standard 'float' variable takes 32 bits of memory, 'double' would take 64 bits of memory. The standard FP32 takes 1 bit for the sign 8 bits for the exponent and the remaining 23 bits for the mantissa. This translates to ~1.18e-38 … ~3.40e38 with 6–9 significant decimal digits precision. Do we need all this precision? Or this much range?
The crucial question to ask ourselves is, what precision and range do I need to represent every possible parameter & calculation within my LLM. Research suggests that most LLM parameters lie between -3.5 and +3.5, with some occasional outliers. (Did that lightbulb in your brain turn on?) If most of the values lie within this small range, we can sacrifice some of the range in our floating point number. Et Voila, we have arrived at the definition of quantization.
There are a lot of different ways to represent floating point numbers. Each has its own advantages & drawbacks. It is essentially a trade-off between range & precision.
The IEEE 754 standard format created FP16, the half-precision floating-point format with 1 bit for sign, 5 bits for the exponent and 10 bits for the mantissa (fraction). Another 16-bit format originally developed by Google is called “Brain Floating Point Format”, or “bfloat16” for short. It has 1 bit for sign, 8 bits for exponent and 7 bits for the fraction. Compared to FP16, BF16 has 8 bits for the exponent instead of 5; giving it a much higher precision.
We can go another way and simply represent our numbers as integers by scaling the original numbers with a constant. Of course, this will inevitable lead to a loss in precision of the number but we can live with that. There are mainly 2 ways to do this. Abs-Max quantization & zero-point quantization.
Abs-max quantization divides the original number by the absolute maximum value within our set of numbers and multiplies by a scaling factor (127), to map inputs into the range [-127, 127]. To retrieve the original numbers, the quantized integer is divided by the quantization factor, acknowledging some loss of precision due to rounding.
For example, let’s say we have a maximum value of 3.2. The FP32 number 0.1 would be quantized to the integer -> (0.1 × 127/3.2) = 4. If we want to dequantize it, we would get 4 × 3.2/127 = 0.1008, which implies a rounding error of 0.008
For zero-point quantization, the input values are first scaled by the total range of values (255), then divided by the difference between the maximum and minimum values. This distribution is then shifted by the zero-point to map it into the range [-128, 127]. First, we calculate the scale factor and the zero-point value.
scale = 255 / (max - min)
zero-point = - round(scale * min(x)) - 128
Next, we can use these variables to quantize or dequantize our weights. ->
quantized = round(scale * x + zero-point)
de-quantized = ( quantized - zero-point ) / scale
Essentially, we are first squeezing/stretching all our numbers to have a range of 255 and then shifting it by the zero-point to fit within -128 & 127.
Both the above methods can represent floats with 8 bit integers. Unfortunately, performing calculations with these rounded-off integers will lead to more rounding errors. Errors will start propagating across layers and we will eventually end up with a rather simplified, 'rounded-off' output which would be very different from the actual output. To prevent this, we usually only store the parameters in low-precision integers but every calculation is still performed with floating point numbers. HuggingFace's bitsandbytes package offers this functionality.
bitsandbytes offers the LLM.int8 datatype. It performs matrix multiplication of outliers using FP16 and the non-outliers using INT8 with vector-wise quantization. The outlier threshold defaults to 6 but can be set manually. Dequantize the non-outlier results (INT8 to FP16) and add them to the outlier results to get the full result in FP16. You can even load supported models in 8bit mode. We just need to specify load_in_8bit=True when loading the model. Unfortunately, this only works on GPU since most CPUs do not support 8 & 16 bit matrix multiplication.
We can go one step further and quantize our floats to 4 bit. This new datatype is called Normalized Float 4bit or NF4 for short. We split our range of numbers into blocks. Normalize each block with its absolute maximum value to make sure the weights fit within the quantization range of [-1, 1]. The actual quantization mapping uses a set of predefined unique 16 float values. You can go through the QLoRA paper to understand why they use these numbers. HuggingFace bitsandbytes supports 4 bit quantization as well. e just need to specify load_in_4bit=True when loading the model.
In addition to quantizing the weights, we can also quantize the optimizer values while training. For instance, if we were using the ADAM optimizer while fine-tuning our model. The optimizer has to maintain a history of gradients to find the minima optimally. Storing these gradient statistics also takes up a lot of memory, why not quantize these as well. bitsandbytes provides adam8bit optimizer for this purpose.