English

Parallel Sampling via Counting

Data Structures and Algorithms 2024-08-20 v1 Artificial Intelligence Machine Learning Probability

Abstract

We show how to use parallelization to speed up sampling from an arbitrary distribution μ\mu on a product space [q]n[q]^n, given oracle access to counting queries: PXμ[XS=σS]\mathbb{P}_{X\sim \mu}[X_S=\sigma_S] for any S[n]S\subseteq [n] and σS[q]S\sigma_S \in [q]^S. Our algorithm takes O(n2/3polylog(n,q))O({n^{2/3}\cdot \operatorname{polylog}(n,q)}) parallel time, to the best of our knowledge, the first sublinear in nn runtime for arbitrary distributions. Our results have implications for sampling in autoregressive models. Our algorithm directly works with an equivalent oracle that answers conditional marginal queries PXμ[Xi=σi    XS=σS]\mathbb{P}_{X\sim \mu}[X_i=\sigma_i\;\vert\; X_S=\sigma_S], whose role is played by a trained neural network in autoregressive models. This suggests a roughly n1/3n^{1/3}-factor speedup is possible for sampling in any-order autoregressive models. We complement our positive result by showing a lower bound of Ω~(n1/3)\widetilde{\Omega}(n^{1/3}) for the runtime of any parallel sampling algorithm making at most poly(n)\operatorname{poly}(n) queries to the counting oracle, even for q=2q=2.

Keywords

Cite

@article{arxiv.2408.09442,
  title  = {Parallel Sampling via Counting},
  author = {Nima Anari and Ruiquan Gao and Aviad Rubinstein},
  journal= {arXiv preprint arXiv:2408.09442},
  year   = {2024}
}