English

Faster Diffusion Sampling with Randomized Midpoints: Sequential and Parallel

Machine Learning 2024-10-18 v2 Data Structures and Algorithms Statistics Theory Machine Learning Statistics Theory

Abstract

Sampling algorithms play an important role in controlling the quality and runtime of diffusion model inference. In recent years, a number of works~\cite{chen2023sampling,chen2023ode,benton2023error,lee2022convergence} have proposed schemes for diffusion sampling with provable guarantees; these works show that for essentially any data distribution, one can approximately sample in polynomial time given a sufficiently accurate estimate of its score functions at different noise levels. In this work, we propose a new scheme inspired by Shen and Lee's randomized midpoint method for log-concave sampling~\cite{ShenL19}. We prove that this approach achieves the best known dimension dependence for sampling from arbitrary smooth distributions in total variation distance (O~(d5/12)\widetilde O(d^{5/12}) compared to O~(d)\widetilde O(\sqrt{d}) from prior work). We also show that our algorithm can be parallelized to run in only O~(log2d)\widetilde O(\log^2 d) parallel rounds, constituting the first provable guarantees for parallel sampling with diffusion models. As a byproduct of our methods, for the well-studied problem of log-concave sampling in total variation distance, we give an algorithm and simple analysis achieving dimension dependence O~(d5/12)\widetilde O(d^{5/12}) compared to O~(d)\widetilde O(\sqrt{d}) from prior work.

Keywords

Cite

@article{arxiv.2406.00924,
  title  = {Faster Diffusion Sampling with Randomized Midpoints: Sequential and Parallel},
  author = {Shivam Gupta and Linda Cai and Sitan Chen},
  journal= {arXiv preprint arXiv:2406.00924},
  year   = {2024}
}