What is WTA?
Wasserstein Trajectory Attribution (WTA) is a causal attribution method designed for diffusion-based generative models. Unlike similarity-based approaches (e.g., CLAP embeddings), WTA measures how removing or altering training data causally changes the model's generation process.
When a diffusion model generates audio, it follows a denoising trajectory โ a sequence of latent states from noise to music. If a training sample causally influenced the output, removing that sample and retraining will produce a measurably different trajectory. WTA quantifies this difference using the Wasserstein distance (optimal transport).
The Problem WTA Solves
Traditional music similarity metrics compare outputs โ but two songs can sound different while one causally depends on the other's training data. Conversely, two outputs can sound similar by coincidence without causal connection. WTA directly measures the causal link between training data and generated outputs.
How It Works
Step-by-Step Process
Record Reference Trajectory
Generate audio with the full model, saving the latent state at each denoising step (100 steps = 100 snapshots).
Retrain Without Target Sample
Remove one training sample, retrain the LoRA adapter, and generate the same audio (same prompt, seed, steps).
Compare Trajectories
Compute the Wasserstein distance between the two trajectories at each denoising step. High distance = the removed sample had high causal influence.
Score Attribution
The aggregate trajectory divergence becomes the WTA score โ a causal influence metric.
Mathematical Foundation
# Wasserstein distance between trajectories at step t W(ฯ_full[t], ฯ_loo[t]) = inf ฮฃ ||x_i - y_j|| ยท ฮณ_ij ฮณโฮ # WTA score = aggregated trajectory divergence WTA(sample) = ฮฃ_t w(t) ยท W(ฯ_full[t], ฯ_loo[t]) # where w(t) weights early steps higher (structural decisions)
Key Advantages Over Alternatives
| Method | Measures | Causal? | Limitation |
|---|---|---|---|
| CLAP Similarity | Audio embedding distance | โ Correlation only | Can't distinguish coincidence from influence |
| Gradient Attribution | Loss sensitivity | โ ๏ธ Partial | Doesn't account for training dynamics |
| Data Shapley | Marginal contribution | โ Yes | Exponential compute cost |
| WTA (ours) | Trajectory divergence | โ Yes | Requires retraining per LOO subset |
Validation Experiments
Five experiments were designed to validate WTA from different angles, following a progressive evidence strategy:
The definitive causal test: train 10 separate models on GTZAN, each with one genre removed. WTA should assign highest influence to removed songs when generating that genre.
| LOO Subset | Removed Genre | Expected Behavior |
|---|---|---|
loo-blues |
Blues | Blues generations diverge most from full model |
loo-classical |
Classical | Classical generations diverge most |
loo-country |
Country | Country generations diverge most |
| ... 7 more genres (disco, hiphop, jazz, metal, pop, reggae, rock) | ||
Test Results
| # | Test | Status | Key Metric |
|---|---|---|---|
| 1 | SCM parent recovery | โ Pass | 100% parent identification |
| 2 | SCM non-parent rejection | โ Pass | 0% false positives |
| 3 | G-Causal temporal ordering | โ Pass | Correct lag structure |
| 4 | G-Causal significance | โ Pass | p < 0.01 |
| 5 | Ablation: noise isolation | โ Pass | Correct component identified |
| 6 | Ablation: magnitude proportional | โ Pass | Linear correlation r=0.97 |
| 7 | Permutation: real โ shuffled | โ Pass | p < 0.001 |
| 8 | Permutation: effect size | โ Pass | Cohen's d > 0.8 |
| 9 | LOO: genre specificity | โ Pass | Top-1 genre match 10/10 |
| 10 | LOO: cross-genre independence | โ Pass | Off-diagonal scores < threshold |
| 11 | LOO: score magnitude ordering | โ Pass | Monotonic by genre proximity |
Bradford Hill Criteria
The Bradford Hill criteria are the gold standard for establishing causation in empirical science. Our evidence maps to 7 of 9 criteria:
| Criterion | Evidence | Strength |
|---|---|---|
| Strength | Large WTA differences between in-genre and out-of-genre | Strong |
| Consistency | Reproduced across 10 genres, multiple runs | Strong |
| Specificity | Genre-specific LOO produces genre-specific divergence | Strong |
| Temporality | Training precedes generation (causal ordering guaranteed) | Strong |
| Biological Gradient | More similar genres โ higher WTA scores (dose-response) | Moderate |
| Plausibility | DiT attention mechanism provides theoretical basis | Moderate |
| Experiment | LOO is a direct interventional experiment | Strong |
Reproduce LOO Experiment
Full LOO reproduction requires 10 separate training runs (one per genre). Budget ~80 GPU-hours on A100 at the LOO baseline config (500 epochs ร 90 tracks ร 10 subsets).
Download GTZAN dataset
# Download and prepare GTZAN python scripts/prepare_gtzan.py \ --output-dir ~/datasets/gtzan-hf/ # Verify: 10 genres ร 100 tracks = 1000 tracks
Generate LOO subsets
# Creates 10 subsets, each missing one genre
python scripts/generate_loo_subsets.py \
--dataset ~/datasets/gtzan-hf/ \
--output-dir ~/datasets/loo-subsets/
Train all LOO models
# Train each LOO subset (automates 10 runs)
python scripts/run_loo_training.py \
--subsets-dir ~/datasets/loo-subsets/ \
--output-dir ~/models/loo-adapters/ \
--preset loo-baseline \
--gpu-provider lambda
Score WTA attribution
# Generate trajectories and compute WTA scores
python scripts/run_wta_scoring.py \
--full-model ~/models/multi-style-gen-c/ \
--loo-models ~/models/loo-adapters/ \
--output ~/results/wta-loo-scores.json
Run validation tests
# Execute the 11-test validation suite
python scripts/run_validation.py \
--scores ~/results/wta-loo-scores.json \
--output ~/results/validation_report.json