Implementation: Split Learning (HE)

Recognizing the challenges posed by encrypting data for use in decentralized inference systems, Nesa adopts Split Learning (SL) as a pragmatic solution to facilitate secure and efficient computation on encrypted data. Traditional encryption methods, while securing data at rest and in transit, render it unusable for direct computation by obscuring its format and structure. This limitation is particularly problematic for processing with LLMs within a decentralized framework, where data privacy cannot be compromised.

Split Learning addresses these concerns by partitioning the computational model, allowing for data to be processed in parts without revealing sensitive information. In essence, the user data is protected by not being directly transmitted to any nodes -- only the data embeddings are being passed around, and each node will only be accessing the embeddings of certain layers.

Consider a neural network model ( \mathcal{N} ), such as Llama 2 composed of a sequence of 32 layers ( {L_1, L_2, \ldots, L_{32}} ), each with its own set of parameters ( \Theta_i ) and activation function ( \sigma_i ). The input to the network is ( X ), and the output of the (i)-th layer, given input ( x_i ), can be mathematically described as:

ai=Li(xi;Θi)=σi(Wixi+bi)a_i = L_i(x_i; \Theta_i) = \sigma_i(W_i x_i + b_i)

where ( W_i ) and ( b_i ) are the weight matrix and bias vector of the (i)-th layer, respectively, and ( \sigma_i ) is a nonlinear activation function such as ReLU, sigmoid, or tanh.

Assuming the model is split at layer ( k ), where the client handles layers ( {L_1, \ldots, L_k} ) and the server handles layers ( {L_{k+1}, \ldots, L_{32}} ). The client computes the intermediate representation ( Z ) as follows:

Z=σk(Wkσk1(σ1(W1X+b1))+bk)Z = \sigma_k(W_k \cdot \sigma_{k-1}( \ldots \sigma_1(W_1 X + b_1) \ldots ) + b_k)

This intermediate representation ( Z ) is then transmitted to the server, which continues the computation:

Y=σ32(W32σ31(σk+1(Wk+1Z+bk+1))+b32)Y = \sigma_{32}(W_{32} \cdot \sigma_{31}( \ldots \sigma_{k+1}(W_{k+1} Z + b_{k+1}) \ldots ) + b_{32})

To include the loss function ( \mathcal{L}(Y, Y_{true}) ) for computing the error between the network output ( Y ) and the true labels ( Y_{true} ), and the gradient of the loss with respect to the model's parameters through backpropagation:

LΘi=ChainRule(LY,Ya32,,aiΘi)\frac{\partial \mathcal{L}}{\partial \Theta_i} = \text{ChainRule}\left(\frac{\partial \mathcal{L}}{\partial Y}, \frac{\partial Y}{\partial a_{32}}, \ldots, \frac{\partial a_i}{\partial \Theta_i}\right)

For privacy concerns during the transmission of ( Z ) from client to server, differential privacy methods may be applied. Defining a privacy metric ( \mathcal{P} ) that quantifies the information leakage from the intermediate representation ( Z ), a proof of privacy preservation could demonstrate that for any ( \epsilon )-differential privacy guarantee, the information leakage remains below a threshold:

P(Z)ϵ\mathcal{P}(Z) \leq \epsilon

It is noted that by using differential privacy with SL, the security will be improved at the cost of inference quality. Thus, in Nesa's framework, this is defined as a tunable parameter to be decided, given the user requirements.

By leveraging Split Learning, Nesa effectively navigates the complexities of data encryption within its decentralized inference system for LLMs. This approach not only preserves the confidentiality and integrity of user data but also ensures the operational feasibility of complex model computations, demonstrating a sophisticated balance between privacy preservation and computational pragmatism.

Last updated