GPT-Neo is a 20 billion parameter LLM. 20 billion 32 bit floating point numbers translates to 80 gigabytes. Not to mention the memory taken by intermediate matrix calculations and the optimizer's gradient calculations. Most large compute-oriented machines top out at 64 gigs of RAM. Can we load this directly on a mid-range GPU? Quantization to the rescue.
We will load our model in 4 bit quantized mode. Set the quantization data type to NF4. (explained in my earlier post ) Set our model to use double quantization, so that the quantization constants are quantized as well. We'll set our computation datatype to BF16 instead of the default 32 bit float. Enable gradient checkpointing, which allows the gradient to be computed and stored in memory only when it is needed at the expense of slower backward pass.
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
model_id = "EleutherAI/gpt-neox-20b"
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config, device_map='auto')
from peft import prepare_model_for_kbit_training
model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training(model)
Parameter efficient fine tuning essentially inolves freezing most of the model parameters and training either a subset of parameters or introducing middleware parameters for training. We will be using QLoRA for our fine tuning. We hae already handled the 'Q' part of QLoRA by quantizing the model above. LoRA is a relatively new method of fine-tuning. Basically, adjusting a weight matrix W while training will lead to a final weight W'. The difference matrix (W' - W) is a very low rank matrix that can be decomposed into 2 matrices of size d x r (r being the rank). Therefore instead of dealing with d x d parameters, we now have to deal only with 2 x d x r parameters. Let's initialize our Low Rank Adapters (LoRA) config.
from peft import LoraConfig, get_peft_model
config = LoraConfig(
r=8,
lora_alpha=32,
target_modules=["query_key_value"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM"
)
model = get_peft_model(model, config)
We can check the number of parameters that will be trained. 8650752 out of a total 10597552128. That is approximately 0.0817 % of the total parameters.
Let us understand the LoRAConfig object first. There are certain hyperparameters associated with LoRA. The rank of the 2 decoomposed adapters is set with 'r'. It is usually set to 4 or 8 but can also be 2. The 'alpha' value represents the scaling factor, the adapter weights are scaled by alpha/r before adding to the LLM weight matrix. Dropout is configured for training with a dropout probability. We can now proceed with training our LLM on a toy dataset.
from datasets import load_dataset
data = load_dataset("Abirate/english_quotes")
data = data.map(lambda samples: tokenizer(samples["quote"]), batched=True)
# needed for gpt-neo-x tokenizer
tokenizer.pad_token = tokenizer.eos_token
trainer = transformers.Trainer(
model=model,
train_dataset=data["train"],
args=transformers.TrainingArguments(
report_to='none',
per_device_train_batch_size=1,
gradient_accumulation_steps=4,
warmup_steps=2,
max_steps=10,
learning_rate=2e-4,
fp16=True,
logging_steps=1,
output_dir="outputs",
optim="paged_adamw_8bit",
save_strategy='best'
),
data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
)
model.config.use_cache = False # silence the warnings. Please re-enable for inference!
start = time.time()
trainer.train()
end = time.time()
We first load a small dataset of 2500 english quotes. We need to set a pad token to pad sequences to the same length during training. This is needed for batched parallel processing and numerical stability. We pad our sequences to the length of the longest sequence with the end-of-sentence token. The trainer config is an especially convenient object to define all the reqd. training parameters. It is basically the complete training loop rolled into 1 convenient object. It calculates the loss, performs backpropagation, updates the weight using the optimizer and repeats until the stopping condition is reached.
It should be clear from the code that we are using 16bit Floats for all the calulations instead of the default 32 bit. We are also using 8 bit AdamW Paged optimizer. This optimizer dequantizes the 8-bit optimizer states to FP16, performs the update and then quantizes the states back to 8-bit for storage. Paging is done if we run out of GPU memory, the memory will be shifted page-by-page, from GPU to CPU. There are various other arguements that can be passed to the Trainer object, have a look here.
It took me about 2.5 hours to fine tune a 20billion parameter LLM on some 2500 data points on a very basic NVidia T4 16GB GPU (for free). Not too bad!