Fine-tuning LLMs

TeeTracker
6 min readJul 22, 2023

--

Tasks to finetune

Before Fine-tuning (base or pretrained model)

Fine-turning

make pretrained LLM better

A supervised learning process involves fine-tuning a Language Model (LLM) using instruction prompts.

dataset: prompt-completion pairs

Prompt datasets (prompt-completion pairs)

Aims to improve the performance and adaptability of a pre-trained language model by training it on specific tasks using instruction prompts.

Classical training process

a common supervised learning process

Pseudocode for PyTorch Style

Instruction Fine-tuning, 2 types of task

Type 1

Catastrophic Forgetting (degrades model performance)

Catastrophic forgetting occurs when a machine learning model forgets previously learned information as it learns new information.

This process is especially problematic in sequential learning scenarios where the model is trained on multiple tasks over time.

Catastrophic forgetting is a common problem in machine learning, especially in deep learning models.

Example

A sentiment judgment task. We fine-tune the model to give sentiment results instead of sentences, and it works.

However, this reduction in ability affects other tasks, such as when the prompt asks for a name, resulting in an incorrect sentiment.

Performing full-finetuning can lead to catastrophic forgetting because it changes all parameters on the model. Since PEFT only updates a small subset of parameters, it’s more robust against this catastrophic forgetting effect.
  • With PEFT, most parameters of the LLM are unchanged, and that helps make it less prone to catastrophic forgetting.
  • With PEFT, we can change just a small amount of parameters when fine-tuning, so during inference you can combine the original model with the new parameters, instead of duplicating the entire model for each new task you want to perform fine-tuning.
  • Because most parameters are frozen, we typically only need to train 15%-20% of the original LLM weights, making the training process less expensive (less memory required)

Some about PEFTs

  • Reparameterization methods create a new low-rank transformation of the original network weights to train, decreasing the trainable parameter count while still working with high-dimensional matrices. LoRa is a common technique in this category.
  • Additive methods freeze all of the original LLM weights and introduce new model components to fine-tune to a specific task.
  • LoRA represents large weight matrices as two smaller, rank decomposition matrices, and trains those instead of the full weights. The product of these smaller matrices is then added to the original weights for inference.
Use LoRA for adapting to new, different tasks
  • A soft prompt refers to a set of trainable tokens that are added to a prompt. Unlike the tokens that represent language, these tokens can take on any value within the embedding space. The token values may not be interpretable by humans but are located in the embedding space close to words related to the language prompt or task to be completed.
  • Prompt Tuning focuses on optimizing the prompts given to the model using trainable tokens that don’t correspond directly to human language. The number of tokens you choose to train, however, would be a hyperparameter of your training process.
  • By training a smaller number of parameters, whether through selecting a subset of model layers to train, adding new, small components to the model architecture, or through the inclusion of soft prompts, the amount of memory needed for training is reduced compared to full fine-tuning.

Another way to mitigate catastrophic forgetting is by using regularization techniques to limit the amount of change that can be made to the weights of the model during training. This can help to preserve the information learned during earlier training phases and prevent overfitting to the new data.

Type 2

To prevent catastrophic forgetting it is important to fine-tune on multiple tasks with a lot of data.

Prompting or Fine-tuning

Evaluation

Common Framework

Perform simple metric calculations similar to other machine-learning tasks using:

  • recall

The recall metric measures the number of words or unigrams that are matched between the reference and the generated output divided by the number of words or unigrams in the reference.

  • precision

Precision measures the unigram matches divided by the output size.

  • F-score

The harmonic mean of both of recall and precision.

Basic metrics that only focused on individual words, hence the one in the name, and don’t consider the ordering of the words. It can be deceptive, ie: Gen output: It is not cold outside. The score will be the same.
Using bigrams, you’re able to calculate a ROUGE-2, matches instead of individual words, notice that the scores are lower than the ROUGE-1 scores.
Look for the longest common subsequence(LCS)present in both the generated output and the reference output.

Notice

  • A unigram is equivalent to a single word
  • A bigram is two words
  • The n-gram is a group of n-words

Model optimizations for deployment

--

--