Dynamic Sharding of Arbitrary Neural Networks

While the sequential method provides a structured and heuristic-driven approach to partitioning, it operates under the constraints of a linear, predetermined exploration path. This may not fully capture the dynamism and complexity of modern neural network architectures, where computational and memory demands can vary significantly across different layers and segments of the network. Given an arbitrary neural network, our objective is to partition the network's computation graph for optimal execution across multiple nodes. This computation graph, G=(V,E)G=(V, E), consists of computational unit operations VV and data flow edges EE, with each operation vVv \in V outputting a tensor consumed by downstream operations vv, forming edges (u,v)E(u,v) \in E. The graph represents the entirety of the model's computational workload which ranges from simple arithmetic operations to layer-specific matrix multiplications, each associated with specific computational and memory requirements i.e. the running time work(v)\text{work}(v), the memory footprint of the model parameters sizeparam(v)\text{sizeparam}(v), and the size of the operation's output sizeout(v)\text{sizeout}(v).

Partitioning this graph involves dividing VV into kk distinct blocks such that each block can be processed on a different node in a swarm under the constraint that the induced quotient graph of GG remains acyclic. This division aims to maximize throughput while minimizing inter-node communication subjected to the bandwidth BB between nodes, with the I/O cost from node SS to node TT given by:

io(S,T)=1BvN(T)Ssizeout(v),\text{io}(S,T) = \frac{1}{B} \sum_{v \in N^-(T) \cap S} \text{sizeout}(v),

where N(T)N^-(T) represents the set of nodes whose outputs are consumed by block TT.

The core challenge lies in efficiently distributing the model's parameters and activations across the available fast memory (e.g., SRAM) of each node. Parameters not fitting in fast memory must be streamed from slower storage which introduces additional latency. The overflow cost which represents the time to stream parameters exceeding the fast memory limit MM is calculated as:

overflow(S)=(sizeparam(S)+peak(S)M)+peak(S)B,\text{overflow}(S) = \left(\text{sizeparam}(S) + \text{peak}(S) - M\right) + \frac{\text{peak}(S)}{B},

where peak(S)\text{peak}(S) denotes the peak memory requirement for activations within block SS.

The overall block cost, f(S)f(S), combines the costs of receiving input tensors, executing the block's operations (including any overflow cost from streaming parameters), and sending output tensors downstream:

f(S)=io(VS,S)+vSwork(v)+overflow(S)+io(S,VS).f(S) = \text{io}(V\setminus S,S) + \sum_{v \in S} \text{work}(v) + \text{overflow}(S) + \text{io}(S,V\setminus S).

The goal of partitioning, defined by the Max-Throughput Partitioning Problem (MTPP), is to minimize the maximum cost across all blocks, optimizing the throughput of the entire pipeline. Formally, MTPP seeks a partition PP^* that minimizes the bottleneck cost:

P=argminPPk(G){maxi[k]f(Pi)},P^* = \text{argmin}_{P \in P_k(G)} \left\{ \max_{i\in[k]} f(P_i) \right\},

where Pk(G)P_k(G) denotes the set of all possible partitions of GG into kk blocks, and cost\text{cost}^* is the minimum achievable bottleneck cost across these partitions.

Last updated