ZKP on Spot Instances: MSM

Giro Research
Dec. 30, 2023

It is difficult to compute zero-knowledge proofs (ZKPs) on commercial hardware given the high memory requirement, and doing it on the cloud using memory-optimized instances can quickly rack up hefty bills. One way to lower the cloud computing cost is to use AWS EC2 spot instances.

Spot instances are much cheaper than on-demand instances because they are allocated from AWS’s spare capacity. However, spot instances are also subject to interruptions when the spare capacity needs to be reclaimed due to higher user demand. In this article, we go over a method for computing MSMs - one of the most time-consuming operations in ZKP generation - that is both efficient and tolerant of interruptions to take maximum advantage of what spot instances have to offer.

Introduction

Spot instance is a way for AWS to generate extra revenue with its spare computing power. It allows users to save up to 90% of cloud computing costs by using idle resources in the cloud, while bearing the risk of AWS terminating the instance at any arbitrary time with very short advance notice (2 minutes). When using spot instances to generate ZKPs, which could take hours for large circuits, these interruptions are the main challenge we need to tackle.

The key idea is to turn ZKP computation into a preemptable workload by saving intermediate state to persistent storage so that progress is not lost when interruption happens. Broadly speaking, we want to construct a computational graph where nodes are input variables and intermediate results. This graph encodes dependencies as a DAG, and at the time of interruption, variables that are no longer needed can be safely released while others serialized. In addition, we can also split a single complex calculation (such as large MSMs or NTTs) so that each step of the calculation can always be completed within the 2-minute notice period. Finally, we need to consider serialization throughput. If it cannot be completed within two minutes, we would need to do background persistence and/or prioritize variables with smaller size-to-compute-time ratio. Usually, this shouldn’t be a concern: assuming 500MB/s throughput for an EBS volume, we can write approximately 60GB of data in 2 minutes, which is enough to cover all surviving variables of a large-scale (such as $2^{25}$) ZKP computation.

To summarize, our strategy can be broken into four steps.

  1. Sort out the proof generation process and variable dependency / lifecycle
  2. Split time-consuming operations into multiple steps
  3. After receiving the interruption notice, serialize all surviving variables
  4. When the program resumes, reconstruct ZKP computation state from serialized variables and continue

In this article, we will focus on a particular aspect of step (2): splitting large MSMs, the dominant operation in most ZKP workloads.

MSMs and the Pippenger’s Algorithm

Introduction to MSMs

Consider a point $P$ on an elliptic curve, and denote the cyclic group generated by $P$ as $(P)$ with order $p $. Given $k_i \in \mathbb{F}_p$ and $P_i \in (P)$, with $i \in {1,2,3,\cdots, n}$, Multi-scalar multiplication (or MSM) computes $Q = \sum_{i=1}^n k_i \cdot P_i$.

Here, $n$ is the MSM size and $p$represented as a $\lambda-$bit value. $k_i \cdot P_i$ is then the number product on an elliptic curve. Elliptic curve algebra supports point addition (PADD) and point double (PDBL). The scalar multiplication of $k_i$ and $P_i$ is defined as $\sum_{j=1}^{k_i} P_i$ and can be calculated using the double-and-add technique.

If we use the double-and-add technique to calculate each scalar multiplication individually and then add them up to obtain the MSM result, we need to perform a maximum of $n\lambda - 1$ PADDs and $n\lambda - n$ PDBLs ($\lambda -1$PADDs and $\lambda -1$PDBLs for each $k_i \cdot P_i$, and another $n-1$PADDs to add them all). In many ZKP applications, $\lambda$ often exceeds $200$, and $n$ can be as big as $2^{20}$. And both PADD and PDBL require approximately $10$ time-consuming large number modular multiplications or modular squarings, so the time complexity of using double-and-add technique to calculate MSM is too high. There are multiple more efficient methods for calculating MSM. Among them, Pippenger’s algorithm performs best for large $n$.

Pippenger’s Algorithm

When calculating MSM, Pippenger’s algorithm first splits the scalar $k_i$ by selecting a specific window size $s < \lambda$, decomposing the scalar $k_i$ into $\lceil{\frac{\lambda}{s}}\rceil$ windows, and then constraining the MSM task to each $s-$bit window. At this time, the scalar shows sparseness (the value falls in $[0,2^s-1]$ as opposed to $[0,2^\lambda-1]$), and $2^s$ buckets are used to aggregate points $P_i$ corresponding to different scalars. The aggregated points in all buckets are then accumulated to obtain the MSM result of each window. The use of buckets effectively speeds up the calculation of MSM by eliminating redundant multiplication by the same (windowed) scalar. Finally, the results from multiple windows need to be combined.

In Pippenger’s algorithm, the choice of the window size $s$ is critical to performance because it determines the trade-off between the number of PADDs and PDBLs. In general, larger window sizes will result in fewer PADDs at the cost of more PDBLs, and vice versa. The optimal window size depends on the specific use case, the number of points involved, and the hardware or platform being used. We describe the algorithm in detail below.

The first step of Pippenger algorithm is to decompose the original task $Q = \sum_{i=1}^n k_i \cdot P_i$ into multiple subtasks. Choose the window size $s < \lambda$ and decompose the constant $k_i$ into $\lceil{\frac{\lambda}{s}}\rceil$ parts. Each part is a constant $k_{i,j}$ represented by $s$ bits, satisfying $k_i = \sum_{j=1}^{\lceil{\frac{\lambda}{s}}\rceil} 2 ^{(j-1)s}k_{i,j}$. We define the subtask as $G_j = \sum_{i=1}^n k_{i,j}P_i$, for $j \in {1,2,\cdots, \lceil{\frac{\lambda} {s}}\rceil }$. The original task and the subtask satisfy Equation \eqref{1}.

\begin{equation} Q = \sum_{i=1}^n k_i \cdot P_i = \sum_{i=1}^n \sum_{j=1}^{\lceil{\frac{\lambda}{s}}\rceil} 2^{(j-1)s}k_{i,j} \cdot P_i = \sum_{j=1}^{\lceil{\frac{\lambda}{s}}\rceil} 2^{(j-1)s} (\sum_{i=1}^n k_{i,j}P_i) = \sum_{j=1}^{\lceil{\frac{\lambda}{s}}\rceil} 2^{(j-1)s} G_j. \tag{1} \label{1} \end{equation}

The second step of the Pippenger algorithm is to calculate the subtasks $G_j$, with $j \in {1,2,\cdots, \lceil{\frac{\lambda}{s}}\rceil }$. For each subtask, we put all points $P_i$ with the same constant $k_{i,j}$ into the same bucket, labeled also by $k_{i,j}$. The points corresponding to the constant $0$have no effect on the calculation results, so only $2^s - 1$ buckets are needed. Then we add up all points in the same bucket, and the result is recorded as $B_l$, where $l \in {1,2,\cdots, 2^s-1}$ is the bucket label. Then $G_j = \sum_{l=1}^{2^s-1} l \cdot B_l$. Subsequently, a method for efficiently calculating $\sum_{l=1}^{2^s-1} l \cdot B_l$ is used. First define $M_l = \sum_{u = 2^s-l}^{2^s-1}B_l$. Noting that $ M_l = M_{l-1} + B_{2^s-l}$, we can calculate $M_1,M_2,\cdots,M_{2^s-1}$ in order, and then $G_j = \sum_{ l=1}^{2^s-1} M_l$ as shown in Equation \eqref{2}.

\begin{equation} \sum_{l=1}^{2^s-1}M_l = \sum_{l=1}^{2^s-1}\sum_{u=2^s-l}^{2^s-1} B_u = \sum_{l=1}^{2^s-1}lB_l = G_j. \tag{2} \label{2} \end{equation}

The last step is to use the results of all subtask to calculate the original MSM. That is, $Q = \sum_{j=1}^{\lceil{\frac{\lambda}{s}}\rceil} 2^{(j-1 )s}G_j$. We use Equation \eqref{3} to recursively calculate a series of intermediate values $T_u = \sum_{j=1}^{\lceil \frac{\lambda}{s}\rceil - u +1} 2 ^{(j-1)s}G_{j+u-1}$, where $u \in {1,2,\cdots, \lceil \frac{\lambda}{s}\rceil }$, $T_{\lceil \frac{\lambda}{s}\rceil} = G_{\lceil \frac{\lambda}{s}\rceil}$. $T_1$ is exactly $Q$.

\begin{equation} T_u = 2^sT_{u+1} + G_u. \tag{3} \label{3} \end{equation}

Time complexity analysis:Each subtask requires at most $n + 2^s -2$ PADDs ($n-2^s$ PADDs to obtain all bucket sums, and $2(2^s-1)$PADDs to accumulate them using Equation \eqref{2}). Using equation \eqref{3} to add the subtasks to the final result requires approximate $\lambda - s$ PDBLs and $\lceil \frac{\lambda}{s}\rceil - 1$ PADDs. The overall time complexity of the Pippenger’s algorithm is approximately $\lceil \frac{\lambda}{s}\rceil (n + 2^s)$ PADDs plus $\lambda$ PDBLs.

Using Spot Instance to Compute MSMs for Halo2

We will now walk through a concrete example of using spot instances to compute MSMs in a specific proof system, halo2.

MSM in Halo2

Halo2 is a popular proof system written in Rust. In halo2, MSM is used to calculate polynomial commitments. The function is commit_lagrange, and the input is coefficients of a polynomial. It first pads the number of coefficients with zeros to $n = 2^m $, the smallest power of 2, to obtain all scalars $k_i$, then it reads the same number of bases points $P_i$ and invokes best_multiexp for calculation.

The function best_multiexp obtains the thread number $k$ based on device configuration and divides the MSM task into $k$ parts. Denote $T = \lceil n/k \rceil$, that is, $\sum_{i=1}^n k_iP_i = \sum_{j=1}^{k-1} \sum_{i=(j-1)T +1}^{jT}k_iP_i + \sum_{i=(k-1)T+1}^n k_iP_i .$ Each part calls multiexp_serial to separately calculate a smaller MSM. After obtaining the results of each part, they are added together to get the answer to the original MSM. The function multiexp_serial implements Pippenger’s algorithm by first computing the optimal window size and then process each window in sequence from high-bit windows to low-bit windows. At the end, it summarizes the results of each window using Equations \eqref{2} and \eqref{3}.

Spot Instance Considerations

When a spot instance receives the interruption notice, we assume that the ongoing polynomial commitment is the only workload whose state needs to be persisted to disk. Any previous MSM calculations done for other polynomials would have already been completed and the results serialized. So without loss of generality, let us focus on a single large-scale MSM calculation process. We propose the following.

  1. Determine the maximum MSM size $N$ that can be completed by best_multiexp in 2 minutes. This is platform and hardware-dependent and requires careful benchmarking for exact results.
  2. In commit_lagrange, if the size of MSM exceeds $N$, we first decompose it into multiple smaller MSMs and then invoke best_multiexp.
  3. Sum the results of completed best_multiexp.

The smaller each MSM subtask is, the more tolerant of interruptions our setup will be. However, due to the nature of Pippenger’s algorithm, dividing a single MSM into multiple MSM subtasks will increase the time complexity of the entire calculation. Therefore, we will choose the largest MSM size $N$ that can be solved within 2 minutes in Step 1. In practice, if interruption does not occur frequently, it is more performant to to use a bigger $N$.

On receiving the interruption notice, we will begin writing the sum of the completed calculation results of best_multiexp while simultaneously waiting for the completion of the currently running best_multiexp. In addition, we need to add continuation code to support reading those intermediate results back and resume calculation. With those, we will have made MSM in halo2 robust against interruptions of spot instances.

Conclusion

In this article, we propose a method for provers with limited computational resources to use the cloud for ZKP generation at a relatively low cost with AWS EC2 spot instances. The method is generally applicable to the entire ZKP generation process, but we focus on MSMs and describe in detail how to divide up the work and checkpoint intermediate results. In the future, we will extend our method to NTTs and handle the idiosyncrasies of different proof systems.