All Projects

Diffusion on MNIST

A class-conditional DDPM trained on MNIST. A small U-Net with FiLM-modulated ResNet-style blocks predicts the noise at each diffusion step, conditioned on a sinusoidal time embedding plus a learned class-label embedding. Pick a digit below and the model will denoise pure Gaussian noise into a handwritten sample.


Generated 28×28 sample (upscaled)

Selected digit
Status idle

Pick a digit, then press Generate.

Samples

10 samples per digit (one column each, 0–9) drawn at the end of training.

MNIST diffusion samples

Live demo

Pick a digit and press Generate. The runner walks 1000 reverse-diffusion steps on CPU, so expect a few seconds per sample.

Code

Full source on GitHub — abridged here for readability.

import torch
import torch.nn as nn

beta_start  = 1e-4
beta_end    = 0.02
t_max_steps = 1000

def f_beta(i):  return beta_start + (beta_end - beta_start) * (i / (t_max_steps - 1))
def f_alpha(i): return 1 - f_beta(i)


class TinyDiffusion(nn.Module):
    """U-Net with FiLM-conditioned blocks. Conditioning = sinusoidal(t) + Embedding(label)."""
    def __init__(self):
        super().__init__()
        chanels = [64, 128, 256]

        self.t_embed = nn.Sequential(nn.Linear(32, 32), nn.ReLU(), nn.Linear(32, 32))
        self.l_embed = nn.Embedding(10, 32)

        # encoder, bottleneck, decoder — each conv block followed by GroupNorm
        # and FiLM (gamma, beta) generated from `cond` via a tiny linear layer.
        # See diffusion_mnist.py for the full architecture.

    def film(self, linear, cond):
        gamma, beta = linear(cond).chunk(2, dim=1)
        return gamma.unsqueeze(-1).unsqueeze(-1), beta.unsqueeze(-1).unsqueeze(-1)


beta_tensor       = torch.tensor([f_beta(i)  for i in range(t_max_steps)])
alpha_tensor      = torch.tensor([f_alpha(i) for i in range(t_max_steps)])
alpha_prod_tensor = torch.cumprod(alpha_tensor, dim=0)


@torch.no_grad()
def sample_reverse(net, num_samples, value):
    net.eval()
    x_t = torch.randn(num_samples, 1, 28, 28)
    for t in reversed(range(t_max_steps)):
        t_en  = torch.full((num_samples,), t, dtype=torch.long)
        label = torch.full((num_samples,), value, dtype=torch.long)
        eps_hat = net((x_t, t_en, label))

        alpha_t     = alpha_tensor[t]
        alpha_bar_t = alpha_prod_tensor[t]
        beta_t      = beta_tensor[t]

        mu_t = (1.0 / torch.sqrt(alpha_t)) * (
            x_t - ((1.0 - alpha_t) / torch.sqrt(1.0 - alpha_bar_t)) * eps_hat
        )

        if t > 0:
            x_t = mu_t + torch.sqrt(beta_t) * torch.randn_like(x_t)
        else:
            x_t = mu_t

    return x_t


# Training loop: standard DDPM — sample t uniformly, noise the input
# with sqrt(alpha_bar_t) * x + sqrt(1 - alpha_bar_t) * eps, train MSE
# between predicted and true epsilon. 20 epochs, Adam(1e-4), cosine LR.