All Posts

SFT for LLMs: A Practical Guide to Supervised Fine-Tuning

Supervised fine-tuning teaches LLMs task behavior before preference tuning and RLHF.

Abstract AlgorithmsAbstract Algorithms
ยทยท12 min read

AI-assisted content.

TLDR: Supervised fine-tuning (SFT) is the stage where a pretrained model learns task-specific response behavior from curated input-output examples. It is usually the first alignment step after pretraining and often the foundation for later RLHF. Good SFT depends more on data quality and format consistency than on exotic training tricks.


๐Ÿ“– Why SFT Is the First Practical Alignment Layer

LLaMA-2 was released as a capable base model โ€” but it wouldn't follow instructions. Meta released LLaMA-2-Chat two weeks later after SFT on 27,540 instruction-response pairs. The same 7B weights, entirely different behaviour. This post shows you exactly how SFT works and how to run it.

Pretraining gives a model broad language competence. It does not automatically make the model a useful assistant for your product or domain.

SFT bridges this gap by teaching behavior through examples:

  • follow instructions,
  • keep output format constraints,
  • answer in the right tone,
  • avoid irrelevant verbosity,
  • prioritize domain-specific facts.

You can think of SFT as "behavior shaping with demonstrations." The model sees prompt-response pairs and learns to imitate the desired response distribution.

Training stageMain objectiveTypical data source
PretrainingLearn broad language patternsLarge unsupervised corpora
SFTLearn task/assistant behaviorCurated prompt-response pairs
RLHFOptimize preference and helpfulness/safety trade-offsHuman or model preference data

Without SFT, RLHF usually has weak foundations. You cannot reliably optimize preference signals if base task behavior is still inconsistent.


๐Ÿ” Building SFT Data That Actually Improves Behavior

Most SFT failures are data failures.

Core design rules

  • Keep prompt format consistent across the dataset.
  • Make expected outputs unambiguous.
  • Include edge cases, not only happy-path examples.
  • Remove contradictory style instructions.
  • Version your data schema.
Data issueObservable symptomFix
Inconsistent response styleModel changes tone unpredictablyStandardize answer templates
Weak negatives / no counterexamplesHallucination on ambiguous promptsAdd hard prompts with strict references
Overly narrow data domainModel fails outside training nichesAdd broad but relevant coverage
No format penalties in labelsOutput breaks JSON/markdown contractsInclude exact format exemplars

A compact, high-quality dataset often beats a massive noisy one.

๐Ÿ“Š SFT Data-to-Training Pipeline

flowchart TD
    Raw[Raw Prompt-Response Pairs]
    Clean[Deduplicate + Normalize (schema, formatting)]
    Format[Format as Chat Messages (system / user / assistant)]
    Split[Train / Val / Holdout Split]
    Tokenize[Tokenize + Mask (loss on assistant only)]
    Train[Train SFT Model (cross-entropy loss)]
    Eval[Evaluate on Held-Out Set (task + behavioral metrics)]
    Done{Quality Gate Pass?}
    Deploy[Publish Model Version]
    Fix[Improve Data Balance and Retrain]

    Raw --> Clean --> Format --> Split --> Tokenize --> Train --> Eval --> Done
    Done -->|Yes| Deploy
    Done -->|No| Fix --> Train

This diagram maps every stage of the SFT data journey, from raw prompt-response pairs through deduplication, format normalisation, tokenisation, and into the training loop. The quality gate at the end โ€” and the feedback arrow back to data improvement โ€” is the most important node: it reinforces that SFT is a data-centric iteration loop, not a one-shot training job. Teams that skip this gate typically discover regressions in production rather than in evaluation.

๐Ÿ“Š SFT Training Loop Sequence

sequenceDiagram
    participant D as DataLoader
    participant T as Tokenizer
    participant M as Model (+ LoRA)
    participant L as Loss Function
    participant O as Optimizer

    D->>T: instruction + response batch
    T->>M: Token IDs + attention mask
    T->>M: Labels (assistant tokens only)
    M->>M: Forward pass
    M->>L: Predicted logits
    L->>L: Cross-entropy on assistant span
    L->>O: Backpropagate gradients
    O->>M: Update LoRA A & B weights

This sequence diagram shows exactly which components participate in one SFT training step and in what order. The DataLoader and Tokenizer prepare the batch; the model runs a forward pass; the loss function computes cross-entropy only on assistant-token spans; and the optimizer updates only the LoRA adapters, leaving the frozen base model unchanged. The key takeaway is that label masking (sending assistant-only labels to the loss function) is not optional โ€” without it, the model wastes gradient signal on user and system tokens.


โš™๏ธ What SFT Optimizes in the Model

{
  "messages": [
    {"role": "system", "content": "You are a concise cloud architecture assistant."},
    {"role": "user", "content": "Compare event-driven and request-response systems."},
    {"role": "assistant", "content": "Event-driven systems react to events asynchronously ..."}
  ]
}

If your target deployment expects chat format, train in chat format. SFT should mirror inference conditions as much as possible.


โš™๏ธ What SFT Optimizes in the Model

SFT typically uses next-token prediction loss on labeled assistant responses, conditioned on prior context.

Given target tokens y_1 ... y_T, loss is:

[ \mathcal{L}{SFT} = - \sum{t=1}^{T} \log P(yt \mid x, y{<t}) ]

Where x includes system and user context.

Important practical point

You usually mask loss on user/system tokens and compute loss only on assistant target spans. If you compute loss on everything, the model may waste capacity modeling prompt boilerplate instead of response quality.

Configuration choiceCommon optionWhy it matters
Label maskingAssistant-only lossFocuses optimization on response behavior
Sequence packingEnabled for throughputBetter GPU utilization
Max context lengthTask dependent (2k, 4k, 8k+)Controls truncation risk and memory
PrecisionBF16/FP16Throughput and stability balance

SFT is simple conceptually, but these operational details strongly affect quality.


๐Ÿง  Deep Dive: Distribution Shift, Forgetting, and Evaluation

Internals: why catastrophic forgetting happens

If your SFT dataset is narrow, the model can over-specialize and lose general capabilities from pretraining. This is catastrophic forgetting.

You will observe:

  • strong performance on narrow in-domain prompts,
  • degraded general QA or reasoning behavior,
  • brittle outputs when user phrasing changes.

A practical mitigation is mixing:

  • domain-specific instruction data,
  • general instruction-following exemplars,
  • format-control examples.

Mathematical intuition: balancing objectives

You can view SFT as optimizing a weighted objective over data subsets:

[ \mathcal{L} = \lambdad \mathcal{L}{domain} + \lambdag \mathcal{L}{general} + \lambdaf \mathcal{L}{format} ]

If lambda_d dominates too hard, you may get excellent domain style but weaker general competence.

Performance analysis: what to track

Metric familyExample metricsWhy you need it
Task qualityAccuracy, F1, exact matchDomain success criteria
Behavioral qualityInstruction adherence, conciseness scoreAssistant usability
Format reliabilityJSON validity, schema pass rateProduction integration safety
Safety controlsToxicity/refusal policy checksRisk management

Teams that only track loss curves often ship models that look fine in training dashboards but fail product expectations.


๐Ÿ“Š SFT Pipeline from Dataset to Deployment

flowchart TD
    A[Collect prompts and gold responses] --> B[Clean and normalize schema]
    B --> C[Split train, validation, holdout]
    C --> D[Train SFT model or adapter]
    D --> E[Run task and behavioral evals]
    E --> F{Pass quality gate?}
    F -- No --> G[Fix data balance and retrain]
    G --> D
    F -- Yes --> H[Publish model version]
    H --> I[Monitor drift and regression]

Treat SFT as a data-centric iteration loop. Most gains come from better datasets and evaluations, not from endlessly changing optimizer settings.


๐ŸŒ Real-World Applications of SFT

Customer support copilots

SFT teaches:

  • policy-compliant tone,
  • escalation patterns,
  • concise troubleshooting sequences.

Developer assistants

SFT improves:

  • structured explanations,
  • code-style consistency,
  • safer command recommendations.

Internal enterprise knowledge bots

SFT aligns:

  • response templates,
  • references to approved documents,
  • role-specific answer depth.
Product typeTypical SFT focus
Chat assistantInstruction following and tone
Workflow botDeterministic format outputs
Domain Q&ATerminology and factual precision

โš–๏ธ Trade-offs & Failure Modes: Trade-offs and Failure Modes in SFT

MistakeWhat happensBetter approach
Training on noisy auto-generated labelsModel imitates noise confidentlyHuman-curated or filtered labels
Overfitting on benchmark-like dataGreat benchmark score, weak real usageInclude realistic production prompts
Ignoring holdout evaluationHidden regressionsKeep immutable holdout set
Skipping post-training safety checksDeployment riskAdd policy and abuse test suite

SFT does not magically "fix" model truthfulness. It improves behavior patterns, but factual correctness still depends on knowledge freshness, retrieval design, and grounding strategy.


๐Ÿงญ Decision Guide: When SFT Is Enough and When It Is Not

ScenarioRecommended path
You need format adherence and style controlSFT is often enough
You need better human preference alignmentSFT + RLHF (or DPO-like preference tuning)
You have strict hardware limitsSFT with PEFT adapters
You need broad factual updatesAdd retrieval + data refresh, not only SFT

SFT is foundational, but it is one layer in a larger alignment and product architecture stack.


๐Ÿงช Practical Example with TRL SFTTrainer

This example shows a complete SFT training run on a chat-formatted dataset using Hugging Face TRL's SFTTrainer โ€” the same library stack used to fine-tune open-source models like LLaMA and Mistral variants. It was chosen because SFTTrainer handles the most error-prone details (assistant-only label masking, chat-template application, LoRA wiring) in a single coherent API. As you read, pay attention to the SFTConfig parameters โ€” max_seq_length, gradient_accumulation_steps, and bf16 โ€” and how they trade off between GPU memory, throughput, and training stability.

from datasets import load_dataset
from trl import SFTTrainer, SFTConfig
from transformers import AutoModelForCausalLM, AutoTokenizer

model_name = "meta-llama/Llama-3.1-8B"

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")

dataset = load_dataset("json", data_files="sft_train.jsonl", split="train")

config = SFTConfig(
    output_dir="./sft-out",
    per_device_train_batch_size=2,
    gradient_accumulation_steps=8,   # effective batch = 2ร—8 = 16; accumulate to match larger-batch stability
    learning_rate=2e-5,              # lower than pretraining LR โ€” prevents overwriting the base model's knowledge
    num_train_epochs=2,
    logging_steps=20,
    bf16=True,                       # BF16 mixed precision: faster and more stable than FP16 on Ampere+ GPUs
    max_seq_length=4096,             # set to your longest example; shorter truncates, wasting training signal
)

trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=dataset,
    args=config,
)

trainer.train()

Production note: pair this with automatic eval jobs and regression gates. A training script without an eval pipeline is a repeatable way to regress behavior.


๐Ÿ› ๏ธ Hugging Face TRL and Axolotl: SFT in Practice

Hugging Face TRL (SFTTrainer) is the standard Python-native SFT implementation โ€” it handles assistant-only label masking, sequence packing, and LoRA integration in a single Trainer subclass. Axolotl is a YAML-driven fine-tuning framework built on top of TRL and PEFT that lets you run the entire SFT pipeline from a config file with no boilerplate Python, making it the preferred tool for teams that want reproducible, configuration-managed fine-tuning runs.

# TRL SFTTrainer: full SFT pipeline with assistant-only label masking and LoRA
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig
from trl import SFTTrainer, SFTConfig
from datasets import load_dataset

model_name = "meta-llama/Llama-3.1-8B"
tokenizer  = AutoTokenizer.from_pretrained(model_name)
model      = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")

# Dataset in chat format: list of {"messages": [{"role": ..., "content": ...}]}
dataset = load_dataset("json", data_files="sft_train.jsonl", split="train")

# Apply chat template so the data matches the model's expected format
def apply_template(examples):
    return {"text": tokenizer.apply_chat_template(examples["messages"], tokenize=False)}

dataset = dataset.map(apply_template)

lora_config = LoraConfig(
    r=16,         # rank 16: more behavioral capacity than r=8, appropriate for full instruction-following SFT
    lora_alpha=32,  # scaling = 2ร—r; standard ratio keeps the adapter's gradient magnitude consistent
    target_modules=["q_proj", "v_proj"],  # query + value projections drive the most behavioral change per parameter
)

trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=dataset,
    peft_config=lora_config,    # optional: enables LoRA-based SFT
    args=SFTConfig(
        output_dir="./sft-out",
        dataset_text_field="text",
        max_seq_length=4096,
        per_device_train_batch_size=2,
        gradient_accumulation_steps=8,
        num_train_epochs=2,
        bf16=True,
        # packing=True,  # uncomment to enable sequence packing for better GPU utilization
    ),
)
trainer.train()
trainer.save_model("./sft-out/final")

Axolotl eliminates the Python boilerplate entirely โ€” the same run above becomes a YAML config:

# axolotl_config.yml
base_model: meta-llama/Llama-3.1-8B
datasets:
  - pat
h: sft_train.jsonl
    type: chat_template
sequence_len: 4096
adapter: lora
lora_r: 16
lora_alpha: 32
lora_target_modules: [q_proj, v_proj]
bf16: true
num_epochs: 2
micro_batch_size: 2
gradient_accumulation_steps: 8
output_dir: ./axolotl-sft-out
# Run the full SFT pipeline from the config file
axolotl train axolotl_config.yml
ToolInterfaceBest for
TRL SFTTrainerPython APICustom data pipelines, programmatic hyperparameter search
AxolotlYAML configReproducible runs, team collaboration, fast iteration

For a full deep-dive on Hugging Face TRL and Axolotl, dedicated follow-up posts are planned.


๐Ÿ“š Field Notes for Better SFT Runs

  • Write labeling guidelines before collecting large datasets.
  • Include adversarial prompts early; do not postpone them.
  • Compare against the base model on the same prompts.
  • Keep a changelog of data-mix and hyperparameter changes.
  • Tie every model release to a measurable acceptance threshold.

๐Ÿ“Œ TLDR: Summary & Key Takeaways

TLDR: SFT is the stage that converts pretrained language ability into reliable product behavior by fine-tuning on curated input-output demonstrations โ€” data quality is the most important lever.

  • SFT is the main stage for teaching LLM behavior from demonstrations.
  • Data schema quality and consistency matter more than most optimizer tweaks.
  • Label masking, sequence strategy, and eval design are practical quality levers.
  • SFT often precedes RLHF, but remains valuable on its own for many products.
  • Reliable deployment requires explicit quality gates, not just low training loss.

One-liner: SFT is where pretrained language ability becomes product behavior.


Share
Abstract Algorithms

Written by

Abstract Algorithms

@abstractalgorithms