Model Partitioning and Deep Network Sharding
Last updated
Last updated
Due to the novelty of running deep learning algorithms on the blockchain, Nesa’s partitioning mechanism aims to minimize the amount of data that needs to be transferred between different nodes, thus reducing the latency and bandwidth requirements for both inference and training phases. The overall throughput of the inference and training tasks is limited by the slowest part of the system, known as the bottleneck stage. We therefore introduce mechanisms to balance the workload and avoid bottlenecks. The partitioning must balance the computational load across all nodes to maximize throughput. This balance ensures that all parts of the system work at their full potential without any single stage becoming a drag on performance. For distributed inference and training across multiple nodes, we are using the available fast memory like Static Random-Access Memory (SRAM), which is essential for storing intermediate model outputs, i.e., activations, and parameter weights. SRAM is significantly faster than conventional memory solutions but is also more expensive and limited in size. Each node in the network contributes its SRAM, which effectively multiplies the available fast memory during the distributed inference and training session. This increase allows for caching more model parameters in fast memory, which is critical for enhancing the performance.
Nesa developed a new approach for network topology-informed model sharding, named Blockchain-based Sequential deep neural network Sharding (BSNS). Our approach establishes a sequence of nodes, each of which will typically be a repeated block in any architecture. It is crucial that, for each inference session involving block sharding, a chain of nodes is found that can collectively reconstruct the layers of the full model [19]. The key innovation of our approach is that the selection of nodes for sharding is informed by: (i) the topology of the blockchain; (ii) our persistent homology metrics based on graph embeddings and (iii) network-based variables, including latency and geographical distance between the nodes. Taken together, this will constitute a full end-to-end neural network that performs inference across the blockchain at the optimal speed, fully embedding the topology and the security components of the blockchain.
Our approach ensures scalability as the model is distributed across multiple nodes, enabling computation for larger models. The network's resources are optimized, as every node only stores a portion of the full architecture, saving on processing needs. Importantly, our approach ensures that the computational resources are also optimized based on the network structure, which is embedded as a constraint in the node selection process when building the full architecture based on the individual shards. This has strong practical benefits for our clients, as distance and latency heavily affect the performance when computing on a blockchain.
Given a network with nodes that need to execute shards of a model, and assuming that each block is held by one network node, the task involves finding a sequence of nodes
We can model this as a recursive sequence finding problem, where here the selection of each node depends on the previous node(s) selected and on the network parameters:
where sets the context length for network inference, and affects the computation of the network parameters. Note that, if , the network parameters only involve distance metrics between the current node and the previous node, while for , more than one node is considered and therefore topology metrics can be calculated to inform the selection of the current node.
To dynamically adapt to changing network conditions and node performance, the BSNS framework introduces a dynamic rebalancing algorithm. This algorithm's core function is to periodically evaluate each node's performance and the prevailing network conditions which facilitates the reallocation of model shards among nodes. The objective is to achieve minimal latency and maximal throughput, which consequently optimizes the distributed system's performance.
This rebalancing operation is provided by the following formula:
where represents the rebalancing function, transforming the current sequence of nodes , given network parameters , into a new, optimized sequence . These parameters encapsulate essential network characteristics, including node and edge position in the network, hardware capacity, latency and bandwidth, which collectively influence the data transmission speeds between nodes.