All Projects

Diffusion on Circles

A tiny DDPM trained on sklearn's two concentric circles. This is a tiny diffusion model trained to predict noise that generate circles. Live demo below, note that it may take some time to generate the circles and animation.

Live demo

Press Generate to draw fresh samples from random noise. The animation plays the 100 reverse-diffusion steps in order — chaos at the top, clean rings at the bottom.

Generated samples (blue) overlaid on real data (green)

Samples
Diffusion step
Status idle

Press Generate to sample from the model.

Code

import torch
import torch.nn as nn
import torch.optim as optim
import math
import numpy as np

from sklearn.datasets import make_circles

samples = 2000
noisy_circles_x, noisy_circles_y = make_circles(n_samples=samples, noise=0.05, random_state=6)
scaled_circles_x = (noisy_circles_x - np.mean(noisy_circles_x, axis=0)) / np.std(noisy_circles_x, axis=0)

beta_start    = 1e-4
beta_end      = 0.02
t_max_steps   = 100

def f_beta(i):
    return beta_start + (beta_end - beta_start) * (i / (t_max_steps - 1))

def f_alpha(i):
    return 1 - f_beta(i)

class TinyDiffusionAutocoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.embed  = nn.Embedding(t_max_steps, 8)
        self.ln1    = nn.Linear(2 + 8, 64)
        self.ln12_a = nn.ReLU()
        self.ln2    = nn.Linear(64, 64)
        self.ln23_a = nn.ReLU()
        self.ln3    = nn.Linear(64, 2)

    def forward(self, x):
        r, t = x
        em = self.embed(t)
        r = torch.cat([r, em], dim=-1)
        r = self.ln12_a(self.ln1(r))
        r = self.ln23_a(self.ln2(r))
        r = self.ln3(r)
        return r


model             = TinyDiffusionAutocoder()
circle_tensor     = torch.from_numpy(scaled_circles_x).float()
beta_tensor       = torch.tensor([f_beta(i)  for i in range(t_max_steps)], dtype=torch.float32)
alpha_tensor      = torch.tensor([f_alpha(i) for i in range(t_max_steps)], dtype=torch.float32)
alpha_prod_tensor = torch.cumprod(alpha_tensor, dim=0, dtype=float)

optimizer         = optim.Adam(model.parameters(), lr=0.001)
loss_function     = nn.MSELoss()

batch_size = 500
assert circle_tensor.shape[0] % batch_size == 0


@torch.no_grad()
def sample_reverse(num_samples):
    model.eval()
    x_t = torch.randn(num_samples, 2)
    for t in reversed(range(t_max_steps)):
        t_en        = torch.full((num_samples,), t, dtype=torch.long)
        eps_hat     = model((x_t, t_en))
        alpha_t     = alpha_tensor[t]
        alpha_bar_t = alpha_prod_tensor[t]
        beta_t      = beta_tensor[t]

        # DDPM reverse mean
        term1 = 1.0 / torch.sqrt(alpha_t)
        term2 = (1.0 - alpha_t) / torch.sqrt(1.0 - alpha_bar_t)
        mu_t  = term1 * (x_t - term2 * eps_hat)

        if t > 0:
            z   = torch.randn_like(x_t)
            x_t = mu_t + torch.sqrt(beta_t) * z
        else:
            x_t = mu_t
    return x_t


for epoch in range(1000):
    model.train()
    training_loss, training_count = 0, 0
    for batch_index in range(0, circle_tensor.shape[0], batch_size):
        optimizer.zero_grad()
        circles     = circle_tensor[batch_index:batch_index + batch_size]
        e_noise     = torch.randn(batch_size, 2)
        t_en        = torch.randint(low=0, high=t_max_steps, size=(batch_size,), dtype=torch.long)
        alpha_bar_t = alpha_prod_tensor[t_en].unsqueeze(-1)
        x_t         = (torch.sqrt(alpha_bar_t) * circles +
                       torch.sqrt(1 - alpha_bar_t) * e_noise).float()

        preds = model((x_t, t_en))
        loss  = loss_function(preds, e_noise)
        loss.backward()
        optimizer.step()
        training_loss  += loss.item()
        training_count += 1

    print(f"Epoch={epoch} | Loss={training_loss / training_count}")