Machine Learning
Tiny Transformer — Character-Level Language Model
April 2025
Python · PyTorch · CUDA
Project 6
I thought it would be fun to train a decoder-only transformer trained from scratch on WhatsApp chat logs with my best friend.
This was a quite a jump from my previous projects, which had me learning a whole bunch. But was a good experience.
Overview
The model has around ~5 million parameters, tiny by any modern standard,
but enough to pick up conversational rhythm, names, punctuation habits,
and the general cadence of two people chatting over years.
Training data
The dataset is exported WhatsApp chat history between myself and my best friend Gareth.
3 years of messages, jokes, plans, and general nonsense — exported, cleaned, and fed in as a single
flat text file. The model sees everything at the character level, so it learns spelling, spacing,
line breaks, and even the Declan: / Gareth: speaker prefixes.
The training split was 80/20. Context window was 128 characters.
Hardware
Training was done on a rented RTX 5060 Ti through Vast.ai for a few dollars. This was the first time
running on CPU alone would have been painful, I think it would have been around 10 hours for a single epoch...
The GPU made each epoch fast enough to iterate on hyperparameters in a reasonable session.
Batch size 128, AdamW with a learning rate of 3e-5 and weight decay of 0.1,
gradient clipping at 1.0, trained for 10 epochs.
What I learnt
- How multi-head self-attention works end-to-end — splitting Q/K/V across heads, scaling by √d, applying a causal mask, then projecting back.
- The role of residual connections and layer norm in keeping gradients healthy through deep stacks.
- Why the FFN expands to 4 × d_model before projecting back down.
- GELU vs ReLU in practice, and why the smoother activation helps transformers.
- Dropout placement — both on attention weights and on the FFN output.
- Weight decay and gradient clipping to stabilise training.
- Debugging shape mismatches in
CrossEntropyLoss — it expects (B×T, V), not (B, T, V).
Example output
Prompted with "Declan: Henlo boyo", the model generated:
Code
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import math
from tqdm import tqdm
batch_size = 128
context_window = 128
model_size = 256
model_heads = 8
with open("data/garethdeclan/clean_data.txt", "r") as file:
content = file.read()
unique_chars = sorted(set(content))
vocab = { char: i for i, char in enumerate(unique_chars) }
vocab_inv = { i: char for i, char in enumerate(unique_chars) }
vocab_size = len(vocab)
content_vocab = [vocab[c] for c in content]
raw_training_data = content_vocab[:int(0.8 * len(content_vocab))]
raw_validation_data = content_vocab[ int(0.8 * len(content_vocab)):]
class TinyDataset(Dataset):
def __init__(self, raw_data):
self.data = raw_data
def __len__(self):
return len(self.data) - context_window
def __getitem__(self, idx):
inputs = torch.tensor(self.data[idx : idx + context_window], dtype=torch.long)
targets = torch.tensor(self.data[idx + 1 : idx + context_window + 1], dtype=torch.long)
return inputs, targets
class TransformerBlock(nn.Module):
def __init__(self, d_model, n_heads, device):
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
self.device = device
self.Q = nn.Linear(d_model, d_model)
self.K = nn.Linear(d_model, d_model)
self.V = nn.Linear(d_model, d_model)
self.proj = nn.Linear(d_model, d_model)
self.layernorm1 = nn.LayerNorm(d_model)
self.layernorm2 = nn.LayerNorm(d_model)
self.ffn1 = nn.Linear(d_model, 4 * d_model)
self.ffn2 = nn.Linear(4 * d_model, d_model)
self.ffnA = nn.GELU()
self.softmax = nn.Softmax(-1)
self.attn_drop = nn.Dropout(0.2)
self.ffn_drop = nn.Dropout(0.2)
def forward(self, x):
B, T, _ = x.shape
dims = self.d_model // self.n_heads
norm = self.layernorm1(x)
q = self.Q(norm).reshape(B, T, self.n_heads, dims).transpose(1, 2)
k = self.K(norm).reshape(B, T, self.n_heads, dims).transpose(1, 2)
v = self.V(norm).reshape(B, T, self.n_heads, dims).transpose(1, 2)
a = q @ k.transpose(-2, -1) / math.sqrt(dims)
mask = ~torch.tril(torch.ones(T, T, dtype=torch.bool, device=self.device))
a = self.attn_drop(self.softmax(a.masked_fill(mask, float("-inf")))) @ v
a_out = self.proj(a.transpose(1, 2).reshape(B, T, self.d_model))
ar = x + a_out
out = ar + self.ffn_drop(self.ffn2(self.ffnA(self.ffn1(self.layernorm2(ar)))))
return out
class TinyDecoder(nn.Module):
def __init__(self, d_model, n_heads, device):
super().__init__()
self.device = device
self.token_embed = nn.Embedding(vocab_size, d_model)
self.pos_embed = nn.Embedding(context_window, d_model)
self.embed_drop = nn.Dropout(0.2)
self.blocks = nn.ModuleList([
TransformerBlock(d_model, n_heads, device) for _ in range(6)
])
self.outln = nn.LayerNorm(d_model)
self.outffn = nn.Linear(d_model, vocab_size)
def forward(self, x):
t = x.shape[1]
e = self.embed_drop(
self.token_embed(x) +
self.pos_embed(torch.arange(t, dtype=torch.long, device=self.device))
)
for block in self.blocks:
e = block(e)
return self.outffn(self.outln(e))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = TinyDecoder(d_model=model_size, n_heads=model_heads, device=device).to(device)
losser = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=0.3e-4, weight_decay=0.1)
train_loader = DataLoader(TinyDataset(raw_training_data), batch_size=batch_size, shuffle=True)
val_loader = DataLoader(TinyDataset(raw_validation_data), batch_size=batch_size)
for epoch in range(10):
model.train()
pbar = tqdm(train_loader, desc=f"Epoch {epoch+1} train")
for inputs, targets in pbar:
inputs, targets = inputs.to(device), targets.to(device)
optimizer.zero_grad()
logits = model(inputs).reshape(-1, vocab_size)
loss = losser(logits, targets.reshape(-1))
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
pbar.set_postfix(loss=f"{loss.item():.4f}")
model.eval()
val_loss, n = 0.0, 0
with torch.no_grad():
for inputs, targets in val_loader:
inputs, targets = inputs.to(device), targets.to(device)
logits = model(inputs).reshape(-1, vocab_size)
val_loss += losser(logits, targets.reshape(-1)).item()
n += 1
print(f"epoch={epoch+1} val_loss={val_loss/n:.4f}")
torch.save(model.state_dict(), f"checkpoints/checkpoint_{epoch}.pth")